diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 28 | ||||
| -rw-r--r-- | tests/autodiff/reverse-while-loop-2.slang | 6 |
2 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 2283ebf5c..05884d13d 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -320,8 +320,34 @@ struct DiffTransposePass // Old cond block becomes new pre-break block. IRBlock* revBreakBlock = revBlockMap[currentBlock]; - // Old true-side starting block becomes loop end block. + // Old true-side starting block becomes loop end block... IRBlock* revLoopEndBlock = revBlockMap[trueBlock]; + + // ... unless the true block has multiple successors, in which + // case revLoopEndBlock is the merge block for some other if-else region + // + // We will insert a new block after revLookEndBlock, which will serve as the + // actual end block. + // + HashSet<IRBlock*> uniqueSuccessors; + for (auto successor : trueBlock->getSuccessors()) + uniqueSuccessors.add(successor); + if (uniqueSuccessors.getCount() > 1) + { + auto revLookPreEndBlock = revLoopEndBlock; + builder.setInsertAfter(revLookPreEndBlock); + revLoopEndBlock = builder.emitBlock(); + + if (isDifferentialInst(trueBlock)) + { + builder.markInstAsDifferential(revLoopEndBlock); + } + + builder.setInsertInto(revLookPreEndBlock); + builder.emitBranch(revLoopEndBlock); + } + + // Then, branch from the new loop end block to the new cond block. builder.setInsertInto(revLoopEndBlock); builder.emitBranch( revCondBlock, diff --git a/tests/autodiff/reverse-while-loop-2.slang b/tests/autodiff/reverse-while-loop-2.slang index 70b9d5a13..9ad9ac466 100644 --- a/tests/autodiff/reverse-while-loop-2.slang +++ b/tests/autodiff/reverse-while-loop-2.slang @@ -1,6 +1,6 @@ -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -loop-inversion -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -Xslang -loop-inversion -//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -Xslang -loop-inversion +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; |
