diff options
| author | jarcherNV <jarcher@nvidia.com> | 2025-04-11 14:51:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-11 21:51:48 +0000 |
| commit | 61a6c211b1587a7b9ed6a24ae1ba6fe0600c80d8 (patch) | |
| tree | b67322ca56975223cc2eb897acc29155928128fd | |
| parent | 88a180ba0aa57b2d0fb4956005db2ea73dc73420 (diff) | |
Add flag to hoist instructions (#6740)
This fixes issue #6654
Only hoist instructions that are optimized by prepareFuncForForwardDiff.
Add flag hoistLoopInvariantInsts to IRSimplificationOptions and set this
to true only if called from prepareFuncForForwardDiff, then only hoist
if the flag is set. Additionally, do not hoist loops if they only have a
single trivial iteration.
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-defer-buffer-load.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.h | 1 | ||||
| -rw-r--r-- | tests/spirv/forceinline-nohoist.slang | 32 |
10 files changed, 79 insertions, 25 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 92c35a618..08ca56f7d 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1877,6 +1877,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) disableIRValidationAtInsert(); auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr); simplifyOptions.removeRedundancy = true; + simplifyOptions.hoistLoopInvariantInsts = true; simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions); enableIRValidationAtInsert(); } diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index fea73e705..bd3b78f9e 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -167,7 +167,7 @@ struct DeferBufferLoadContext void deferBufferLoadInFunc(IRFunc* func) { - removeRedundancyInFunc(func); + removeRedundancyInFunc(func, false); currentFunc = func; dominatorTree = func->getModule()->findOrCreateDominatorTree(func); diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 0d9b910c4..a6dac723e 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -1,6 +1,7 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-dominators.h" +#include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" namespace Slang @@ -20,6 +21,10 @@ struct RedundancyRemovalContext auto terminatorInst = parentBlock->getTerminator(); if (auto loop = as<IRLoop>(terminatorInst)) { + // Don't bother hoisting if a loop has only a single trivial iteration. + if (isTrivialSingleIterationLoop(dom, func, loop)) + continue; + // If `inst` is outside of the loop region, don't hoist it into the loop. if (dom->dominates(loop->getBreakBlock(), inst)) continue; @@ -62,7 +67,8 @@ struct RedundancyRemovalContext bool removeRedundancyInBlock( Dictionary<IRBlock*, DeduplicateContext>& mapBlockToDedupContext, IRGlobalValueWithCode* func, - IRBlock* block) + IRBlock* block, + bool hoistLoopInvariantInsts) { bool result = false; auto& deduplicateContext = mapBlockToDedupContext.getValue(block); @@ -89,7 +95,8 @@ struct RedundancyRemovalContext { // This inst is unique, we should consider hoisting it // if it is inside a loop. - result |= tryHoistInstToOuterMostLoop(func, resultInst); + if (hoistLoopInvariantInsts) + result |= tryHoistInstToOuterMostLoop(func, resultInst); } } for (auto child : dom->getImmediatelyDominatedBlocks(block)) @@ -101,25 +108,25 @@ struct RedundancyRemovalContext } }; -bool removeRedundancy(IRModule* module) +bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts) { bool changed = false; for (auto inst : module->getGlobalInsts()) { if (auto genericInst = as<IRGeneric>(inst)) { - removeRedundancyInFunc(genericInst); + removeRedundancyInFunc(genericInst, hoistLoopInvariantInsts); inst = findGenericReturnVal(genericInst); } if (auto func = as<IRFunc>(inst)) { - changed |= removeRedundancyInFunc(func); + changed |= removeRedundancyInFunc(func, hoistLoopInvariantInsts); } } return changed; } -bool removeRedundancyInFunc(IRGlobalValueWithCode* func) +bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts) { auto root = func->getFirstBlock(); if (!root) @@ -139,7 +146,11 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func) { for (auto block : workList) { - result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block); + result |= context.removeRedundancyInBlock( + mapBlockToDeduplicateContext, + func, + block, + hoistLoopInvariantInsts); for (auto child : context.dom->getImmediatelyDominatedBlocks(block)) { diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h index 40f3b07a9..9d9fb43f0 100644 --- a/source/slang/slang-ir-redundancy-removal.h +++ b/source/slang/slang-ir-redundancy-removal.h @@ -7,8 +7,8 @@ namespace Slang struct IRModule; struct IRGlobalValueWithCode; -bool removeRedundancy(IRModule* module); -bool removeRedundancyInFunc(IRGlobalValueWithCode* func); +bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts); +bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts); bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 68d79617a..280874f74 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -66,8 +66,8 @@ static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock* // Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and // there is only one break out of the loop at the very end. The function generates `regionTree` if // it is needed and hasn't been generated yet. -static bool isTrivialSingleIterationLoop( - CFGSimplificationContext& context, +bool isTrivialSingleIterationLoop( + IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRLoop* loop) { @@ -91,21 +91,21 @@ static bool isTrivialSingleIterationLoop( // // We need to verify this is a trivial loop by checking if there is any multi-level breaks // that skips out of this loop. - if (!context.domTree) - context.domTree = computeDominatorTree(func); + if (!domTree) + domTree = computeDominatorTree(func); bool hasMultiLevelBreaks = false; - auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks); + auto loopBlocks = collectBlocksInRegion(domTree, loop, &hasMultiLevelBreaks); if (hasMultiLevelBreaks) return false; for (auto block : loopBlocks) { for (auto branchTarget : block->getSuccessors()) { - if (!context.domTree->dominates(loop->getParent(), branchTarget)) + if (!domTree->dominates(loop->getParent(), branchTarget)) return false; if (branchTarget != loop->getBreakBlock()) continue; - if (findBreakableRegionHeaderInst(context.domTree, block) != loop) + if (findBreakableRegionHeaderInst(domTree, block) != loop) { // If the break is initiated from a nested region, this is not trivial. return false; @@ -127,7 +127,7 @@ static bool isTrivialSingleIterationLoop( auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin(); for (auto currBlock = breakOriginBlock; currBlock; - currBlock = context.domTree->getImmediateDominator(currBlock)) + currBlock = domTree->getImmediateDominator(currBlock)) { auto terminator = currBlock->getTerminator(); if (terminator == loop) @@ -139,11 +139,11 @@ static bool isTrivialSingleIterationLoop( switch (terminator->getOp()) { case kIROp_loop: - if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock)) + if (isBlockInRegion(domTree, as<IRLoop>(terminator), breakOriginBlock)) return false; break; case kIROp_Switch: - if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock)) + if (isBlockInRegion(domTree, as<IRSwitch>(terminator), breakOriginBlock)) return false; break; default: @@ -853,8 +853,10 @@ static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions op // break at the end of the loop, we can remove the header and turn it into // a normal branch. auto targetBlock = loop->getTargetBlock(); + if (!simplificationContext.domTree) + simplificationContext.domTree = computeDominatorTree(func); if (options.removeTrivialSingleIterationLoops && - isTrivialSingleIterationLoop(simplificationContext, func, loop)) + isTrivialSingleIterationLoop(simplificationContext.domTree, func, loop)) { builder.setInsertBefore(loop); List<IRInst*> args; diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h index 8adde3a42..3822b0fa8 100644 --- a/source/slang/slang-ir-simplify-cfg.h +++ b/source/slang/slang-ir-simplify-cfg.h @@ -5,6 +5,8 @@ namespace Slang { struct IRModule; struct IRGlobalValueWithCode; +struct IRLoop; +struct IRDominatorTree; struct CFGSimplificationOptions { @@ -14,6 +16,11 @@ struct CFGSimplificationOptions static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{false, false}; } }; +bool isTrivialSingleIterationLoop( + IRDominatorTree* domTree, + IRGlobalValueWithCode* func, + IRLoop* loop); + /// Simplifies control flow graph by merging basic blocks that /// forms a simple linear chain. /// Returns true if changed. diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 2e2ba358f..82b424daa 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2183,7 +2183,7 @@ void simplifyIRForSpirvLegalization(TargetProgram* target, DiagnosticSink* sink, funcChanged = false; funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); - funcChanged |= removeRedundancyInFunc(func); + funcChanged |= removeRedundancyInFunc(func, false); CFGSimplificationOptions options; options.removeTrivialSingleIterationLoops = true; options.removeSideEffectFreeLoops = false; diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index b5081dab7..41da868b6 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -90,7 +90,7 @@ void simplifyIR( funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); if (options.removeRedundancy) - funcChanged |= removeRedundancyInFunc(func); + funcChanged |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts); funcChanged |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never @@ -122,7 +122,7 @@ void simplifyNonSSAIR(TargetProgram* target, IRModule* module, IRSimplificationO changed |= peepholeOptimize(target, module, options.peepholeOptions); if (!options.minimalOptimization) - changed |= removeRedundancy(module); + changed |= removeRedundancy(module, options.hoistLoopInvariantInsts); changed |= simplifyCFG(module, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since @@ -153,7 +153,7 @@ void simplifyFunc( changed |= applySparseConditionalConstantPropagation(func, sink); changed |= peepholeOptimize(target, func); if (!options.minimalOptimization) - changed |= removeRedundancyInFunc(func); + changed |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts); changed |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h index fbff2fdda..e3db559d7 100644 --- a/source/slang/slang-ir-ssa-simplification.h +++ b/source/slang/slang-ir-ssa-simplification.h @@ -20,6 +20,7 @@ struct IRSimplificationOptions bool minimalOptimization = false; bool removeRedundancy = false; + bool hoistLoopInvariantInsts = false; static IRSimplificationOptions getDefault(TargetProgram* targetProgram); diff --git a/tests/spirv/forceinline-nohoist.slang b/tests/spirv/forceinline-nohoist.slang new file mode 100644 index 000000000..54db2838f --- /dev/null +++ b/tests/spirv/forceinline-nohoist.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main
+
+// Verify that the call to dot is after the conditional branch.
+
+// CHECK: OpBranchConditional
+// CHECK: OpDot
+
+[ForceInline]
+float test(bool x, float3 a, float3 b) {
+ float result = 0;
+ if(x) {
+ result = dot(a, b);
+ }
+ return result;
+}
+
+float caller(uniform bool x, uniform float3 a, uniform float3 b) {
+ return test(x, a, b);
+}
+
+RWStructuredBuffer<float> output;
+
+uniform bool branchCheck;
+uniform float3 uniformA;
+uniform float3 uniformB;
+
+[numthreads(1,1,1)]
+[shader("compute")]
+void main()
+{
+ output[0] = caller(branchCheck, uniformA, uniformB);
+}
|
