summaryrefslogtreecommitdiff
path: root/tests/autodiff/reverse-struct-out.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-10 18:46:57 -0800
committerGitHub <noreply@github.com>2023-02-10 18:46:57 -0800
commitaec57d849ae20a305d08348cf543d19eabc2e2d6 (patch)
treeafac620a888d27ee1000b036c4ab8c3773180af3 /tests/autodiff/reverse-struct-out.slang
parent6e7b424953ae6732d4863e887e7e452396095d71 (diff)
Fix several autodiff bugs. (#2643)
Diffstat (limited to 'tests/autodiff/reverse-struct-out.slang')
-rw-r--r--tests/autodiff/reverse-struct-out.slang50
1 files changed, 50 insertions, 0 deletions
diff --git a/tests/autodiff/reverse-struct-out.slang b/tests/autodiff/reverse-struct-out.slang
new file mode 100644
index 000000000..af2f8becf
--- /dev/null
+++ b/tests/autodiff/reverse-struct-out.slang
@@ -0,0 +1,50 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A : IDifferentiable
+{
+ float x;
+ float y;
+};
+
+struct B : IDifferentiable
+{
+ float x;
+ float y;
+};
+
+typedef DifferentialPair<A> dpA;
+
+float id(float x)
+{
+ return x;
+}
+
+[BackwardDifferentiable]
+void f(A input, out B rs)
+{
+ rs.x = input.x * input.x;
+ // Derivative of rs.x should still propagate through this no_diff call.
+ rs.y = no_diff id(input.y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ A a = {3.0, 2.0};
+ A.Differential azero = {0.0, 0.0};
+
+ dpA dpa = dpA(a, azero);
+
+ B.Differential dout = {1.0, 1.0};
+
+ __bwd_diff(f)(dpa, dout);
+ outputBuffer[0] = dpa.d.x; // Expect: 6
+ outputBuffer[1] = dpa.d.y; // Expect: 0
+ }
+}