diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/reverse-continue-loop.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-immediate-return.slang | 59 |
2 files changed, 60 insertions, 1 deletions
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 2dfad0a61..51f17b611 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -16,7 +16,7 @@ float test_loop_with_continue(float y) //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: float t = y; - //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here: + //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang new file mode 100644 index 000000000..121836115 --- /dev/null +++ b/tests/autodiff/reverse-loop-immediate-return.slang @@ -0,0 +1,59 @@ + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + + +[BackwardDerivative(set_bwd)] +void set(uint idx, float x) +{ + outputBuffer[idx] = x; +} + +void set_bwd(uint idx, inout DifferentialPair<float> x) +{ + // For debugging, we'll set the derivative to 1.0 + x = DifferentialPair<float>(x.p, 1.0f); +} + +[Differentiable] +void run( + uint idx, + float x) +{ + if (idx >= 1) return; + + if (idx == 0) + { } + + for (int i = 0; i < 1; i++) + { + if (idx > 0) + { + return; + } + + if (idx == 0) + { + x = x * 2.0f; + } + } + + if (idx == 0) + { } + + set(idx, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // bwd_diff + DifferentialPair<float> dpa = DifferentialPair<float>(1.0, 0.0); + bwd_diff(run)(dispatchThreadID.x, dpa); + outputBuffer[dispatchThreadID.x] = dpa.d; + + // CHECK: type: float + // CHECK: 2.0 +} |
