diff options
| author | ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> | 2025-07-02 14:18:21 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-02 21:18:21 +0000 |
| commit | 3e1dd65adff0873e0385040c5c0a003eda83de3b (patch) | |
| tree | 8e8c3902d3e96e2e39346d4a306d04771f3ca121 | |
| parent | 54a5d7f0056b4a846c790e7e019b9b5e74f76a98 (diff) | |
[HLSL, SPIRV_1_3] Hoist OpSelect returning a composite into `if`/`else` (#7594)
* emit var and hoist out OpSelect if Composite
* cleanup comment
* address review
check for version in spv context
use phi node instead of using var
move inst's using a list (not in-place modification)
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/slang-compiler.h | 41 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 63 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-composite-select.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-composite-select.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 27 | ||||
| -rw-r--r-- | tests/bugs/op-select-return-composite.slang | 34 |
7 files changed, 222 insertions, 44 deletions
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 57c20aed2..60f6cc92f 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2947,6 +2947,34 @@ class ExtensionTracker : public RefObject public: }; +struct RequiredLoweringPassSet +{ + bool debugInfo; + bool resultType; + bool optionalType; + bool enumType; + bool combinedTextureSamplers; + bool reinterpret; + bool generics; + bool bindExistential; + bool autodiff; + bool derivativePyBindWrapper; + bool bitcast; + bool existentialTypeLayout; + bool bindingQuery; + bool meshOutput; + bool higherOrderFunc; + bool globalVaryingVar; + bool glslSSBO; + bool byteAddressBuffer; + bool dynamicResource; + bool dynamicResourceHeap; + bool resolveVaryingInputRef; + bool specializeStageSwitch; + bool missingReturn; + bool nonVectorCompositeSelect; +}; + /// A context for code generation in the compiler back-end struct CodeGenContext { @@ -3076,11 +3104,24 @@ public: // This is a no-op if modules are not precompiled. bool shouldSkipDownstreamLinking(); + RequiredLoweringPassSet& getRequiredLoweringPassSet() { return m_requiredLoweringPassSet; } + protected: CodeGenTarget m_targetFormat = CodeGenTarget::Unknown; Profile m_targetProfile; ExtensionTracker* m_extensionTracker = nullptr; + // To improve the performance of our backend, we will try to avoid running + // passes related to features not used in the user code. + // To do so, we will scan the IR module once, and determine which passes are needed + // based on the instructions used in the IR module. + // This will allow us to skip running passes that are not needed, without having to + // run all the passes only to find out that no work is needed. + // This is especially important for the performance of the backend, as some passes + // have an initialization cost (such as building reference graphs or DOM trees) that + // can be expensive. + RequiredLoweringPassSet m_requiredLoweringPassSet; + /// Will output assembly as well as the artifact if appropriate for the artifact type for /// assembly output and conversion is possible void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index db8a9ba61..aaefdaee0 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -56,6 +56,7 @@ #include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" #include "slang-ir-legalize-binary-operator.h" +#include "slang-ir-legalize-composite-select.h" #include "slang-ir-legalize-empty-array.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-image-subscript.h" @@ -300,43 +301,6 @@ struct LinkingAndOptimizationOptions CLikeSourceEmitter* sourceEmitter = nullptr; }; -// To improve the performance of our backend, we will try to avoid running -// passes related to features not used in the user code. -// To do so, we will scan the IR module once, and determine which passes are needed -// based on the instructions used in the IR module. -// This will allow us to skip running passes that are not needed, without having to -// run all the passes only to find out that no work is needed. -// This is especially important for the performance of the backend, as some passes -// have an initialization cost (such as building reference graphs or DOM trees) that -// can be expensive. -// -struct RequiredLoweringPassSet -{ - bool debugInfo; - bool resultType; - bool optionalType; - bool enumType; - bool combinedTextureSamplers; - bool reinterpret; - bool generics; - bool bindExistential; - bool autodiff; - bool derivativePyBindWrapper; - bool bitcast; - bool existentialTypeLayout; - bool bindingQuery; - bool meshOutput; - bool higherOrderFunc; - bool globalVaryingVar; - bool glslSSBO; - bool byteAddressBuffer; - bool dynamicResource; - bool dynamicResourceHeap; - bool resolveVaryingInputRef; - bool specializeStageSwitch; - bool missingReturn; -}; - // Scan the IR module and determine which lowering/legalization passes are needed based // on the instructions we see. // @@ -471,6 +435,10 @@ void calcRequiredLoweringPassSet( case kIROp_MissingReturn: result.missingReturn = true; break; + case kIROp_Select: + if (!isScalarOrVectorType(inst->getFullType())) + result.nonVectorCompositeSelect = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -492,7 +460,10 @@ void calcRequiredLoweringPassSet( } for (auto child : inst->getDecorationsAndChildren()) { - calcRequiredLoweringPassSet(result, codeGenContext, child); + calcRequiredLoweringPassSet( + codeGenContext->getRequiredLoweringPassSet(), + codeGenContext, + child); } } @@ -764,7 +735,8 @@ Result linkAndOptimizeIR( dumpIRIfEnabled(codeGenContext, irModule, "POST IR VALIDATION"); // Scan the IR module and determine which lowering/legalization passes are needed. - RequiredLoweringPassSet requiredLoweringPassSet = {}; + RequiredLoweringPassSet& requiredLoweringPassSet = codeGenContext->getRequiredLoweringPassSet(); + requiredLoweringPassSet = {}; calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); // Debug info is added by the front-end, and therefore needs to be stripped out by targets that @@ -1083,6 +1055,18 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.optionalType) lowerOptionalType(irModule, sink); + if (requiredLoweringPassSet.nonVectorCompositeSelect) + { + switch (target) + { + case CodeGenTarget::HLSL: + legalizeNonVectorCompositeSelect(irModule); + break; + default: + break; + } + } + switch (target) { case CodeGenTarget::CPPSource: @@ -1097,7 +1081,6 @@ Result linkAndOptimizeIR( break; } - requiredLoweringPassSet = {}; calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); switch (target) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index afe06f9a1..b0a0c74f9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2423,6 +2423,17 @@ struct IRImageStore : IRInst // Terminators FIDDLE() +struct IRSelect : IRInst +{ + FIDDLE(leafInst()); + + IRInst* getCondition() { return getOperand(0); } + IRInst* getTrueResult() { return getOperand(1); } + IRInst* getFalseResult() { return getOperand(2); } +}; + + +FIDDLE() struct IRReturn : IRTerminatorInst { FIDDLE(leafInst()) diff --git a/source/slang/slang-ir-legalize-composite-select.cpp b/source/slang/slang-ir-legalize-composite-select.cpp new file mode 100644 index 000000000..1b2ba0670 --- /dev/null +++ b/source/slang/slang-ir-legalize-composite-select.cpp @@ -0,0 +1,79 @@ +#include "slang-ir-legalize-composite-select.h" + +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-legalize-varying-params.h" +#include "slang-ir-specialize-address-space.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selectInst) +{ + SLANG_ASSERT(selectInst); + + auto resultType = selectInst->getFullType(); + auto trueResult = selectInst->getTrueResult(); + auto falseResult = selectInst->getFalseResult(); + + IRBlock* trueBlock; + IRBlock* falseBlock; + IRBlock* afterBlock; + builder.emitIfElseWithBlocks(selectInst->getCondition(), trueBlock, falseBlock, afterBlock); + + // Generate if-select-true and else-select-false clause + builder.setInsertInto(trueBlock); + builder.emitBranch(afterBlock, 1, &trueResult); + + builder.setInsertInto(falseBlock); + builder.emitBranch(afterBlock, 1, &falseResult); + + // Move everything after the OpSelect into the "after" block + List<IRInst*> instsToMove; + instsToMove.reserve(15); + IRInst* nextInst = selectInst; + while (nextInst) + { + instsToMove.add(nextInst); + nextInst = nextInst->getNextInst(); + } + for (auto i : instsToMove) + afterBlock->insertAtEnd(i); + + // Merge result of branches into param + builder.setInsertInto(afterBlock); + auto param = builder.emitParam(resultType); + selectInst->replaceUsesWith(param); + + // Clean up + selectInst->removeAndDeallocate(); +} +void legalizeNonVectorCompositeSelect(IRModule* module) +{ + IRBuilder builder(module); + for (auto globalInst : module->getModuleInst()->getChildren()) + { + auto func = as<IRFunc>(globalInst); + if (!func) + continue; + for (auto block : func->getBlocks()) + { + auto inst = block->getFirstInst(); + IRInst* next; + for (; inst; inst = next) + { + next = inst->getNextInst(); + switch (inst->getOp()) + { + case kIROp_Select: + // Replace OpSelect with if/else branch (same process as glslang) + if (!isScalarOrVectorType(inst->getFullType())) + legalizeASingleNonVectorCompositeSelect(builder, as<IRSelect>(inst)); + continue; + } + } + } + } +} +} // namespace Slang diff --git a/source/slang/slang-ir-legalize-composite-select.h b/source/slang/slang-ir-legalize-composite-select.h new file mode 100644 index 000000000..dc11ccb21 --- /dev/null +++ b/source/slang/slang-ir-legalize-composite-select.h @@ -0,0 +1,11 @@ +#pragma once + +#include "slang-compiler.h" +#include "slang-ir.h" + +namespace Slang +{ +class DiagnosticSink; + +void legalizeNonVectorCompositeSelect(IRModule* module); +} // namespace Slang diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 2a5701e33..20b721a20 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -12,6 +12,7 @@ #include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" +#include "slang-ir-legalize-composite-select.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-loop-unroll.h" @@ -39,6 +40,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRModule* m_module; + CodeGenContext* m_codeGenContext; + DiagnosticSink* m_sink; struct LoweredStructuredBufferTypeInfo @@ -202,8 +205,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase SPIRVLegalizationContext( SPIRVEmitSharedContext* sharedContext, IRModule* module, + CodeGenContext* codeGenContext, DiagnosticSink* sink) - : m_sharedContext(sharedContext), m_module(module), m_sink(sink) + : m_sharedContext(sharedContext) + , m_module(module) + , m_codeGenContext(codeGenContext) + , m_sink(sink) { } @@ -2208,6 +2215,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // invalid SPIR-V. bool skipFuncParamValidation = false; validateAtomicOperations(skipFuncParamValidation, m_sink, m_module->getModuleInst()); + + // If older than spirv 1.4, legalize OpSelect returning non-vector-composites + if (m_codeGenContext->getRequiredLoweringPassSet().nonVectorCompositeSelect && + !m_sharedContext->isSpirv14OrLater()) + legalizeNonVectorCompositeSelect(m_module); } void updateFunctionTypes() @@ -2281,9 +2293,16 @@ SpvSnippet* SPIRVEmitSharedContext::getParsedSpvSnippet(IRTargetIntrinsicDecorat return snippet; } -void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink) +void legalizeSPIRV( + SPIRVEmitSharedContext* sharedContext, + IRModule* module, + CodeGenContext* codeGenContext) { - SPIRVLegalizationContext context(sharedContext, module, sink); + SPIRVLegalizationContext context( + sharedContext, + module, + codeGenContext, + codeGenContext->getSink()); context.processModule(); } @@ -2424,7 +2443,7 @@ void legalizeIRForSPIRV( CodeGenContext* codeGenContext) { SLANG_UNUSED(entryPoints); - legalizeSPIRV(context, module, codeGenContext->getSink()); + legalizeSPIRV(context, module, codeGenContext); simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module); buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module); insertFragmentShaderInterlock(context, module); diff --git a/tests/bugs/op-select-return-composite.slang b/tests/bugs/op-select-return-composite.slang new file mode 100644 index 000000000..eb453d457 --- /dev/null +++ b/tests/bugs/op-select-return-composite.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-d3d12 -output-using-type -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type -profile spirv_1_3 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +struct CompositeType +{ + __init(int dataIn) + { + data1 = dataIn; + data2 = dataIn; + } + int data1; + float data2; +} + +[numthreads(1,1,1)] +void computeMain(){ + + CompositeType composite = CompositeType(-1); + if (outputBuffer[0] == 0) + { + composite = outputBuffer[1] > -1 ? CompositeType(1) : CompositeType(-1); + } + outputBuffer[2] = composite.data1; + outputBuffer[3] = (int)composite.data2; +} + +//BUF: 0 +//BUF: 0 +//BUF: 1 +//BUF: 1 |
