diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-08-14 03:23:32 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-14 00:23:32 -0700 |
| commit | 0403e0556b470f6b316153caea2dc6f5c314da5b (patch) | |
| tree | 1271dbddc28a6fccaa680dd3a6dc68fadcf45115 | |
| parent | e689d5ee8e9724fee018aa14be24f9679ec5c851 (diff) | |
Fix issue with nested loop unrolling (#3100)
* Do not eliminate single-iter-loops that have inner loops using their break label.
* Add test
* Delete out-old.hlsl
* Update slang-ir-autodiff-cfg-norm.cpp
* Fix whitespace
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 23 | ||||
| -rw-r--r-- | tests/autodiff/nested-loop-unroll.slang | 44 | ||||
| -rw-r--r-- | tests/autodiff/nested-loop-unroll.slang.expected.txt | 5 |
4 files changed, 80 insertions, 2 deletions
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 7db517309..8f7307f1b 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -59,6 +59,11 @@ struct EliminateMultiLevelBreakContext HashSet<IRBlock*> processedBlocks; List<MultiLevelBreakInfo> multiLevelBreaks; + bool isUnreachable(IRBlock* block) + { + return block->getPredecessors().getCount() == 0; + } + void collectBreakableRegionBlocks(BreakableRegionInfo& info) { // Push break block to a stack so we can easily check if a block is a break block in its @@ -92,7 +97,7 @@ struct EliminateMultiLevelBreakContext collectBreakableRegionBlocks(*childRegion); info.childRegions.add(childRegion); block = childRegion->getBreakBlock(); - if (info.blockSet.add(block)) + if (!isUnreachable(block) && info.blockSet.add(block)) { info.blocks.add(block); } @@ -142,7 +147,8 @@ struct EliminateMultiLevelBreakContext l->forEach( [&](BreakableRegionInfo* region) { - mapBreakBlockToRegion.add(region->getBreakBlock(), region); + if(!isUnreachable(region->getBreakBlock())) + mapBreakBlockToRegion.add(region->getBreakBlock(), region); for (auto block : region->blocks) mapBlockToRegion.add(block, region); }); diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index c4c5b584e..f2d0c4555 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -94,6 +94,29 @@ static bool isTrivialSingleIterationLoop( } } } + + // We'll also check if there's an inner loop that is breaking out into this loop's break block. + // If so, we cannot remove it right away since it interferes with the multi-level break elimination + // logic. + // + // Track the break block backwards through the dominator tree, and see if we find a loop block + // that is not the current loop. + // + auto currBlock = loop->getBreakBlock(); + for (;;) + { + auto parent = context.domTree->getImmediateDominator(currBlock); + if (!parent) + break; + currBlock = parent; + if (auto _loop = as<IRLoop>(currBlock->getTerminator())) + { + if (loop != _loop) + return false; + if (loop == _loop) + break; + } + } return true; } diff --git a/tests/autodiff/nested-loop-unroll.slang b/tests/autodiff/nested-loop-unroll.slang new file mode 100644 index 000000000..026764f8f --- /dev/null +++ b/tests/autodiff/nested-loop-unroll.slang @@ -0,0 +1,44 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +static const uint levels = 8; + +[Differentiable] +void eval(float3 p, out float4 output[levels]) +{ + [ForceUnroll] for (int level = 0; level < 3; ++level) + { + float4 f = 0.f; + + // tri-linear time! + [ForceUnroll] for (int z = 0; z < 2; ++z) + { + float wx = 0; + if (z != 0) + wx = p.x; + else + wx = p.y; + + f += wx; + } + + output[level] = f; + } +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float3 p = float3(2.0, 3.0, 0); + + float4 output[levels]; + eval(p, output); + DifferentialPair<float3> dp = DifferentialPair<float3>(p, 0); + __bwd_diff(eval)(dp, output); + + // Write output + outputBuffer[0] = dp.d.x; +} diff --git a/tests/autodiff/nested-loop-unroll.slang.expected.txt b/tests/autodiff/nested-loop-unroll.slang.expected.txt new file mode 100644 index 000000000..c34bf7c0d --- /dev/null +++ b/tests/autodiff/nested-loop-unroll.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +60.000000 +0.000000 +0.000000 +0.000000 |
