diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-25 12:04:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-25 12:04:31 -0800 |
| commit | f7b9745e46db6a7e55f6e0265493350d65ea4615 (patch) | |
| tree | fb74e013a1c57876c7b94299367c6b9b8343784f /tests | |
| parent | a9f2f8a592c4514cd116c947486055788092ea56 (diff) | |
Fix a bug with hoisting 'IRVar' insts that are used outside the loop (#6446)
* Fix a bug with hoisting 'IRVar' insts that are used outside the loop
- We introduce a 'CheckpointObject' inst and use that to split loop state insts into two pieces (one for within-loop uses and one for outside-loop uses.
- This allows the two kinds of uses to be handled separately by the hoisting mechanism
- CheckpointObject is then lowered to a no-op after hoisting is complete.
* Update slang-ir-autodiff-primal-hoist.cpp
* Update slang-ir-autodiff-primal-hoist.cpp
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/reverse-continue-loop.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-immediate-return.slang | 59 |
2 files changed, 60 insertions, 1 deletions
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 2dfad0a61..51f17b611 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -16,7 +16,7 @@ float test_loop_with_continue(float y) //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: float t = y; - //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here: + //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang new file mode 100644 index 000000000..121836115 --- /dev/null +++ b/tests/autodiff/reverse-loop-immediate-return.slang @@ -0,0 +1,59 @@ + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + + +[BackwardDerivative(set_bwd)] +void set(uint idx, float x) +{ + outputBuffer[idx] = x; +} + +void set_bwd(uint idx, inout DifferentialPair<float> x) +{ + // For debugging, we'll set the derivative to 1.0 + x = DifferentialPair<float>(x.p, 1.0f); +} + +[Differentiable] +void run( + uint idx, + float x) +{ + if (idx >= 1) return; + + if (idx == 0) + { } + + for (int i = 0; i < 1; i++) + { + if (idx > 0) + { + return; + } + + if (idx == 0) + { + x = x * 2.0f; + } + } + + if (idx == 0) + { } + + set(idx, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // bwd_diff + DifferentialPair<float> dpa = DifferentialPair<float>(1.0, 0.0); + bwd_diff(run)(dispatchThreadID.x, dpa); + outputBuffer[dispatchThreadID.x] = dpa.d; + + // CHECK: type: float + // CHECK: 2.0 +} |
