summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-08-14 03:23:32 -0400
committerGitHub <noreply@github.com>2023-08-14 00:23:32 -0700
commit0403e0556b470f6b316153caea2dc6f5c314da5b (patch)
tree1271dbddc28a6fccaa680dd3a6dc68fadcf45115
parente689d5ee8e9724fee018aa14be24f9679ec5c851 (diff)
Fix issue with nested loop unrolling (#3100)
* Do not eliminate single-iter-loops that have inner loops using their break label. * Add test * Delete out-old.hlsl * Update slang-ir-autodiff-cfg-norm.cpp * Fix whitespace
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp10
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp23
-rw-r--r--tests/autodiff/nested-loop-unroll.slang44
-rw-r--r--tests/autodiff/nested-loop-unroll.slang.expected.txt5
4 files changed, 80 insertions, 2 deletions
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 7db517309..8f7307f1b 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -59,6 +59,11 @@ struct EliminateMultiLevelBreakContext
HashSet<IRBlock*> processedBlocks;
List<MultiLevelBreakInfo> multiLevelBreaks;
+ bool isUnreachable(IRBlock* block)
+ {
+ return block->getPredecessors().getCount() == 0;
+ }
+
void collectBreakableRegionBlocks(BreakableRegionInfo& info)
{
// Push break block to a stack so we can easily check if a block is a break block in its
@@ -92,7 +97,7 @@ struct EliminateMultiLevelBreakContext
collectBreakableRegionBlocks(*childRegion);
info.childRegions.add(childRegion);
block = childRegion->getBreakBlock();
- if (info.blockSet.add(block))
+ if (!isUnreachable(block) && info.blockSet.add(block))
{
info.blocks.add(block);
}
@@ -142,7 +147,8 @@ struct EliminateMultiLevelBreakContext
l->forEach(
[&](BreakableRegionInfo* region)
{
- mapBreakBlockToRegion.add(region->getBreakBlock(), region);
+ if(!isUnreachable(region->getBreakBlock()))
+ mapBreakBlockToRegion.add(region->getBreakBlock(), region);
for (auto block : region->blocks)
mapBlockToRegion.add(block, region);
});
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index c4c5b584e..f2d0c4555 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -94,6 +94,29 @@ static bool isTrivialSingleIterationLoop(
}
}
}
+
+ // We'll also check if there's an inner loop that is breaking out into this loop's break block.
+ // If so, we cannot remove it right away since it interferes with the multi-level break elimination
+ // logic.
+ //
+ // Track the break block backwards through the dominator tree, and see if we find a loop block
+ // that is not the current loop.
+ //
+ auto currBlock = loop->getBreakBlock();
+ for (;;)
+ {
+ auto parent = context.domTree->getImmediateDominator(currBlock);
+ if (!parent)
+ break;
+ currBlock = parent;
+ if (auto _loop = as<IRLoop>(currBlock->getTerminator()))
+ {
+ if (loop != _loop)
+ return false;
+ if (loop == _loop)
+ break;
+ }
+ }
return true;
}
diff --git a/tests/autodiff/nested-loop-unroll.slang b/tests/autodiff/nested-loop-unroll.slang
new file mode 100644
index 000000000..026764f8f
--- /dev/null
+++ b/tests/autodiff/nested-loop-unroll.slang
@@ -0,0 +1,44 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+static const uint levels = 8;
+
+[Differentiable]
+void eval(float3 p, out float4 output[levels])
+{
+ [ForceUnroll] for (int level = 0; level < 3; ++level)
+ {
+ float4 f = 0.f;
+
+ // tri-linear time!
+ [ForceUnroll] for (int z = 0; z < 2; ++z)
+ {
+ float wx = 0;
+ if (z != 0)
+ wx = p.x;
+ else
+ wx = p.y;
+
+ f += wx;
+ }
+
+ output[level] = f;
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float3 p = float3(2.0, 3.0, 0);
+
+ float4 output[levels];
+ eval(p, output);
+ DifferentialPair<float3> dp = DifferentialPair<float3>(p, 0);
+ __bwd_diff(eval)(dp, output);
+
+ // Write output
+ outputBuffer[0] = dp.d.x;
+}
diff --git a/tests/autodiff/nested-loop-unroll.slang.expected.txt b/tests/autodiff/nested-loop-unroll.slang.expected.txt
new file mode 100644
index 000000000..c34bf7c0d
--- /dev/null
+++ b/tests/autodiff/nested-loop-unroll.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+60.000000
+0.000000
+0.000000
+0.000000