diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-19 13:05:10 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-19 13:05:10 -0800 |
| commit | a02379208f8906272d3fd773d4b5cfe8eec3be3b (patch) | |
| tree | 854f74fa2dc6da7ca660c0b7eba9407e11040c32 /tests/autodiff-dstdlib | |
| parent | 0959d7ebeb6932b1949a4be10e5c472327006352 (diff) | |
Fix issue with `clamp`'s derivatives at the boundary. (#6403)
Diffstat (limited to 'tests/autodiff-dstdlib')
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-clamp.slang | 44 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt | 8 |
2 files changed, 51 insertions, 1 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang b/tests/autodiff-dstdlib/dstdlib-clamp.slang index 32b1cc8eb..3af12907a 100644 --- a/tests/autodiff-dstdlib/dstdlib-clamp.slang +++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang @@ -1,7 +1,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; @@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[28] = dpmax.d.y; // Expected: 0.0 outputBuffer[29] = dpmax.d.z; // Expected: 0.3 } + + // New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1) + { + // Lower edge: x exactly = 0 + dpfloat dpx = dpfloat(0.0, 0.4); + dpfloat dpmin = dpfloat(0.0, 0.8); + dpfloat dpmax = dpfloat(1.0, 0.5); + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); + outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x) + } + + { + // Upper edge: x exactly = 1 + dpfloat dpx = dpfloat(1.0, 0.7); + dpfloat dpmin = dpfloat(0.0, 0.8); + dpfloat dpmax = dpfloat(1.0, 0.9); + dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax); + outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x) + } + + // Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1) + { + // Lower edge: x exactly = 0 + dpfloat dpx = dpfloat(0.0, 0.0); + dpfloat dpmin = dpfloat(0.0, 0.0); + dpfloat dpmax = dpfloat(1.0, 0.0); + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); + outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x) + outputBuffer[33] = dpmin.d; // Expected: 0.0 + outputBuffer[34] = dpmax.d; // Expected: 0.0 + } + + { + // Upper edge: x exactly = 1 + dpfloat dpx = dpfloat(1.0, 0.0); + dpfloat dpmin = dpfloat(0.0, 0.0); + dpfloat dpmax = dpfloat(1.0, 0.0); + bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0); + outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x) + outputBuffer[36] = dpmin.d; // Expected: 0.0 + outputBuffer[37] = dpmax.d; // Expected: 0.0 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt index b00b0060b..b18853e90 100644 --- a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt @@ -29,3 +29,11 @@ type: float 0.000000 0.000000 0.300000 +0.400000 +0.700000 +1.000000 +0.000000 +0.000000 +1.000000 +0.000000 +0.000000 |
