diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-13 10:39:12 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-13 10:39:12 -0800 |
| commit | 977eb925b7e9cb1a763c1e5563b2bc605b6476d6 (patch) | |
| tree | bf4922bdf76e9dbd25a2186c93097b30ffb57432 /source | |
| parent | 4dbc74a953ae1b34ce64a4eaef3aa7feb73663b9 (diff) | |
Eliminate `continue` to allow unrolling any loops. (#2645)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 120 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 2 |
3 files changed, 127 insertions, 5 deletions
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 725f20902..f606f0cc1 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -6,6 +6,7 @@ #include "slang-ir-clone.h" #include "slang-ir-util.h" #include "slang-ir-simplify-cfg.h" +#include "slang-ir-dce.h" namespace Slang { @@ -444,10 +445,9 @@ static bool _unrollLoop( return loopTerminated; } -bool unrollLoopsInFunc( - SharedIRBuilder* sharedBuilder, - IRGlobalValueWithCode* func, - DiagnosticSink* sink) +// Visits all loop insts in a func, inner loop first. +template<typename TFunc> +List<IRLoop*> collectLoopsInFunc(IRGlobalValueWithCode* func, const TFunc& filter) { List<IRLoop*> loops; @@ -458,18 +458,31 @@ bool unrollLoopsInFunc( { if (auto loop = as<IRLoop>(block->getTerminator())) { - if (loop->findDecoration<IRForceUnrollDecoration>()) + if (filter(loop)) { loops.add(loop); } } } + return loops; +} + +bool unrollLoopsInFunc( + SharedIRBuilder* sharedBuilder, + IRGlobalValueWithCode* func, + DiagnosticSink* sink) +{ + List<IRLoop*> loops = collectLoopsInFunc( + func, [](IRLoop* l) { return l->findDecoration<IRForceUnrollDecoration>() != nullptr; }); if (loops.getCount() == 0) return true; for (auto loop : loops) { + // Remove any continue jumps from the loop. + eliminateContinueBlocks(sharedBuilder, loop); + auto postOrderReverseCFG = getPostorderOnReverseCFG(func); Dictionary<IRBlock*, int> blockOrdering; @@ -490,6 +503,7 @@ bool unrollLoopsInFunc( // Make sure we simplify things as much as possible before // attempting to potentially unroll outer loop. simplifyCFG(func); + eliminateDeadCode(func); } return true; } @@ -517,4 +531,100 @@ bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, Diagn return true; } +static void _moveParams(IRBlock* dest, IRBlock* src) +{ + for (auto param = src->getFirstChild(); param;) + { + auto nextInst = param->getNextInst(); + if (as<IRDecoration>(param) || as<IRParam>(param)) + { + param->insertAtEnd(dest); + } + else + { + break; + } + param = nextInst; + } +} + +void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst) +{ + // Eliminate the continue jumps by turning a loop in the form of: + // for (;;) + // { + // <loop body> + // continueBlock: + // <continuePart> + // } + // into: + // for (;;) // original loop + // { + // for(;;) // breakableRegionHeader + // { + // <loop body> + // } + // breakableRegionBreakBlock: + // <continuePart> + // } + // where a continue is replaced with a "break" into breakableRegionBreakBlock. + // + + if (loopInst->getContinueBlock() == loopInst->getTargetBlock()) + return; + + // If the continue block is not reachable, remove it. + auto continueBlock = loopInst->getContinueBlock(); + if (continueBlock && !continueBlock->hasMoreThanOneUse()) + { + loopInst->continueBlock.set(loopInst->getTargetBlock()); + continueBlock->removeAndDeallocate(); + return; + } + + // We have determined that there is really a non-trivial continue block in the loop body, + // we will now introduce a breakable region for each iteration. + + IRBuilder builder(sharedBuilder); + + auto targetBlock = loopInst->getTargetBlock(); + + auto innerBreakableRegionHeader = builder.createBlock(); + innerBreakableRegionHeader->insertBefore(targetBlock); + + auto innerBreakableRegionBreakBlock = builder.createBlock(); + innerBreakableRegionBreakBlock->insertBefore(continueBlock); + + loopInst->block.set(innerBreakableRegionHeader); + loopInst->continueBlock.set(innerBreakableRegionHeader); + + targetBlock->replaceUsesWith(innerBreakableRegionHeader); + + // Move decorations and params from original targetBlock to innerBreakableRegionHeader. + _moveParams(innerBreakableRegionHeader, targetBlock); + + builder.setInsertInto(innerBreakableRegionHeader); + builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock); + + continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock); + + builder.setInsertInto(innerBreakableRegionBreakBlock); + _moveParams(innerBreakableRegionBreakBlock, continueBlock); + builder.emitBranch(continueBlock); +} + +void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func) +{ + List<IRLoop*> loops = collectLoopsInFunc( + func, [](IRLoop* l) { return l->getContinueBlock() != l->getTargetBlock(); }); + + if (loops.getCount() == 0) + return; + + for (auto loop : loops) + { + eliminateContinueBlocks(sharedBuilder, loop); + } +} + } diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h index a63625285..340011cb5 100644 --- a/source/slang/slang-ir-loop-unroll.h +++ b/source/slang/slang-ir-loop-unroll.h @@ -13,4 +13,14 @@ namespace Slang bool unrollLoopsInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, DiagnosticSink* sink); bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink); + + + // Turn a loop with continue block into a loop with only back jumps and breaks. + // Each iteration will be wrapped in a breakable region, where everything before `continue` + // is within the breakable region, and everything after `continue` is outside the breakable + // region. A `continue` then becomes a `break` in the inner breakable region, and a `break` + // becomes a multi-level break out of the parent loop. + void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst); + void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func); + } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 0448eb649..1ea426715 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -224,7 +224,9 @@ String dumpIRToString(IRInst* root) StringBuilder sb; StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); IRDumpOptions options = {}; +#if 0 options.flags = IRDumpOptions::Flag::DumpDebugIds; +#endif dumpIR(root, options, nullptr, &writer); return sb.ToString(); } |
