summaryrefslogtreecommitdiffstats
path: root/tests/autodiff-dstdlib/dstdlib-sqrt.slang
diff options
context:
space:
mode:
Diffstat (limited to 'tests/autodiff-dstdlib/dstdlib-sqrt.slang')
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang11
1 files changed, 10 insertions, 1 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
index 15573c4ef..d68a2697c 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -50,4 +50,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[6] = dpx.d[0]; // Expect: 0.158114
outputBuffer[7] = dpx.d[1]; // Expect: 0.577350
}
+
+ {
+ var dpx = diffPair(float2x2(4.0, 9.0, 16.0, 25.0), float2x2(0.0, 0.0, 0.0, 0.0));
+ __bwd_diff(sqrt)(dpx, float2x2(1.0, 2.0, 3.0, 4.0));
+ outputBuffer[8] = dpx.d[0][0]; // Expect: 0.25
+ outputBuffer[9] = dpx.d[0][1]; // Expect: 0.3333
+ outputBuffer[10] = dpx.d[1][0]; // Expect: 0.375
+ outputBuffer[11] = dpx.d[1][1]; // Expect: 0.4
+ }
}