summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorjarcherNV <jarcher@nvidia.com>2025-04-11 14:51:48 -0700
committerGitHub <noreply@github.com>2025-04-11 21:51:48 +0000
commit61a6c211b1587a7b9ed6a24ae1ba6fe0600c80d8 (patch)
treeb67322ca56975223cc2eb897acc29155928128fd
parent88a180ba0aa57b2d0fb4956005db2ea73dc73420 (diff)
Add flag to hoist instructions (#6740)
This fixes issue #6654 Only hoist instructions that are optimized by prepareFuncForForwardDiff. Add flag hoistLoopInvariantInsts to IRSimplificationOptions and set this to true only if called from prepareFuncForForwardDiff, then only hoist if the flag is set. Additionally, do not hoist loops if they only have a single trivial iteration.
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-defer-buffer-load.cpp2
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp25
-rw-r--r--source/slang/slang-ir-redundancy-removal.h4
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp24
-rw-r--r--source/slang/slang-ir-simplify-cfg.h7
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp2
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp6
-rw-r--r--source/slang/slang-ir-ssa-simplification.h1
-rw-r--r--tests/spirv/forceinline-nohoist.slang32
10 files changed, 79 insertions, 25 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 92c35a618..08ca56f7d 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1877,6 +1877,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
disableIRValidationAtInsert();
auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr);
simplifyOptions.removeRedundancy = true;
+ simplifyOptions.hoistLoopInvariantInsts = true;
simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions);
enableIRValidationAtInsert();
}
diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp
index fea73e705..bd3b78f9e 100644
--- a/source/slang/slang-ir-defer-buffer-load.cpp
+++ b/source/slang/slang-ir-defer-buffer-load.cpp
@@ -167,7 +167,7 @@ struct DeferBufferLoadContext
void deferBufferLoadInFunc(IRFunc* func)
{
- removeRedundancyInFunc(func);
+ removeRedundancyInFunc(func, false);
currentFunc = func;
dominatorTree = func->getModule()->findOrCreateDominatorTree(func);
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 0d9b910c4..a6dac723e 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-dominators.h"
+#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
namespace Slang
@@ -20,6 +21,10 @@ struct RedundancyRemovalContext
auto terminatorInst = parentBlock->getTerminator();
if (auto loop = as<IRLoop>(terminatorInst))
{
+ // Don't bother hoisting if a loop has only a single trivial iteration.
+ if (isTrivialSingleIterationLoop(dom, func, loop))
+ continue;
+
// If `inst` is outside of the loop region, don't hoist it into the loop.
if (dom->dominates(loop->getBreakBlock(), inst))
continue;
@@ -62,7 +67,8 @@ struct RedundancyRemovalContext
bool removeRedundancyInBlock(
Dictionary<IRBlock*, DeduplicateContext>& mapBlockToDedupContext,
IRGlobalValueWithCode* func,
- IRBlock* block)
+ IRBlock* block,
+ bool hoistLoopInvariantInsts)
{
bool result = false;
auto& deduplicateContext = mapBlockToDedupContext.getValue(block);
@@ -89,7 +95,8 @@ struct RedundancyRemovalContext
{
// This inst is unique, we should consider hoisting it
// if it is inside a loop.
- result |= tryHoistInstToOuterMostLoop(func, resultInst);
+ if (hoistLoopInvariantInsts)
+ result |= tryHoistInstToOuterMostLoop(func, resultInst);
}
}
for (auto child : dom->getImmediatelyDominatedBlocks(block))
@@ -101,25 +108,25 @@ struct RedundancyRemovalContext
}
};
-bool removeRedundancy(IRModule* module)
+bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts)
{
bool changed = false;
for (auto inst : module->getGlobalInsts())
{
if (auto genericInst = as<IRGeneric>(inst))
{
- removeRedundancyInFunc(genericInst);
+ removeRedundancyInFunc(genericInst, hoistLoopInvariantInsts);
inst = findGenericReturnVal(genericInst);
}
if (auto func = as<IRFunc>(inst))
{
- changed |= removeRedundancyInFunc(func);
+ changed |= removeRedundancyInFunc(func, hoistLoopInvariantInsts);
}
}
return changed;
}
-bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts)
{
auto root = func->getFirstBlock();
if (!root)
@@ -139,7 +146,11 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
{
for (auto block : workList)
{
- result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block);
+ result |= context.removeRedundancyInBlock(
+ mapBlockToDeduplicateContext,
+ func,
+ block,
+ hoistLoopInvariantInsts);
for (auto child : context.dom->getImmediatelyDominatedBlocks(block))
{
diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h
index 40f3b07a9..9d9fb43f0 100644
--- a/source/slang/slang-ir-redundancy-removal.h
+++ b/source/slang/slang-ir-redundancy-removal.h
@@ -7,8 +7,8 @@ namespace Slang
struct IRModule;
struct IRGlobalValueWithCode;
-bool removeRedundancy(IRModule* module);
-bool removeRedundancyInFunc(IRGlobalValueWithCode* func);
+bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts);
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts);
bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func);
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 68d79617a..280874f74 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -66,8 +66,8 @@ static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock*
// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and
// there is only one break out of the loop at the very end. The function generates `regionTree` if
// it is needed and hasn't been generated yet.
-static bool isTrivialSingleIterationLoop(
- CFGSimplificationContext& context,
+bool isTrivialSingleIterationLoop(
+ IRDominatorTree* domTree,
IRGlobalValueWithCode* func,
IRLoop* loop)
{
@@ -91,21 +91,21 @@ static bool isTrivialSingleIterationLoop(
//
// We need to verify this is a trivial loop by checking if there is any multi-level breaks
// that skips out of this loop.
- if (!context.domTree)
- context.domTree = computeDominatorTree(func);
+ if (!domTree)
+ domTree = computeDominatorTree(func);
bool hasMultiLevelBreaks = false;
- auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks);
+ auto loopBlocks = collectBlocksInRegion(domTree, loop, &hasMultiLevelBreaks);
if (hasMultiLevelBreaks)
return false;
for (auto block : loopBlocks)
{
for (auto branchTarget : block->getSuccessors())
{
- if (!context.domTree->dominates(loop->getParent(), branchTarget))
+ if (!domTree->dominates(loop->getParent(), branchTarget))
return false;
if (branchTarget != loop->getBreakBlock())
continue;
- if (findBreakableRegionHeaderInst(context.domTree, block) != loop)
+ if (findBreakableRegionHeaderInst(domTree, block) != loop)
{
// If the break is initiated from a nested region, this is not trivial.
return false;
@@ -127,7 +127,7 @@ static bool isTrivialSingleIterationLoop(
auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin();
for (auto currBlock = breakOriginBlock; currBlock;
- currBlock = context.domTree->getImmediateDominator(currBlock))
+ currBlock = domTree->getImmediateDominator(currBlock))
{
auto terminator = currBlock->getTerminator();
if (terminator == loop)
@@ -139,11 +139,11 @@ static bool isTrivialSingleIterationLoop(
switch (terminator->getOp())
{
case kIROp_loop:
- if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock))
+ if (isBlockInRegion(domTree, as<IRLoop>(terminator), breakOriginBlock))
return false;
break;
case kIROp_Switch:
- if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock))
+ if (isBlockInRegion(domTree, as<IRSwitch>(terminator), breakOriginBlock))
return false;
break;
default:
@@ -853,8 +853,10 @@ static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions op
// break at the end of the loop, we can remove the header and turn it into
// a normal branch.
auto targetBlock = loop->getTargetBlock();
+ if (!simplificationContext.domTree)
+ simplificationContext.domTree = computeDominatorTree(func);
if (options.removeTrivialSingleIterationLoops &&
- isTrivialSingleIterationLoop(simplificationContext, func, loop))
+ isTrivialSingleIterationLoop(simplificationContext.domTree, func, loop))
{
builder.setInsertBefore(loop);
List<IRInst*> args;
diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h
index 8adde3a42..3822b0fa8 100644
--- a/source/slang/slang-ir-simplify-cfg.h
+++ b/source/slang/slang-ir-simplify-cfg.h
@@ -5,6 +5,8 @@ namespace Slang
{
struct IRModule;
struct IRGlobalValueWithCode;
+struct IRLoop;
+struct IRDominatorTree;
struct CFGSimplificationOptions
{
@@ -14,6 +16,11 @@ struct CFGSimplificationOptions
static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{false, false}; }
};
+bool isTrivialSingleIterationLoop(
+ IRDominatorTree* domTree,
+ IRGlobalValueWithCode* func,
+ IRLoop* loop);
+
/// Simplifies control flow graph by merging basic blocks that
/// forms a simple linear chain.
/// Returns true if changed.
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 2e2ba358f..82b424daa 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -2183,7 +2183,7 @@ void simplifyIRForSpirvLegalization(TargetProgram* target, DiagnosticSink* sink,
funcChanged = false;
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
- funcChanged |= removeRedundancyInFunc(func);
+ funcChanged |= removeRedundancyInFunc(func, false);
CFGSimplificationOptions options;
options.removeTrivialSingleIterationLoops = true;
options.removeSideEffectFreeLoops = false;
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index b5081dab7..41da868b6 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -90,7 +90,7 @@ void simplifyIR(
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
if (options.removeRedundancy)
- funcChanged |= removeRedundancyInFunc(func);
+ funcChanged |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts);
funcChanged |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never
@@ -122,7 +122,7 @@ void simplifyNonSSAIR(TargetProgram* target, IRModule* module, IRSimplificationO
changed |= peepholeOptimize(target, module, options.peepholeOptions);
if (!options.minimalOptimization)
- changed |= removeRedundancy(module);
+ changed |= removeRedundancy(module, options.hoistLoopInvariantInsts);
changed |= simplifyCFG(module, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
@@ -153,7 +153,7 @@ void simplifyFunc(
changed |= applySparseConditionalConstantPropagation(func, sink);
changed |= peepholeOptimize(target, func);
if (!options.minimalOptimization)
- changed |= removeRedundancyInFunc(func);
+ changed |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts);
changed |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
index fbff2fdda..e3db559d7 100644
--- a/source/slang/slang-ir-ssa-simplification.h
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -20,6 +20,7 @@ struct IRSimplificationOptions
bool minimalOptimization = false;
bool removeRedundancy = false;
+ bool hoistLoopInvariantInsts = false;
static IRSimplificationOptions getDefault(TargetProgram* targetProgram);
diff --git a/tests/spirv/forceinline-nohoist.slang b/tests/spirv/forceinline-nohoist.slang
new file mode 100644
index 000000000..54db2838f
--- /dev/null
+++ b/tests/spirv/forceinline-nohoist.slang
@@ -0,0 +1,32 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -stage compute -entry main
+
+// Verify that the call to dot is after the conditional branch.
+
+// CHECK: OpBranchConditional
+// CHECK: OpDot
+
+[ForceInline]
+float test(bool x, float3 a, float3 b) {
+ float result = 0;
+ if(x) {
+ result = dot(a, b);
+ }
+ return result;
+}
+
+float caller(uniform bool x, uniform float3 a, uniform float3 b) {
+ return test(x, a, b);
+}
+
+RWStructuredBuffer<float> output;
+
+uniform bool branchCheck;
+uniform float3 uniformA;
+uniform float3 uniformB;
+
+[numthreads(1,1,1)]
+[shader("compute")]
+void main()
+{
+ output[0] = caller(branchCheck, uniformA, uniformB);
+}