summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-07-02 14:18:21 -0700
committerGitHub <noreply@github.com>2025-07-02 21:18:21 +0000
commit3e1dd65adff0873e0385040c5c0a003eda83de3b (patch)
tree8e8c3902d3e96e2e39346d4a306d04771f3ca121 /source
parent54a5d7f0056b4a846c790e7e019b9b5e74f76a98 (diff)
[HLSL, SPIRV_1_3] Hoist OpSelect returning a composite into `if`/`else` (#7594)
* emit var and hoist out OpSelect if Composite * cleanup comment * address review check for version in spv context use phi node instead of using var move inst's using a list (not in-place modification) * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-compiler.h41
-rw-r--r--source/slang/slang-emit.cpp63
-rw-r--r--source/slang/slang-ir-insts.h11
-rw-r--r--source/slang/slang-ir-legalize-composite-select.cpp79
-rw-r--r--source/slang/slang-ir-legalize-composite-select.h11
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp27
6 files changed, 188 insertions, 44 deletions
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 57c20aed2..60f6cc92f 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2947,6 +2947,34 @@ class ExtensionTracker : public RefObject
public:
};
+struct RequiredLoweringPassSet
+{
+ bool debugInfo;
+ bool resultType;
+ bool optionalType;
+ bool enumType;
+ bool combinedTextureSamplers;
+ bool reinterpret;
+ bool generics;
+ bool bindExistential;
+ bool autodiff;
+ bool derivativePyBindWrapper;
+ bool bitcast;
+ bool existentialTypeLayout;
+ bool bindingQuery;
+ bool meshOutput;
+ bool higherOrderFunc;
+ bool globalVaryingVar;
+ bool glslSSBO;
+ bool byteAddressBuffer;
+ bool dynamicResource;
+ bool dynamicResourceHeap;
+ bool resolveVaryingInputRef;
+ bool specializeStageSwitch;
+ bool missingReturn;
+ bool nonVectorCompositeSelect;
+};
+
/// A context for code generation in the compiler back-end
struct CodeGenContext
{
@@ -3076,11 +3104,24 @@ public:
// This is a no-op if modules are not precompiled.
bool shouldSkipDownstreamLinking();
+ RequiredLoweringPassSet& getRequiredLoweringPassSet() { return m_requiredLoweringPassSet; }
+
protected:
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
Profile m_targetProfile;
ExtensionTracker* m_extensionTracker = nullptr;
+ // To improve the performance of our backend, we will try to avoid running
+ // passes related to features not used in the user code.
+ // To do so, we will scan the IR module once, and determine which passes are needed
+ // based on the instructions used in the IR module.
+ // This will allow us to skip running passes that are not needed, without having to
+ // run all the passes only to find out that no work is needed.
+ // This is especially important for the performance of the backend, as some passes
+ // have an initialization cost (such as building reference graphs or DOM trees) that
+ // can be expensive.
+ RequiredLoweringPassSet m_requiredLoweringPassSet;
+
/// Will output assembly as well as the artifact if appropriate for the artifact type for
/// assembly output and conversion is possible
void _dumpIntermediateMaybeWithAssembly(IArtifact* artifact);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index db8a9ba61..aaefdaee0 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -56,6 +56,7 @@
#include "slang-ir-layout.h"
#include "slang-ir-legalize-array-return-type.h"
#include "slang-ir-legalize-binary-operator.h"
+#include "slang-ir-legalize-composite-select.h"
#include "slang-ir-legalize-empty-array.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-image-subscript.h"
@@ -300,43 +301,6 @@ struct LinkingAndOptimizationOptions
CLikeSourceEmitter* sourceEmitter = nullptr;
};
-// To improve the performance of our backend, we will try to avoid running
-// passes related to features not used in the user code.
-// To do so, we will scan the IR module once, and determine which passes are needed
-// based on the instructions used in the IR module.
-// This will allow us to skip running passes that are not needed, without having to
-// run all the passes only to find out that no work is needed.
-// This is especially important for the performance of the backend, as some passes
-// have an initialization cost (such as building reference graphs or DOM trees) that
-// can be expensive.
-//
-struct RequiredLoweringPassSet
-{
- bool debugInfo;
- bool resultType;
- bool optionalType;
- bool enumType;
- bool combinedTextureSamplers;
- bool reinterpret;
- bool generics;
- bool bindExistential;
- bool autodiff;
- bool derivativePyBindWrapper;
- bool bitcast;
- bool existentialTypeLayout;
- bool bindingQuery;
- bool meshOutput;
- bool higherOrderFunc;
- bool globalVaryingVar;
- bool glslSSBO;
- bool byteAddressBuffer;
- bool dynamicResource;
- bool dynamicResourceHeap;
- bool resolveVaryingInputRef;
- bool specializeStageSwitch;
- bool missingReturn;
-};
-
// Scan the IR module and determine which lowering/legalization passes are needed based
// on the instructions we see.
//
@@ -471,6 +435,10 @@ void calcRequiredLoweringPassSet(
case kIROp_MissingReturn:
result.missingReturn = true;
break;
+ case kIROp_Select:
+ if (!isScalarOrVectorType(inst->getFullType()))
+ result.nonVectorCompositeSelect = true;
+ break;
}
if (!result.generics || !result.existentialTypeLayout)
{
@@ -492,7 +460,10 @@ void calcRequiredLoweringPassSet(
}
for (auto child : inst->getDecorationsAndChildren())
{
- calcRequiredLoweringPassSet(result, codeGenContext, child);
+ calcRequiredLoweringPassSet(
+ codeGenContext->getRequiredLoweringPassSet(),
+ codeGenContext,
+ child);
}
}
@@ -764,7 +735,8 @@ Result linkAndOptimizeIR(
dumpIRIfEnabled(codeGenContext, irModule, "POST IR VALIDATION");
// Scan the IR module and determine which lowering/legalization passes are needed.
- RequiredLoweringPassSet requiredLoweringPassSet = {};
+ RequiredLoweringPassSet& requiredLoweringPassSet = codeGenContext->getRequiredLoweringPassSet();
+ requiredLoweringPassSet = {};
calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst());
// Debug info is added by the front-end, and therefore needs to be stripped out by targets that
@@ -1083,6 +1055,18 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.optionalType)
lowerOptionalType(irModule, sink);
+ if (requiredLoweringPassSet.nonVectorCompositeSelect)
+ {
+ switch (target)
+ {
+ case CodeGenTarget::HLSL:
+ legalizeNonVectorCompositeSelect(irModule);
+ break;
+ default:
+ break;
+ }
+ }
+
switch (target)
{
case CodeGenTarget::CPPSource:
@@ -1097,7 +1081,6 @@ Result linkAndOptimizeIR(
break;
}
- requiredLoweringPassSet = {};
calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst());
switch (target)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index afe06f9a1..b0a0c74f9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2423,6 +2423,17 @@ struct IRImageStore : IRInst
// Terminators
FIDDLE()
+struct IRSelect : IRInst
+{
+ FIDDLE(leafInst());
+
+ IRInst* getCondition() { return getOperand(0); }
+ IRInst* getTrueResult() { return getOperand(1); }
+ IRInst* getFalseResult() { return getOperand(2); }
+};
+
+
+FIDDLE()
struct IRReturn : IRTerminatorInst
{
FIDDLE(leafInst())
diff --git a/source/slang/slang-ir-legalize-composite-select.cpp b/source/slang/slang-ir-legalize-composite-select.cpp
new file mode 100644
index 000000000..1b2ba0670
--- /dev/null
+++ b/source/slang/slang-ir-legalize-composite-select.cpp
@@ -0,0 +1,79 @@
+#include "slang-ir-legalize-composite-select.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-legalize-varying-params.h"
+#include "slang-ir-specialize-address-space.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selectInst)
+{
+ SLANG_ASSERT(selectInst);
+
+ auto resultType = selectInst->getFullType();
+ auto trueResult = selectInst->getTrueResult();
+ auto falseResult = selectInst->getFalseResult();
+
+ IRBlock* trueBlock;
+ IRBlock* falseBlock;
+ IRBlock* afterBlock;
+ builder.emitIfElseWithBlocks(selectInst->getCondition(), trueBlock, falseBlock, afterBlock);
+
+ // Generate if-select-true and else-select-false clause
+ builder.setInsertInto(trueBlock);
+ builder.emitBranch(afterBlock, 1, &trueResult);
+
+ builder.setInsertInto(falseBlock);
+ builder.emitBranch(afterBlock, 1, &falseResult);
+
+ // Move everything after the OpSelect into the "after" block
+ List<IRInst*> instsToMove;
+ instsToMove.reserve(15);
+ IRInst* nextInst = selectInst;
+ while (nextInst)
+ {
+ instsToMove.add(nextInst);
+ nextInst = nextInst->getNextInst();
+ }
+ for (auto i : instsToMove)
+ afterBlock->insertAtEnd(i);
+
+ // Merge result of branches into param
+ builder.setInsertInto(afterBlock);
+ auto param = builder.emitParam(resultType);
+ selectInst->replaceUsesWith(param);
+
+ // Clean up
+ selectInst->removeAndDeallocate();
+}
+void legalizeNonVectorCompositeSelect(IRModule* module)
+{
+ IRBuilder builder(module);
+ for (auto globalInst : module->getModuleInst()->getChildren())
+ {
+ auto func = as<IRFunc>(globalInst);
+ if (!func)
+ continue;
+ for (auto block : func->getBlocks())
+ {
+ auto inst = block->getFirstInst();
+ IRInst* next;
+ for (; inst; inst = next)
+ {
+ next = inst->getNextInst();
+ switch (inst->getOp())
+ {
+ case kIROp_Select:
+ // Replace OpSelect with if/else branch (same process as glslang)
+ if (!isScalarOrVectorType(inst->getFullType()))
+ legalizeASingleNonVectorCompositeSelect(builder, as<IRSelect>(inst));
+ continue;
+ }
+ }
+ }
+ }
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-composite-select.h b/source/slang/slang-ir-legalize-composite-select.h
new file mode 100644
index 000000000..dc11ccb21
--- /dev/null
+++ b/source/slang/slang-ir-legalize-composite-select.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include "slang-compiler.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+class DiagnosticSink;
+
+void legalizeNonVectorCompositeSelect(IRModule* module);
+} // namespace Slang
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 2a5701e33..20b721a20 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -12,6 +12,7 @@
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-layout.h"
+#include "slang-ir-legalize-composite-select.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-mesh-outputs.h"
#include "slang-ir-loop-unroll.h"
@@ -39,6 +40,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRModule* m_module;
+ CodeGenContext* m_codeGenContext;
+
DiagnosticSink* m_sink;
struct LoweredStructuredBufferTypeInfo
@@ -202,8 +205,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
SPIRVLegalizationContext(
SPIRVEmitSharedContext* sharedContext,
IRModule* module,
+ CodeGenContext* codeGenContext,
DiagnosticSink* sink)
- : m_sharedContext(sharedContext), m_module(module), m_sink(sink)
+ : m_sharedContext(sharedContext)
+ , m_module(module)
+ , m_codeGenContext(codeGenContext)
+ , m_sink(sink)
{
}
@@ -2208,6 +2215,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// invalid SPIR-V.
bool skipFuncParamValidation = false;
validateAtomicOperations(skipFuncParamValidation, m_sink, m_module->getModuleInst());
+
+ // If older than spirv 1.4, legalize OpSelect returning non-vector-composites
+ if (m_codeGenContext->getRequiredLoweringPassSet().nonVectorCompositeSelect &&
+ !m_sharedContext->isSpirv14OrLater())
+ legalizeNonVectorCompositeSelect(m_module);
}
void updateFunctionTypes()
@@ -2281,9 +2293,16 @@ SpvSnippet* SPIRVEmitSharedContext::getParsedSpvSnippet(IRTargetIntrinsicDecorat
return snippet;
}
-void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
+void legalizeSPIRV(
+ SPIRVEmitSharedContext* sharedContext,
+ IRModule* module,
+ CodeGenContext* codeGenContext)
{
- SPIRVLegalizationContext context(sharedContext, module, sink);
+ SPIRVLegalizationContext context(
+ sharedContext,
+ module,
+ codeGenContext,
+ codeGenContext->getSink());
context.processModule();
}
@@ -2424,7 +2443,7 @@ void legalizeIRForSPIRV(
CodeGenContext* codeGenContext)
{
SLANG_UNUSED(entryPoints);
- legalizeSPIRV(context, module, codeGenContext->getSink());
+ legalizeSPIRV(context, module, codeGenContext);
simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module);
buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module);
insertFragmentShaderInterlock(context, module);