summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-defer-buffer-load.cpp2
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp25
-rw-r--r--source/slang/slang-ir-redundancy-removal.h4
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp24
-rw-r--r--source/slang/slang-ir-simplify-cfg.h7
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp2
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp6
-rw-r--r--source/slang/slang-ir-ssa-simplification.h1
9 files changed, 47 insertions, 25 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 92c35a618..08ca56f7d 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1877,6 +1877,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
disableIRValidationAtInsert();
auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr);
simplifyOptions.removeRedundancy = true;
+ simplifyOptions.hoistLoopInvariantInsts = true;
simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions);
enableIRValidationAtInsert();
}
diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp
index fea73e705..bd3b78f9e 100644
--- a/source/slang/slang-ir-defer-buffer-load.cpp
+++ b/source/slang/slang-ir-defer-buffer-load.cpp
@@ -167,7 +167,7 @@ struct DeferBufferLoadContext
void deferBufferLoadInFunc(IRFunc* func)
{
- removeRedundancyInFunc(func);
+ removeRedundancyInFunc(func, false);
currentFunc = func;
dominatorTree = func->getModule()->findOrCreateDominatorTree(func);
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 0d9b910c4..a6dac723e 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-dominators.h"
+#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
namespace Slang
@@ -20,6 +21,10 @@ struct RedundancyRemovalContext
auto terminatorInst = parentBlock->getTerminator();
if (auto loop = as<IRLoop>(terminatorInst))
{
+ // Don't bother hoisting if a loop has only a single trivial iteration.
+ if (isTrivialSingleIterationLoop(dom, func, loop))
+ continue;
+
// If `inst` is outside of the loop region, don't hoist it into the loop.
if (dom->dominates(loop->getBreakBlock(), inst))
continue;
@@ -62,7 +67,8 @@ struct RedundancyRemovalContext
bool removeRedundancyInBlock(
Dictionary<IRBlock*, DeduplicateContext>& mapBlockToDedupContext,
IRGlobalValueWithCode* func,
- IRBlock* block)
+ IRBlock* block,
+ bool hoistLoopInvariantInsts)
{
bool result = false;
auto& deduplicateContext = mapBlockToDedupContext.getValue(block);
@@ -89,7 +95,8 @@ struct RedundancyRemovalContext
{
// This inst is unique, we should consider hoisting it
// if it is inside a loop.
- result |= tryHoistInstToOuterMostLoop(func, resultInst);
+ if (hoistLoopInvariantInsts)
+ result |= tryHoistInstToOuterMostLoop(func, resultInst);
}
}
for (auto child : dom->getImmediatelyDominatedBlocks(block))
@@ -101,25 +108,25 @@ struct RedundancyRemovalContext
}
};
-bool removeRedundancy(IRModule* module)
+bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts)
{
bool changed = false;
for (auto inst : module->getGlobalInsts())
{
if (auto genericInst = as<IRGeneric>(inst))
{
- removeRedundancyInFunc(genericInst);
+ removeRedundancyInFunc(genericInst, hoistLoopInvariantInsts);
inst = findGenericReturnVal(genericInst);
}
if (auto func = as<IRFunc>(inst))
{
- changed |= removeRedundancyInFunc(func);
+ changed |= removeRedundancyInFunc(func, hoistLoopInvariantInsts);
}
}
return changed;
}
-bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts)
{
auto root = func->getFirstBlock();
if (!root)
@@ -139,7 +146,11 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
{
for (auto block : workList)
{
- result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block);
+ result |= context.removeRedundancyInBlock(
+ mapBlockToDeduplicateContext,
+ func,
+ block,
+ hoistLoopInvariantInsts);
for (auto child : context.dom->getImmediatelyDominatedBlocks(block))
{
diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h
index 40f3b07a9..9d9fb43f0 100644
--- a/source/slang/slang-ir-redundancy-removal.h
+++ b/source/slang/slang-ir-redundancy-removal.h
@@ -7,8 +7,8 @@ namespace Slang
struct IRModule;
struct IRGlobalValueWithCode;
-bool removeRedundancy(IRModule* module);
-bool removeRedundancyInFunc(IRGlobalValueWithCode* func);
+bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts);
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts);
bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func);
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 68d79617a..280874f74 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -66,8 +66,8 @@ static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock*
// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and
// there is only one break out of the loop at the very end. The function generates `regionTree` if
// it is needed and hasn't been generated yet.
-static bool isTrivialSingleIterationLoop(
- CFGSimplificationContext& context,
+bool isTrivialSingleIterationLoop(
+ IRDominatorTree* domTree,
IRGlobalValueWithCode* func,
IRLoop* loop)
{
@@ -91,21 +91,21 @@ static bool isTrivialSingleIterationLoop(
//
// We need to verify this is a trivial loop by checking if there is any multi-level breaks
// that skips out of this loop.
- if (!context.domTree)
- context.domTree = computeDominatorTree(func);
+ if (!domTree)
+ domTree = computeDominatorTree(func);
bool hasMultiLevelBreaks = false;
- auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks);
+ auto loopBlocks = collectBlocksInRegion(domTree, loop, &hasMultiLevelBreaks);
if (hasMultiLevelBreaks)
return false;
for (auto block : loopBlocks)
{
for (auto branchTarget : block->getSuccessors())
{
- if (!context.domTree->dominates(loop->getParent(), branchTarget))
+ if (!domTree->dominates(loop->getParent(), branchTarget))
return false;
if (branchTarget != loop->getBreakBlock())
continue;
- if (findBreakableRegionHeaderInst(context.domTree, block) != loop)
+ if (findBreakableRegionHeaderInst(domTree, block) != loop)
{
// If the break is initiated from a nested region, this is not trivial.
return false;
@@ -127,7 +127,7 @@ static bool isTrivialSingleIterationLoop(
auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin();
for (auto currBlock = breakOriginBlock; currBlock;
- currBlock = context.domTree->getImmediateDominator(currBlock))
+ currBlock = domTree->getImmediateDominator(currBlock))
{
auto terminator = currBlock->getTerminator();
if (terminator == loop)
@@ -139,11 +139,11 @@ static bool isTrivialSingleIterationLoop(
switch (terminator->getOp())
{
case kIROp_loop:
- if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock))
+ if (isBlockInRegion(domTree, as<IRLoop>(terminator), breakOriginBlock))
return false;
break;
case kIROp_Switch:
- if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock))
+ if (isBlockInRegion(domTree, as<IRSwitch>(terminator), breakOriginBlock))
return false;
break;
default:
@@ -853,8 +853,10 @@ static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions op
// break at the end of the loop, we can remove the header and turn it into
// a normal branch.
auto targetBlock = loop->getTargetBlock();
+ if (!simplificationContext.domTree)
+ simplificationContext.domTree = computeDominatorTree(func);
if (options.removeTrivialSingleIterationLoops &&
- isTrivialSingleIterationLoop(simplificationContext, func, loop))
+ isTrivialSingleIterationLoop(simplificationContext.domTree, func, loop))
{
builder.setInsertBefore(loop);
List<IRInst*> args;
diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h
index 8adde3a42..3822b0fa8 100644
--- a/source/slang/slang-ir-simplify-cfg.h
+++ b/source/slang/slang-ir-simplify-cfg.h
@@ -5,6 +5,8 @@ namespace Slang
{
struct IRModule;
struct IRGlobalValueWithCode;
+struct IRLoop;
+struct IRDominatorTree;
struct CFGSimplificationOptions
{
@@ -14,6 +16,11 @@ struct CFGSimplificationOptions
static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{false, false}; }
};
+bool isTrivialSingleIterationLoop(
+ IRDominatorTree* domTree,
+ IRGlobalValueWithCode* func,
+ IRLoop* loop);
+
/// Simplifies control flow graph by merging basic blocks that
/// forms a simple linear chain.
/// Returns true if changed.
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 2e2ba358f..82b424daa 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -2183,7 +2183,7 @@ void simplifyIRForSpirvLegalization(TargetProgram* target, DiagnosticSink* sink,
funcChanged = false;
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
- funcChanged |= removeRedundancyInFunc(func);
+ funcChanged |= removeRedundancyInFunc(func, false);
CFGSimplificationOptions options;
options.removeTrivialSingleIterationLoops = true;
options.removeSideEffectFreeLoops = false;
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index b5081dab7..41da868b6 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -90,7 +90,7 @@ void simplifyIR(
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
if (options.removeRedundancy)
- funcChanged |= removeRedundancyInFunc(func);
+ funcChanged |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts);
funcChanged |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never
@@ -122,7 +122,7 @@ void simplifyNonSSAIR(TargetProgram* target, IRModule* module, IRSimplificationO
changed |= peepholeOptimize(target, module, options.peepholeOptions);
if (!options.minimalOptimization)
- changed |= removeRedundancy(module);
+ changed |= removeRedundancy(module, options.hoistLoopInvariantInsts);
changed |= simplifyCFG(module, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
@@ -153,7 +153,7 @@ void simplifyFunc(
changed |= applySparseConditionalConstantPropagation(func, sink);
changed |= peepholeOptimize(target, func);
if (!options.minimalOptimization)
- changed |= removeRedundancyInFunc(func);
+ changed |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts);
changed |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
index fbff2fdda..e3db559d7 100644
--- a/source/slang/slang-ir-ssa-simplification.h
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -20,6 +20,7 @@ struct IRSimplificationOptions
bool minimalOptimization = false;
bool removeRedundancy = false;
+ bool hoistLoopInvariantInsts = false;
static IRSimplificationOptions getDefault(TargetProgram* targetProgram);