diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-09-26 20:50:13 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-26 17:50:13 -0700 |
| commit | c5c8cfbb360d9a763f549df48636effde839eacd (patch) | |
| tree | 6b055d36e71749d70ace575fc180a23500331b00 | |
| parent | a18dca27392b257ba2cc58ceabdf15471f34ee25 (diff) | |
Handle the case where the parent if-else region's after-block is unreachable. (#3241)
Also added a test for this.
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 24 | ||||
| -rw-r--r-- | tests/autodiff/control-flow-bug.slang | 60 | ||||
| -rw-r--r-- | tests/autodiff/control-flow-bug.slang.expected.txt | 6 |
3 files changed, 90 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index e7c269756..3602e77ae 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -356,6 +356,30 @@ struct CFGNormalizationPass // afterBaseRegion = true; + // One case we do check for is if the after block is 'unreachable' + // i.e. the terminator is an `unreachable` instruction. + // In this case, we can safely assume that the after block does not + // have anything to execute. Further, we need to re-wire the + // previously unreachable block to the parent break block. + // Note that this operation is safe because if the after block was + // originally unreachable, all potential paths to it must have + // broken out of the region. + // + if (auto unreachInst = as<IRUnreachable>(afterBlock->getTerminator())) + { + // Link it to the parentAfterBlock. + builder.setInsertInto(afterBlock); + unreachInst->removeAndDeallocate(); + + builder.emitBranch(parentAfterBlock); + + // We can now safely assume that the after block is empty. + // Set 'afterBaseRegion' to false, which should lead the rest + // of the logic to avoid splitting the after block + // + afterBaseRegion = false; + } + // Do we need to split the after region? if (afterBaseRegion && afterBreakRegion) { diff --git a/tests/autodiff/control-flow-bug.slang b/tests/autodiff/control-flow-bug.slang new file mode 100644 index 000000000..187d7f6f8 --- /dev/null +++ b/tests/autodiff/control-flow-bug.slang @@ -0,0 +1,60 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[Differentiable] [PreferRecompute] +float3 fetch(float2 uv) +{ + if (uv.x > 0.5f) + { + if (uv.x > 0.7f) + return float3(2.) * uv.y; + else + return float3(1.) * uv.y; + } + else + { + if (uv.x > 0.3f) + return float3(4.) * uv.y; + else + return float3(3.) * uv.y; + } +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float2 uv = (float2)dispatchThreadID.xy / float2(512, 512); + float3 color = fetch(uv); + outputBuffer[0] = color.x; // Expect: 0.0 + + { + DifferentialPair<float2> dpuv = diffPair(float2(0.6f)); + bwd_diff(fetch)(dpuv, float3(1.f)); + + outputBuffer[1] = dpuv.d.y; // Expect: 1.0 * 3 = 3 + } + + { + DifferentialPair<float2> dpuv = diffPair(float2(0.8f)); + bwd_diff(fetch)(dpuv, float3(1.f)); + + outputBuffer[2] = dpuv.d.y; // Expect: 2.0 * 3 = 6 + } + + { + DifferentialPair<float2> dpuv = diffPair(float2(0.1f)); + bwd_diff(fetch)(dpuv, float3(1.f)); + + outputBuffer[3] = dpuv.d.y; // Expect: 3.0 * 3 = 9 + } + + { + DifferentialPair<float2> dpuv = diffPair(float2(0.4f)); + bwd_diff(fetch)(dpuv, float3(1.f)); + + outputBuffer[4] = dpuv.d.y; // Expect: 4.0 * 3 = 12 + } +}
\ No newline at end of file diff --git a/tests/autodiff/control-flow-bug.slang.expected.txt b/tests/autodiff/control-flow-bug.slang.expected.txt new file mode 100644 index 000000000..07e59ce80 --- /dev/null +++ b/tests/autodiff/control-flow-bug.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +3.000000 +6.000000 +9.000000 +12.000000 |
