diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-defer-buffer-load.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-redundancy-removal.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.h | 1 |
9 files changed, 47 insertions, 25 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 92c35a618..08ca56f7d 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1877,6 +1877,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) disableIRValidationAtInsert(); auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr); simplifyOptions.removeRedundancy = true; + simplifyOptions.hoistLoopInvariantInsts = true; simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions); enableIRValidationAtInsert(); } diff --git a/source/slang/slang-ir-defer-buffer-load.cpp b/source/slang/slang-ir-defer-buffer-load.cpp index fea73e705..bd3b78f9e 100644 --- a/source/slang/slang-ir-defer-buffer-load.cpp +++ b/source/slang/slang-ir-defer-buffer-load.cpp @@ -167,7 +167,7 @@ struct DeferBufferLoadContext void deferBufferLoadInFunc(IRFunc* func) { - removeRedundancyInFunc(func); + removeRedundancyInFunc(func, false); currentFunc = func; dominatorTree = func->getModule()->findOrCreateDominatorTree(func); diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 0d9b910c4..a6dac723e 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -1,6 +1,7 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-dominators.h" +#include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" namespace Slang @@ -20,6 +21,10 @@ struct RedundancyRemovalContext auto terminatorInst = parentBlock->getTerminator(); if (auto loop = as<IRLoop>(terminatorInst)) { + // Don't bother hoisting if a loop has only a single trivial iteration. + if (isTrivialSingleIterationLoop(dom, func, loop)) + continue; + // If `inst` is outside of the loop region, don't hoist it into the loop. if (dom->dominates(loop->getBreakBlock(), inst)) continue; @@ -62,7 +67,8 @@ struct RedundancyRemovalContext bool removeRedundancyInBlock( Dictionary<IRBlock*, DeduplicateContext>& mapBlockToDedupContext, IRGlobalValueWithCode* func, - IRBlock* block) + IRBlock* block, + bool hoistLoopInvariantInsts) { bool result = false; auto& deduplicateContext = mapBlockToDedupContext.getValue(block); @@ -89,7 +95,8 @@ struct RedundancyRemovalContext { // This inst is unique, we should consider hoisting it // if it is inside a loop. - result |= tryHoistInstToOuterMostLoop(func, resultInst); + if (hoistLoopInvariantInsts) + result |= tryHoistInstToOuterMostLoop(func, resultInst); } } for (auto child : dom->getImmediatelyDominatedBlocks(block)) @@ -101,25 +108,25 @@ struct RedundancyRemovalContext } }; -bool removeRedundancy(IRModule* module) +bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts) { bool changed = false; for (auto inst : module->getGlobalInsts()) { if (auto genericInst = as<IRGeneric>(inst)) { - removeRedundancyInFunc(genericInst); + removeRedundancyInFunc(genericInst, hoistLoopInvariantInsts); inst = findGenericReturnVal(genericInst); } if (auto func = as<IRFunc>(inst)) { - changed |= removeRedundancyInFunc(func); + changed |= removeRedundancyInFunc(func, hoistLoopInvariantInsts); } } return changed; } -bool removeRedundancyInFunc(IRGlobalValueWithCode* func) +bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts) { auto root = func->getFirstBlock(); if (!root) @@ -139,7 +146,11 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func) { for (auto block : workList) { - result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block); + result |= context.removeRedundancyInBlock( + mapBlockToDeduplicateContext, + func, + block, + hoistLoopInvariantInsts); for (auto child : context.dom->getImmediatelyDominatedBlocks(block)) { diff --git a/source/slang/slang-ir-redundancy-removal.h b/source/slang/slang-ir-redundancy-removal.h index 40f3b07a9..9d9fb43f0 100644 --- a/source/slang/slang-ir-redundancy-removal.h +++ b/source/slang/slang-ir-redundancy-removal.h @@ -7,8 +7,8 @@ namespace Slang struct IRModule; struct IRGlobalValueWithCode; -bool removeRedundancy(IRModule* module); -bool removeRedundancyInFunc(IRGlobalValueWithCode* func); +bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts); +bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts); bool eliminateRedundantLoadStore(IRGlobalValueWithCode* func); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 68d79617a..280874f74 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -66,8 +66,8 @@ static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock* // Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and // there is only one break out of the loop at the very end. The function generates `regionTree` if // it is needed and hasn't been generated yet. -static bool isTrivialSingleIterationLoop( - CFGSimplificationContext& context, +bool isTrivialSingleIterationLoop( + IRDominatorTree* domTree, IRGlobalValueWithCode* func, IRLoop* loop) { @@ -91,21 +91,21 @@ static bool isTrivialSingleIterationLoop( // // We need to verify this is a trivial loop by checking if there is any multi-level breaks // that skips out of this loop. - if (!context.domTree) - context.domTree = computeDominatorTree(func); + if (!domTree) + domTree = computeDominatorTree(func); bool hasMultiLevelBreaks = false; - auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks); + auto loopBlocks = collectBlocksInRegion(domTree, loop, &hasMultiLevelBreaks); if (hasMultiLevelBreaks) return false; for (auto block : loopBlocks) { for (auto branchTarget : block->getSuccessors()) { - if (!context.domTree->dominates(loop->getParent(), branchTarget)) + if (!domTree->dominates(loop->getParent(), branchTarget)) return false; if (branchTarget != loop->getBreakBlock()) continue; - if (findBreakableRegionHeaderInst(context.domTree, block) != loop) + if (findBreakableRegionHeaderInst(domTree, block) != loop) { // If the break is initiated from a nested region, this is not trivial. return false; @@ -127,7 +127,7 @@ static bool isTrivialSingleIterationLoop( auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin(); for (auto currBlock = breakOriginBlock; currBlock; - currBlock = context.domTree->getImmediateDominator(currBlock)) + currBlock = domTree->getImmediateDominator(currBlock)) { auto terminator = currBlock->getTerminator(); if (terminator == loop) @@ -139,11 +139,11 @@ static bool isTrivialSingleIterationLoop( switch (terminator->getOp()) { case kIROp_loop: - if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock)) + if (isBlockInRegion(domTree, as<IRLoop>(terminator), breakOriginBlock)) return false; break; case kIROp_Switch: - if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock)) + if (isBlockInRegion(domTree, as<IRSwitch>(terminator), breakOriginBlock)) return false; break; default: @@ -853,8 +853,10 @@ static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions op // break at the end of the loop, we can remove the header and turn it into // a normal branch. auto targetBlock = loop->getTargetBlock(); + if (!simplificationContext.domTree) + simplificationContext.domTree = computeDominatorTree(func); if (options.removeTrivialSingleIterationLoops && - isTrivialSingleIterationLoop(simplificationContext, func, loop)) + isTrivialSingleIterationLoop(simplificationContext.domTree, func, loop)) { builder.setInsertBefore(loop); List<IRInst*> args; diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h index 8adde3a42..3822b0fa8 100644 --- a/source/slang/slang-ir-simplify-cfg.h +++ b/source/slang/slang-ir-simplify-cfg.h @@ -5,6 +5,8 @@ namespace Slang { struct IRModule; struct IRGlobalValueWithCode; +struct IRLoop; +struct IRDominatorTree; struct CFGSimplificationOptions { @@ -14,6 +16,11 @@ struct CFGSimplificationOptions static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{false, false}; } }; +bool isTrivialSingleIterationLoop( + IRDominatorTree* domTree, + IRGlobalValueWithCode* func, + IRLoop* loop); + /// Simplifies control flow graph by merging basic blocks that /// forms a simple linear chain. /// Returns true if changed. diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 2e2ba358f..82b424daa 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2183,7 +2183,7 @@ void simplifyIRForSpirvLegalization(TargetProgram* target, DiagnosticSink* sink, funcChanged = false; funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); - funcChanged |= removeRedundancyInFunc(func); + funcChanged |= removeRedundancyInFunc(func, false); CFGSimplificationOptions options; options.removeTrivialSingleIterationLoops = true; options.removeSideEffectFreeLoops = false; diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index b5081dab7..41da868b6 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -90,7 +90,7 @@ void simplifyIR( funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); if (options.removeRedundancy) - funcChanged |= removeRedundancyInFunc(func); + funcChanged |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts); funcChanged |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never @@ -122,7 +122,7 @@ void simplifyNonSSAIR(TargetProgram* target, IRModule* module, IRSimplificationO changed |= peepholeOptimize(target, module, options.peepholeOptions); if (!options.minimalOptimization) - changed |= removeRedundancy(module); + changed |= removeRedundancy(module, options.hoistLoopInvariantInsts); changed |= simplifyCFG(module, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since @@ -153,7 +153,7 @@ void simplifyFunc( changed |= applySparseConditionalConstantPropagation(func, sink); changed |= peepholeOptimize(target, func); if (!options.minimalOptimization) - changed |= removeRedundancyInFunc(func); + changed |= removeRedundancyInFunc(func, options.hoistLoopInvariantInsts); changed |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h index fbff2fdda..e3db559d7 100644 --- a/source/slang/slang-ir-ssa-simplification.h +++ b/source/slang/slang-ir-ssa-simplification.h @@ -20,6 +20,7 @@ struct IRSimplificationOptions bool minimalOptimization = false; bool removeRedundancy = false; + bool hoistLoopInvariantInsts = false; static IRSimplificationOptions getDefault(TargetProgram* targetProgram); |
