diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-22 19:33:42 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-22 16:33:42 -0800 |
| commit | 6eb0b4dea4da1fc21767c86cc0837d0c8b68063b (patch) | |
| tree | 8ad8fe77e57db437be5f7403fd324e218db9c578 /tests | |
| parent | 0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a (diff) | |
Reverse-mode AD fixes for loops with non-trivial break region (#2671)
* Fix crash when applying autodiff to functions with no arguments
* Fixes for loops where the break region is non-trivial
* Minor fix
* Implement array legalization correctly.
* Fix array legalization.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/reverse-more-loops.slang | 60 | ||||
| -rw-r--r-- | tests/autodiff/reverse-more-loops.slang.expected.txt | 5 |
2 files changed, 65 insertions, 0 deletions
diff --git a/tests/autodiff/reverse-more-loops.slang b/tests/autodiff/reverse-more-loops.slang new file mode 100644 index 000000000..173caf963 --- /dev/null +++ b/tests/autodiff/reverse-more-loops.slang @@ -0,0 +1,60 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer + +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float bsdf() +{ + return 0.5; +} + +[ForwardDerivativeOf(bsdf)] +DifferentialPair<float> d_bsdf() +{ + return diffPair(0.5f, 1.0f); +} + +[BackwardDerivativeOf(bsdf)] +void d_bsdf(float dOut) +{ + outputBuffer[0] += dOut; +} + +[BackwardDifferentiable] +float tracePath() +{ + float thp = 1.0; + float L = 0.0; + + uint depth = 0; + + for (int i = 0; i < 3; ++i) + { + if (depth <= 2) + { + thp = thp * bsdf(); + + L = thp * 1.0; + + if (depth >= 2) break; + + depth = depth + 1; + } + else + { + L = 0.0; + } + } + + return L; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + __bwd_diff(tracePath)(1.0); // Expect: 1.0 in outputBuffer[0] + } +} diff --git a/tests/autodiff/reverse-more-loops.slang.expected.txt b/tests/autodiff/reverse-more-loops.slang.expected.txt new file mode 100644 index 000000000..2bf164814 --- /dev/null +++ b/tests/autodiff/reverse-more-loops.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +0.750000 +0.000000 +0.000000 +0.000000 |
