summaryrefslogtreecommitdiff
path: root/tests/autodiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-08 14:56:20 -0800
committerGitHub <noreply@github.com>2022-12-08 14:56:20 -0800
commit41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (patch)
treec6cde57da4d3415d86d09213936a48d3d26e07e1 /tests/autodiff
parent468bb7ecf65c000c308adae511bf65a1ca4cc412 (diff)
Auto-diff for matrix operations. (#2559)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff')
-rw-r--r--tests/autodiff/matrix-arithmetic-fwd.slang41
-rw-r--r--tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt5
2 files changed, 46 insertions, 0 deletions
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang
new file mode 100644
index 000000000..7a953cef8
--- /dev/null
+++ b/tests/autodiff/matrix-arithmetic-fwd.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[ForwardDifferentiable]
+float3x3 g(float3x3 x, float3x3 y)
+{
+ float3x3 a = x + y;
+ float3x3 b = x - y;
+ return a * b + 2 * x * y;
+}
+
+[ForwardDifferentiable]
+float h(float2x2 x, float2x2 y)
+{
+ let t = mul(x, y);
+ return t[0][0] + t[0][1] + t[1][0] + t[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float3x3 a = float3x3(2.0);
+ float3x3 b = float3x3(1.5);
+ float3x3 da = float3x3(1.0);
+
+ outputBuffer[0] = __fwd_diff(g)(
+ DifferentialPair<float3x3>(a, da),
+ DifferentialPair<float3x3>(b, da)).d._11; // Expect: 8
+
+ float2x2 l = float2x2(1.0, 2.0, 3.0, 4.0);
+ float2x2 r = float2x2(10.0, 11.0, 12.0, 13.0);
+ float2x2 d = float2x2(1.0, 0.0, 1.0, 1.0);
+
+ //float2x2 epsilon = d * 0.001f;
+ //outputBuffer[1] = (h(l + epsilon, r + epsilon) - h(l - epsilon, r - epsilon)) / (epsilon[0][0] * 2.0));
+
+ outputBuffer[1] = __fwd_diff(h)(DifferentialPair<float2x2>(l, d), DifferentialPair<float2x2>(r, d)).d; // Expect 83.0
+}
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt
new file mode 100644
index 000000000..c595048c3
--- /dev/null
+++ b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+8.0
+83.0
+0.0
+0.0 \ No newline at end of file