summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-13 10:39:12 -0800
committerGitHub <noreply@github.com>2023-02-13 10:39:12 -0800
commit977eb925b7e9cb1a763c1e5563b2bc605b6476d6 (patch)
treebf4922bdf76e9dbd25a2186c93097b30ffb57432 /source
parent4dbc74a953ae1b34ce64a4eaef3aa7feb73663b9 (diff)
Eliminate `continue` to allow unrolling any loops. (#2645)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp120
-rw-r--r--source/slang/slang-ir-loop-unroll.h10
-rw-r--r--source/slang/slang-ir-util.cpp2
3 files changed, 127 insertions, 5 deletions
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 725f20902..f606f0cc1 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-util.h"
#include "slang-ir-simplify-cfg.h"
+#include "slang-ir-dce.h"
namespace Slang
{
@@ -444,10 +445,9 @@ static bool _unrollLoop(
return loopTerminated;
}
-bool unrollLoopsInFunc(
- SharedIRBuilder* sharedBuilder,
- IRGlobalValueWithCode* func,
- DiagnosticSink* sink)
+// Visits all loop insts in a func, inner loop first.
+template<typename TFunc>
+List<IRLoop*> collectLoopsInFunc(IRGlobalValueWithCode* func, const TFunc& filter)
{
List<IRLoop*> loops;
@@ -458,18 +458,31 @@ bool unrollLoopsInFunc(
{
if (auto loop = as<IRLoop>(block->getTerminator()))
{
- if (loop->findDecoration<IRForceUnrollDecoration>())
+ if (filter(loop))
{
loops.add(loop);
}
}
}
+ return loops;
+}
+
+bool unrollLoopsInFunc(
+ SharedIRBuilder* sharedBuilder,
+ IRGlobalValueWithCode* func,
+ DiagnosticSink* sink)
+{
+ List<IRLoop*> loops = collectLoopsInFunc(
+ func, [](IRLoop* l) { return l->findDecoration<IRForceUnrollDecoration>() != nullptr; });
if (loops.getCount() == 0)
return true;
for (auto loop : loops)
{
+ // Remove any continue jumps from the loop.
+ eliminateContinueBlocks(sharedBuilder, loop);
+
auto postOrderReverseCFG = getPostorderOnReverseCFG(func);
Dictionary<IRBlock*, int> blockOrdering;
@@ -490,6 +503,7 @@ bool unrollLoopsInFunc(
// Make sure we simplify things as much as possible before
// attempting to potentially unroll outer loop.
simplifyCFG(func);
+ eliminateDeadCode(func);
}
return true;
}
@@ -517,4 +531,100 @@ bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, Diagn
return true;
}
+static void _moveParams(IRBlock* dest, IRBlock* src)
+{
+ for (auto param = src->getFirstChild(); param;)
+ {
+ auto nextInst = param->getNextInst();
+ if (as<IRDecoration>(param) || as<IRParam>(param))
+ {
+ param->insertAtEnd(dest);
+ }
+ else
+ {
+ break;
+ }
+ param = nextInst;
+ }
+}
+
+void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst)
+{
+ // Eliminate the continue jumps by turning a loop in the form of:
+ // for (;;)
+ // {
+ // <loop body>
+ // continueBlock:
+ // <continuePart>
+ // }
+ // into:
+ // for (;;) // original loop
+ // {
+ // for(;;) // breakableRegionHeader
+ // {
+ // <loop body>
+ // }
+ // breakableRegionBreakBlock:
+ // <continuePart>
+ // }
+ // where a continue is replaced with a "break" into breakableRegionBreakBlock.
+ //
+
+ if (loopInst->getContinueBlock() == loopInst->getTargetBlock())
+ return;
+
+ // If the continue block is not reachable, remove it.
+ auto continueBlock = loopInst->getContinueBlock();
+ if (continueBlock && !continueBlock->hasMoreThanOneUse())
+ {
+ loopInst->continueBlock.set(loopInst->getTargetBlock());
+ continueBlock->removeAndDeallocate();
+ return;
+ }
+
+ // We have determined that there is really a non-trivial continue block in the loop body,
+ // we will now introduce a breakable region for each iteration.
+
+ IRBuilder builder(sharedBuilder);
+
+ auto targetBlock = loopInst->getTargetBlock();
+
+ auto innerBreakableRegionHeader = builder.createBlock();
+ innerBreakableRegionHeader->insertBefore(targetBlock);
+
+ auto innerBreakableRegionBreakBlock = builder.createBlock();
+ innerBreakableRegionBreakBlock->insertBefore(continueBlock);
+
+ loopInst->block.set(innerBreakableRegionHeader);
+ loopInst->continueBlock.set(innerBreakableRegionHeader);
+
+ targetBlock->replaceUsesWith(innerBreakableRegionHeader);
+
+ // Move decorations and params from original targetBlock to innerBreakableRegionHeader.
+ _moveParams(innerBreakableRegionHeader, targetBlock);
+
+ builder.setInsertInto(innerBreakableRegionHeader);
+ builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock);
+
+ continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock);
+
+ builder.setInsertInto(innerBreakableRegionBreakBlock);
+ _moveParams(innerBreakableRegionBreakBlock, continueBlock);
+ builder.emitBranch(continueBlock);
+}
+
+void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func)
+{
+ List<IRLoop*> loops = collectLoopsInFunc(
+ func, [](IRLoop* l) { return l->getContinueBlock() != l->getTargetBlock(); });
+
+ if (loops.getCount() == 0)
+ return;
+
+ for (auto loop : loops)
+ {
+ eliminateContinueBlocks(sharedBuilder, loop);
+ }
+}
+
}
diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h
index a63625285..340011cb5 100644
--- a/source/slang/slang-ir-loop-unroll.h
+++ b/source/slang/slang-ir-loop-unroll.h
@@ -13,4 +13,14 @@ namespace Slang
bool unrollLoopsInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, DiagnosticSink* sink);
bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink);
+
+
+ // Turn a loop with continue block into a loop with only back jumps and breaks.
+ // Each iteration will be wrapped in a breakable region, where everything before `continue`
+ // is within the breakable region, and everything after `continue` is outside the breakable
+ // region. A `continue` then becomes a `break` in the inner breakable region, and a `break`
+ // becomes a multi-level break out of the parent loop.
+ void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst);
+ void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func);
+
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 0448eb649..1ea426715 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -224,7 +224,9 @@ String dumpIRToString(IRInst* root)
StringBuilder sb;
StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
IRDumpOptions options = {};
+#if 0
options.flags = IRDumpOptions::Flag::DumpDebugIds;
+#endif
dumpIR(root, options, nullptr, &writer);
return sb.ToString();
}