summaryrefslogtreecommitdiff
path: root/tests/autodiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-27 21:21:39 -0800
committerGitHub <noreply@github.com>2023-02-27 21:21:39 -0800
commitf23e36243e9c59c02f66ec2e18b80ba4ea540f45 (patch)
tree6bf0e2a3676fe84067f70fcbda4549fa4eb6504d /tests/autodiff
parent10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (diff)
Diagnose on storing differentiable value into non-differentiable location. (#2681)
Diffstat (limited to 'tests/autodiff')
-rw-r--r--tests/autodiff/getter-setter.slang2
-rw-r--r--tests/autodiff/reverse-inout-param-2.slang6
-rw-r--r--tests/autodiff/reverse-inout-param-2.slang.expected.txt2
3 files changed, 5 insertions, 5 deletions
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 705604bbb..06caadce8 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -45,7 +45,7 @@ typedef DifferentialPair<A> dpA;
A f(A a)
{
A aout;
- aout.y = 2 * a.x;
+ aout.y = detach(2 * a.x);
aout.x = 5 * a.x;
return aout;
diff --git a/tests/autodiff/reverse-inout-param-2.slang b/tests/autodiff/reverse-inout-param-2.slang
index 18eb825e6..813238863 100644
--- a/tests/autodiff/reverse-inout-param-2.slang
+++ b/tests/autodiff/reverse-inout-param-2.slang
@@ -28,7 +28,7 @@ void g(
v1 = v2;
v2.nd = v2.nd + 1.0;
p.n = v1.nd + 1.0;
- p.m = v2.nd + 1.0 + x; // == v2.nd + 2 + x == 1 + 2 + x == 3+x
+ p.m = detach(v2.nd + 1.0 + x); // == v2.nd + 2 + x == 1 + 2 + x == 3+x
po = p;
po.m += 1.0; // == 4+x
y = p.m * x; // == (3+x)*x
@@ -39,7 +39,7 @@ void f(inout no_diff D p, out no_diff D p0, out ND v1, inout ND v2, float x, out
{
// v2.nd is 3.
g(p, p0, v1, v2, x, y);
- // v2.nd is now 4, now g is equivalent to (4+x)*x.
+ // v2.nd is now 4, now g is equivalent to detach(4+x)*x, so g' = 9.
g(p, p0, v1, v2, x, y);
}
@@ -64,7 +64,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
__bwd_diff(f)(p, v2, x, yDiffOut);
outputBuffer[0] = x.p; // should be 5, since bwd_diff does not write back new primal val.
- outputBuffer[1] = x.d; // 14
+ outputBuffer[1] = x.d; // 9
outputBuffer[2] = p.m; // 1.0
outputBuffer[3] = p.n; // 2.0
outputBuffer[4] = v2.nd; // 1.0
diff --git a/tests/autodiff/reverse-inout-param-2.slang.expected.txt b/tests/autodiff/reverse-inout-param-2.slang.expected.txt
index 65933cc7d..492ad8d10 100644
--- a/tests/autodiff/reverse-inout-param-2.slang.expected.txt
+++ b/tests/autodiff/reverse-inout-param-2.slang.expected.txt
@@ -1,6 +1,6 @@
type: float
5.000000
-14.000000
+9.000000
1.000000
2.000000
1.000000