From f7b9745e46db6a7e55f6e0265493350d65ea4615 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 25 Feb 2025 12:04:31 -0800 Subject: Fix a bug with hoisting 'IRVar' insts that are used outside the loop (#6446) * Fix a bug with hoisting 'IRVar' insts that are used outside the loop - We introduce a 'CheckpointObject' inst and use that to split loop state insts into two pieces (one for within-loop uses and one for outside-loop uses. - This allows the two kinds of uses to be handled separately by the hoisting mechanism - CheckpointObject is then lowered to a no-op after hoisting is complete. * Update slang-ir-autodiff-primal-hoist.cpp * Update slang-ir-autodiff-primal-hoist.cpp --- tests/autodiff/reverse-continue-loop.slang | 2 +- tests/autodiff/reverse-loop-immediate-return.slang | 59 ++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 tests/autodiff/reverse-loop-immediate-return.slang (limited to 'tests') diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 2dfad0a61..51f17b611 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -16,7 +16,7 @@ float test_loop_with_continue(float y) //CHK-DAG: note: 20 bytes (FixedArray ) used to checkpoint the following item: float t = y; - //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here: + //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang new file mode 100644 index 000000000..121836115 --- /dev/null +++ b/tests/autodiff/reverse-loop-immediate-return.slang @@ -0,0 +1,59 @@ + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + + +[BackwardDerivative(set_bwd)] +void set(uint idx, float x) +{ + outputBuffer[idx] = x; +} + +void set_bwd(uint idx, inout DifferentialPair x) +{ + // For debugging, we'll set the derivative to 1.0 + x = DifferentialPair(x.p, 1.0f); +} + +[Differentiable] +void run( + uint idx, + float x) +{ + if (idx >= 1) return; + + if (idx == 0) + { } + + for (int i = 0; i < 1; i++) + { + if (idx > 0) + { + return; + } + + if (idx == 0) + { + x = x * 2.0f; + } + } + + if (idx == 0) + { } + + set(idx, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // bwd_diff + DifferentialPair dpa = DifferentialPair(1.0, 0.0); + bwd_diff(run)(dispatchThreadID.x, dpa); + outputBuffer[dispatchThreadID.x] = dpa.d; + + // CHECK: type: float + // CHECK: 2.0 +} -- cgit v1.2.3