summaryrefslogtreecommitdiffstats
path: root/tests/autodiff-dstdlib
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-02-19 13:05:10 -0800
committerGitHub <noreply@github.com>2025-02-19 13:05:10 -0800
commita02379208f8906272d3fd773d4b5cfe8eec3be3b (patch)
tree854f74fa2dc6da7ca660c0b7eba9407e11040c32 /tests/autodiff-dstdlib
parent0959d7ebeb6932b1949a4be10e5c472327006352 (diff)
Fix issue with `clamp`'s derivatives at the boundary. (#6403)
Diffstat (limited to 'tests/autodiff-dstdlib')
-rw-r--r--tests/autodiff-dstdlib/dstdlib-clamp.slang44
-rw-r--r--tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt8
2 files changed, 51 insertions, 1 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang b/tests/autodiff-dstdlib/dstdlib-clamp.slang
index 32b1cc8eb..3af12907a 100644
--- a/tests/autodiff-dstdlib/dstdlib-clamp.slang
+++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -178,4 +178,46 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[28] = dpmax.d.y; // Expected: 0.0
outputBuffer[29] = dpmax.d.z; // Expected: 0.3
}
+
+ // New tests: Forward-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
+ {
+ // Lower edge: x exactly = 0
+ dpfloat dpx = dpfloat(0.0, 0.4);
+ dpfloat dpmin = dpfloat(0.0, 0.8);
+ dpfloat dpmax = dpfloat(1.0, 0.5);
+ dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
+ outputBuffer[30] = res.d; // Expected: 0.4 (propagated from x)
+ }
+
+ {
+ // Upper edge: x exactly = 1
+ dpfloat dpx = dpfloat(1.0, 0.7);
+ dpfloat dpmin = dpfloat(0.0, 0.8);
+ dpfloat dpmax = dpfloat(1.0, 0.9);
+ dpfloat res = fwd_diff(_clamp)(dpx, dpmin, dpmax);
+ outputBuffer[31] = res.d; // Expected: 0.7 (propagated from x)
+ }
+
+ // Reverse-mode tests for derivative propagation at the edges with clamp(x, 0, 1)
+ {
+ // Lower edge: x exactly = 0
+ dpfloat dpx = dpfloat(0.0, 0.0);
+ dpfloat dpmin = dpfloat(0.0, 0.0);
+ dpfloat dpmax = dpfloat(1.0, 0.0);
+ bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
+ outputBuffer[32] = dpx.d; // Expected: 1.0 (propagated from x)
+ outputBuffer[33] = dpmin.d; // Expected: 0.0
+ outputBuffer[34] = dpmax.d; // Expected: 0.0
+ }
+
+ {
+ // Upper edge: x exactly = 1
+ dpfloat dpx = dpfloat(1.0, 0.0);
+ dpfloat dpmin = dpfloat(0.0, 0.0);
+ dpfloat dpmax = dpfloat(1.0, 0.0);
+ bwd_diff(_clamp)(dpx, dpmin, dpmax, 1.0);
+ outputBuffer[35] = dpx.d; // Expected: 1.0 (propagated from x)
+ outputBuffer[36] = dpmin.d; // Expected: 0.0
+ outputBuffer[37] = dpmax.d; // Expected: 0.0
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt
index b00b0060b..b18853e90 100644
--- a/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-clamp.slang.expected.txt
@@ -29,3 +29,11 @@ type: float
0.000000
0.000000
0.300000
+0.400000
+0.700000
+1.000000
+0.000000
+0.000000
+1.000000
+0.000000
+0.000000