diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-20 23:31:05 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-20 23:31:05 -0800 |
| commit | ca592d2791047be4df3cac44c18af99f003bd085 (patch) | |
| tree | 9cbe6ca41ad95257d16bbcc01b0639505203916a /tests/autodiff-dstdlib | |
| parent | 4d286aab2ec23c081f23846f5dfdb30b1c05728b (diff) | |
Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. (#6411)
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests/autodiff-dstdlib')
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-max-min.slang | 112 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt | 21 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-max.slang | 52 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt | 11 |
4 files changed, 133 insertions, 63 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-max-min.slang b/tests/autodiff-dstdlib/dstdlib-max-min.slang new file mode 100644 index 000000000..f37083706 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-max-min.slang @@ -0,0 +1,112 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef DifferentialPair<float2> dpfloat2; + +[BackwardDifferentiable] +float diffMax(float x, float y) +{ + return max(x, y); +} + +[BackwardDifferentiable] +float2 diffMax(float2 x, float2 y) +{ + return max(x, y); +} + +[BackwardDifferentiable] +float diffMin(float x, float y) +{ + return min(x, y); +} + +[BackwardDifferentiable] +float2 diffMin(float2 x, float2 y) +{ + return min(x, y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + // Test max() with x < y + { + dpfloat dpx = dpfloat(2.0, 1.0); + dpfloat dpy = dpfloat(5.0, -2.0); + dpfloat res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[0] = res.p; // Expect: 5.000000 + outputBuffer[1] = res.d; // Expect: -2.000000 + } + + // Test max() with x == y + { + dpfloat dpx = dpfloat(3.0, 1.0); + dpfloat dpy = dpfloat(3.0, -2.0); + dpfloat res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[2] = res.p; // Expect: 3.000000 + outputBuffer[3] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0) + } + + // Test min() with x > y + { + dpfloat dpx = dpfloat(5.0, 1.0); + dpfloat dpy = dpfloat(2.0, -2.0); + dpfloat res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[4] = res.p; // Expect: 2.000000 + outputBuffer[5] = res.d; // Expect: -2.000000 + } + + // Test min() with x == y + { + dpfloat dpx = dpfloat(3.0, 1.0); + dpfloat dpy = dpfloat(3.0, -2.0); + dpfloat res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[6] = res.p; // Expect: 3.000000 + outputBuffer[7] = res.d; // Expect: -0.500000 (average of 1.0 and -2.0) + } + + // Test backward-mode max() with x == y + { + dpfloat dpx = dpfloat(3.0, 0.0); + dpfloat dpy = dpfloat(3.0, 0.0); + __bwd_diff(diffMax)(dpx, dpy, 1.0); + outputBuffer[8] = dpx.d; // Expect: 0.500000 (half of gradient) + outputBuffer[9] = dpy.d; // Expect: 0.500000 (half of gradient) + } + + // Test backward-mode min() with x == y + { + dpfloat dpx = dpfloat(3.0, 0.0); + dpfloat dpy = dpfloat(3.0, 0.0); + __bwd_diff(diffMin)(dpx, dpy, 1.0); + outputBuffer[10] = dpx.d; // Expect: 0.500000 (half of gradient) + outputBuffer[11] = dpy.d; // Expect: 0.500000 (half of gradient) + } + + // Test vector max() with x == y + { + dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0)); + dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0)); + dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy); + outputBuffer[12] = res.p[0]; // Expect: 3.000000 + outputBuffer[13] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0) + outputBuffer[14] = res.p[1]; // Expect: 4.000000 + outputBuffer[15] = res.d[1]; // Expect: 2.000000 + } + + // Test vector min() with x == y + { + dpfloat2 dpx = dpfloat2(float2(3.0, 4.0), float2(1.0, 2.0)); + dpfloat2 dpy = dpfloat2(float2(3.0, 2.0), float2(-2.0, -3.0)); + dpfloat2 res = __fwd_diff(diffMin)(dpx, dpy); + outputBuffer[16] = res.p[0]; // Expect: 3.000000 + outputBuffer[17] = res.d[0]; // Expect: -0.500000 (average of 1.0 and -2.0) + outputBuffer[18] = res.p[1]; // Expect: 2.000000 + outputBuffer[19] = res.d[1]; // Expect: -3.000000 + } +} diff --git a/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt new file mode 100644 index 000000000..504343b58 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-max-min.slang.expected.txt @@ -0,0 +1,21 @@ +type: float +5.000000 +-2.000000 +3.000000 +-0.500000 +2.000000 +-2.000000 +3.000000 +-0.500000 +0.500000 +0.500000 +0.500000 +0.500000 +3.000000 +-0.500000 +4.000000 +2.000000 +3.000000 +-0.500000 +2.000000 +-3.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang b/tests/autodiff-dstdlib/dstdlib-max.slang deleted file mode 100644 index 026914c8c..000000000 --- a/tests/autodiff-dstdlib/dstdlib-max.slang +++ /dev/null @@ -1,52 +0,0 @@ -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type - -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer -RWStructuredBuffer<float> outputBuffer; - -typedef DifferentialPair<float> dpfloat; -typedef DifferentialPair<float2> dpfloat2; - -[BackwardDifferentiable] -float diffMax(float x, float y) -{ - return max(x, y); -} - -[BackwardDifferentiable] -float2 diffMax(float2 x, float2 y) -{ - return max(x, y); -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) -{ - { - dpfloat dpx = dpfloat(2.0, 1.0); - dpfloat dpy = dpfloat(5.0, -2.0); - dpfloat res = __fwd_diff(diffMax)(dpx, dpy); - outputBuffer[0] = res.p; // Expect: 5.000000 - outputBuffer[1] = res.d; // Expect: -2.000000 - } - - { - dpfloat2 dpx = dpfloat2(float2(-3.0, 4.0), float2(-1.0, -1.0)); - dpfloat2 dpy = dpfloat2(float2(1.0, 2.0), float2(2.0, 2.0)); - dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy); - outputBuffer[2] = res.p[0]; // Expect: 1.000000 - outputBuffer[3] = res.d[0]; // Expect: 2.000000 - outputBuffer[4] = res.p[1]; // Expect: 4.000000 - outputBuffer[5] = res.d[1]; // Expect: -1.000000 - } - - { - dpfloat2 dpx = dpfloat2(float2(2.0, 3.0), float2(0.0, 0.0)); - dpfloat2 dpy = dpfloat2(float2(5.0, 1.0), float2(0.0, 0.0)); - __bwd_diff(diffMax)(dpx, dpy, float2(1.0, 2.0)); - outputBuffer[6] = dpx.d[0]; // Expect: 0.000000 - outputBuffer[7] = dpx.d[1]; // Expect: 2.000000 - outputBuffer[8] = dpy.d[0]; // Expect: 1.000000 - outputBuffer[9] = dpy.d[1]; // Expect: 0.000000 - } -} diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt deleted file mode 100644 index 4cc1e9533..000000000 --- a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt +++ /dev/null @@ -1,11 +0,0 @@ -type: float -5.000000 --2.000000 -1.000000 -2.000000 -4.000000 --1.000000 -0.000000 -2.000000 -1.000000 -0.000000
\ No newline at end of file |
