summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-05-14 18:29:09 -0400
committerGitHub <noreply@github.com>2024-05-14 15:29:09 -0700
commit5ceb8569b1ac7898c437b0c47ad29a5d8a9f7d90 (patch)
treeb7ec73d72c5d468fe9526f10bc72bc77b2b14ef8
parent291b4cd82cebeed39d8c06c8208fc415dfa32a48 (diff)
Fix CFG reversal logic for loops (#4162)
Handles a corner case where the first block after the condition on the true-side is another condition. This would currently result in an invalid reverse graph, where the reverse version of the true-block is the merge point for two different branching insts (the reverse version of the loop as well as the second condition). This patch simply adds a blank block when constructing the reverse-loop (similar to critical edge breaking) so that each branch inst in the reversed loop has a unique merge block.
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h28
-rw-r--r--tests/autodiff/reverse-while-loop-2.slang6
2 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 2283ebf5c..05884d13d 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -320,8 +320,34 @@ struct DiffTransposePass
// Old cond block becomes new pre-break block.
IRBlock* revBreakBlock = revBlockMap[currentBlock];
- // Old true-side starting block becomes loop end block.
+ // Old true-side starting block becomes loop end block...
IRBlock* revLoopEndBlock = revBlockMap[trueBlock];
+
+ // ... unless the true block has multiple successors, in which
+ // case revLoopEndBlock is the merge block for some other if-else region
+ //
+ // We will insert a new block after revLookEndBlock, which will serve as the
+ // actual end block.
+ //
+ HashSet<IRBlock*> uniqueSuccessors;
+ for (auto successor : trueBlock->getSuccessors())
+ uniqueSuccessors.add(successor);
+ if (uniqueSuccessors.getCount() > 1)
+ {
+ auto revLookPreEndBlock = revLoopEndBlock;
+ builder.setInsertAfter(revLookPreEndBlock);
+ revLoopEndBlock = builder.emitBlock();
+
+ if (isDifferentialInst(trueBlock))
+ {
+ builder.markInstAsDifferential(revLoopEndBlock);
+ }
+
+ builder.setInsertInto(revLookPreEndBlock);
+ builder.emitBranch(revLoopEndBlock);
+ }
+
+ // Then, branch from the new loop end block to the new cond block.
builder.setInsertInto(revLoopEndBlock);
builder.emitBranch(
revCondBlock,
diff --git a/tests/autodiff/reverse-while-loop-2.slang b/tests/autodiff/reverse-while-loop-2.slang
index 70b9d5a13..9ad9ac466 100644
--- a/tests/autodiff/reverse-while-loop-2.slang
+++ b/tests/autodiff/reverse-while-loop-2.slang
@@ -1,6 +1,6 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -loop-inversion
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -Xslang -loop-inversion
-//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -Xslang -loop-inversion
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;