summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp120
-rw-r--r--source/slang/slang-ir-loop-unroll.h10
-rw-r--r--source/slang/slang-ir-util.cpp2
-rw-r--r--tests/ir/loop-unroll-2.slang23
-rw-r--r--tests/ir/loop-unroll-2.slang.expected.txt4
5 files changed, 154 insertions, 5 deletions
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 725f20902..f606f0cc1 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-util.h"
#include "slang-ir-simplify-cfg.h"
+#include "slang-ir-dce.h"
namespace Slang
{
@@ -444,10 +445,9 @@ static bool _unrollLoop(
return loopTerminated;
}
-bool unrollLoopsInFunc(
- SharedIRBuilder* sharedBuilder,
- IRGlobalValueWithCode* func,
- DiagnosticSink* sink)
+// Visits all loop insts in a func, inner loop first.
+template<typename TFunc>
+List<IRLoop*> collectLoopsInFunc(IRGlobalValueWithCode* func, const TFunc& filter)
{
List<IRLoop*> loops;
@@ -458,18 +458,31 @@ bool unrollLoopsInFunc(
{
if (auto loop = as<IRLoop>(block->getTerminator()))
{
- if (loop->findDecoration<IRForceUnrollDecoration>())
+ if (filter(loop))
{
loops.add(loop);
}
}
}
+ return loops;
+}
+
+bool unrollLoopsInFunc(
+ SharedIRBuilder* sharedBuilder,
+ IRGlobalValueWithCode* func,
+ DiagnosticSink* sink)
+{
+ List<IRLoop*> loops = collectLoopsInFunc(
+ func, [](IRLoop* l) { return l->findDecoration<IRForceUnrollDecoration>() != nullptr; });
if (loops.getCount() == 0)
return true;
for (auto loop : loops)
{
+ // Remove any continue jumps from the loop.
+ eliminateContinueBlocks(sharedBuilder, loop);
+
auto postOrderReverseCFG = getPostorderOnReverseCFG(func);
Dictionary<IRBlock*, int> blockOrdering;
@@ -490,6 +503,7 @@ bool unrollLoopsInFunc(
// Make sure we simplify things as much as possible before
// attempting to potentially unroll outer loop.
simplifyCFG(func);
+ eliminateDeadCode(func);
}
return true;
}
@@ -517,4 +531,100 @@ bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, Diagn
return true;
}
+static void _moveParams(IRBlock* dest, IRBlock* src)
+{
+ for (auto param = src->getFirstChild(); param;)
+ {
+ auto nextInst = param->getNextInst();
+ if (as<IRDecoration>(param) || as<IRParam>(param))
+ {
+ param->insertAtEnd(dest);
+ }
+ else
+ {
+ break;
+ }
+ param = nextInst;
+ }
+}
+
+void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst)
+{
+ // Eliminate the continue jumps by turning a loop in the form of:
+ // for (;;)
+ // {
+ // <loop body>
+ // continueBlock:
+ // <continuePart>
+ // }
+ // into:
+ // for (;;) // original loop
+ // {
+ // for(;;) // breakableRegionHeader
+ // {
+ // <loop body>
+ // }
+ // breakableRegionBreakBlock:
+ // <continuePart>
+ // }
+ // where a continue is replaced with a "break" into breakableRegionBreakBlock.
+ //
+
+ if (loopInst->getContinueBlock() == loopInst->getTargetBlock())
+ return;
+
+ // If the continue block is not reachable, remove it.
+ auto continueBlock = loopInst->getContinueBlock();
+ if (continueBlock && !continueBlock->hasMoreThanOneUse())
+ {
+ loopInst->continueBlock.set(loopInst->getTargetBlock());
+ continueBlock->removeAndDeallocate();
+ return;
+ }
+
+ // We have determined that there is really a non-trivial continue block in the loop body,
+ // we will now introduce a breakable region for each iteration.
+
+ IRBuilder builder(sharedBuilder);
+
+ auto targetBlock = loopInst->getTargetBlock();
+
+ auto innerBreakableRegionHeader = builder.createBlock();
+ innerBreakableRegionHeader->insertBefore(targetBlock);
+
+ auto innerBreakableRegionBreakBlock = builder.createBlock();
+ innerBreakableRegionBreakBlock->insertBefore(continueBlock);
+
+ loopInst->block.set(innerBreakableRegionHeader);
+ loopInst->continueBlock.set(innerBreakableRegionHeader);
+
+ targetBlock->replaceUsesWith(innerBreakableRegionHeader);
+
+ // Move decorations and params from original targetBlock to innerBreakableRegionHeader.
+ _moveParams(innerBreakableRegionHeader, targetBlock);
+
+ builder.setInsertInto(innerBreakableRegionHeader);
+ builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock);
+
+ continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock);
+
+ builder.setInsertInto(innerBreakableRegionBreakBlock);
+ _moveParams(innerBreakableRegionBreakBlock, continueBlock);
+ builder.emitBranch(continueBlock);
+}
+
+void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func)
+{
+ List<IRLoop*> loops = collectLoopsInFunc(
+ func, [](IRLoop* l) { return l->getContinueBlock() != l->getTargetBlock(); });
+
+ if (loops.getCount() == 0)
+ return;
+
+ for (auto loop : loops)
+ {
+ eliminateContinueBlocks(sharedBuilder, loop);
+ }
+}
+
}
diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h
index a63625285..340011cb5 100644
--- a/source/slang/slang-ir-loop-unroll.h
+++ b/source/slang/slang-ir-loop-unroll.h
@@ -13,4 +13,14 @@ namespace Slang
bool unrollLoopsInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func, DiagnosticSink* sink);
bool unrollLoopsInModule(SharedIRBuilder* sharedBuilder, IRModule* module, DiagnosticSink* sink);
+
+
+ // Turn a loop with continue block into a loop with only back jumps and breaks.
+ // Each iteration will be wrapped in a breakable region, where everything before `continue`
+ // is within the breakable region, and everything after `continue` is outside the breakable
+ // region. A `continue` then becomes a `break` in the inner breakable region, and a `break`
+ // becomes a multi-level break out of the parent loop.
+ void eliminateContinueBlocks(SharedIRBuilder* sharedBuilder, IRLoop* loopInst);
+ void eliminateContinueBlocksInFunc(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func);
+
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 0448eb649..1ea426715 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -224,7 +224,9 @@ String dumpIRToString(IRInst* root)
StringBuilder sb;
StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
IRDumpOptions options = {};
+#if 0
options.flags = IRDumpOptions::Flag::DumpDebugIds;
+#endif
dumpIR(root, options, nullptr, &writer);
return sb.ToString();
}
diff --git a/tests/ir/loop-unroll-2.slang b/tests/ir/loop-unroll-2.slang
new file mode 100644
index 000000000..aff227432
--- /dev/null
+++ b/tests/ir/loop-unroll-2.slang
@@ -0,0 +1,23 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int sum = 0;
+ [ForceUnroll]
+ for (int i = 0; i < 2; i++)
+ {
+ [ForceUnroll(2)]
+ for (int j = 1; j < 3; j++)
+ {
+ if (i == 1 && j == 1)
+ continue;
+ sum += (i+j);
+ }
+ }
+ outputBuffer[0] = sum;
+}
diff --git a/tests/ir/loop-unroll-2.slang.expected.txt b/tests/ir/loop-unroll-2.slang.expected.txt
new file mode 100644
index 000000000..e20d75ba3
--- /dev/null
+++ b/tests/ir/loop-unroll-2.slang.expected.txt
@@ -0,0 +1,4 @@
+6
+0
+0
+0