summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h28
-rw-r--r--tests/autodiff/reverse-while-loop-2.slang6
2 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 2283ebf5c..05884d13d 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -320,8 +320,34 @@ struct DiffTransposePass
// Old cond block becomes new pre-break block.
IRBlock* revBreakBlock = revBlockMap[currentBlock];
- // Old true-side starting block becomes loop end block.
+ // Old true-side starting block becomes loop end block...
IRBlock* revLoopEndBlock = revBlockMap[trueBlock];
+
+ // ... unless the true block has multiple successors, in which
+ // case revLoopEndBlock is the merge block for some other if-else region
+ //
+ // We will insert a new block after revLookEndBlock, which will serve as the
+ // actual end block.
+ //
+ HashSet<IRBlock*> uniqueSuccessors;
+ for (auto successor : trueBlock->getSuccessors())
+ uniqueSuccessors.add(successor);
+ if (uniqueSuccessors.getCount() > 1)
+ {
+ auto revLookPreEndBlock = revLoopEndBlock;
+ builder.setInsertAfter(revLookPreEndBlock);
+ revLoopEndBlock = builder.emitBlock();
+
+ if (isDifferentialInst(trueBlock))
+ {
+ builder.markInstAsDifferential(revLoopEndBlock);
+ }
+
+ builder.setInsertInto(revLookPreEndBlock);
+ builder.emitBranch(revLoopEndBlock);
+ }
+
+ // Then, branch from the new loop end block to the new cond block.
builder.setInsertInto(revLoopEndBlock);
builder.emitBranch(
revCondBlock,
diff --git a/tests/autodiff/reverse-while-loop-2.slang b/tests/autodiff/reverse-while-loop-2.slang
index 70b9d5a13..9ad9ac466 100644
--- a/tests/autodiff/reverse-while-loop-2.slang
+++ b/tests/autodiff/reverse-while-loop-2.slang
@@ -1,6 +1,6 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -loop-inversion
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -Xslang -loop-inversion
-//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -Xslang -loop-inversion
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;