summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-25 17:27:40 -0800
committerGitHub <noreply@github.com>2023-01-25 17:27:40 -0800
commit1f4c7cab13341c2e9d48df2b01ed2c048c17c152 (patch)
treeed85dda63e1c939cf474961b965b7cc1883940bb /tests
parentaa6814be1f7dea20597ae34d477e79e53d4a543f (diff)
Unify UpdateField and UpdateElement with access chain. (#2611)
* Unify UpdateField and UpdateElement with access chain. * Fix warnings. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/reverse-addr-eliminate.slang14
-rw-r--r--tests/autodiff/reverse-addr-eliminate.slang.expected.txt2
2 files changed, 12 insertions, 4 deletions
diff --git a/tests/autodiff/reverse-addr-eliminate.slang b/tests/autodiff/reverse-addr-eliminate.slang
index daa6fa32b..e23e83e6a 100644
--- a/tests/autodiff/reverse-addr-eliminate.slang
+++ b/tests/autodiff/reverse-addr-eliminate.slang
@@ -4,6 +4,11 @@
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
+struct D : IDifferentiable
+{
+ float n;
+ float m;
+}
struct C : IDifferentiable
{
float3 t;
@@ -22,6 +27,7 @@ struct A : IDifferentiable
float y;
B fb;
C aarr[3];
+ D dv;
};
[BackwardDifferentiable]
@@ -33,7 +39,9 @@ A f(A a, int i)
aout.x = aout.y + 5 * a.x;
aout.aarr[1].t = float3(a.y, 0.0, a.x);
aout.aarr[1].t = float3(a.y, 1.0, a.x + 1.0);
-
+ D nd = { a.x * 4.0f, 1.0f };
+ aout.dv = nd;
+ aout.dv.m = aout.dv.n * 0.5f;
// Test that writes to a potentially dynamic address multiple times
// is allowed and will propagate the correct derivative.
aout.fb.arr[i].v = a.x * 2.0; // since this value is overwritten, the diff will not accumulate to a.x
@@ -48,9 +56,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
var dpa = diffPair(a);
- A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } };
+ A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } }, {1.0, 1.0} };
__bwd_diff(f)(dpa, 1, dout);
- outputBuffer[0] = dpa.d.x; // Expect: 17
+ outputBuffer[0] = dpa.d.x; // Expect: 23
outputBuffer[1] = dpa.d.y; // Expect: 0
}
diff --git a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt
index dd367f3f5..fddc3120a 100644
--- a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt
+++ b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt
@@ -1,5 +1,5 @@
type: float
-17.000000
+23.000000
1.000000
0.000000
0.000000