diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-25 17:27:40 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-25 17:27:40 -0800 |
| commit | 1f4c7cab13341c2e9d48df2b01ed2c048c17c152 (patch) | |
| tree | ed85dda63e1c939cf474961b965b7cc1883940bb /tests | |
| parent | aa6814be1f7dea20597ae34d477e79e53d4a543f (diff) | |
Unify UpdateField and UpdateElement with access chain. (#2611)
* Unify UpdateField and UpdateElement with access chain.
* Fix warnings.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/reverse-addr-eliminate.slang | 14 | ||||
| -rw-r--r-- | tests/autodiff/reverse-addr-eliminate.slang.expected.txt | 2 |
2 files changed, 12 insertions, 4 deletions
diff --git a/tests/autodiff/reverse-addr-eliminate.slang b/tests/autodiff/reverse-addr-eliminate.slang index daa6fa32b..e23e83e6a 100644 --- a/tests/autodiff/reverse-addr-eliminate.slang +++ b/tests/autodiff/reverse-addr-eliminate.slang @@ -4,6 +4,11 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +struct D : IDifferentiable +{ + float n; + float m; +} struct C : IDifferentiable { float3 t; @@ -22,6 +27,7 @@ struct A : IDifferentiable float y; B fb; C aarr[3]; + D dv; }; [BackwardDifferentiable] @@ -33,7 +39,9 @@ A f(A a, int i) aout.x = aout.y + 5 * a.x; aout.aarr[1].t = float3(a.y, 0.0, a.x); aout.aarr[1].t = float3(a.y, 1.0, a.x + 1.0); - + D nd = { a.x * 4.0f, 1.0f }; + aout.dv = nd; + aout.dv.m = aout.dv.n * 0.5f; // Test that writes to a potentially dynamic address multiple times // is allowed and will propagate the correct derivative. aout.fb.arr[i].v = a.x * 2.0; // since this value is overwritten, the diff will not accumulate to a.x @@ -48,9 +56,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) var dpa = diffPair(a); - A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }; + A.Differential dout = { 1.0, 1.0, { float2(0), { { float3(1.0), 1.0 }, { float3(1.0), 1.0 } } }, { { float3(1.0), 1.0 }, { float3(1.0), 1.0 }, { float3(1.0), 1.0 } }, {1.0, 1.0} }; __bwd_diff(f)(dpa, 1, dout); - outputBuffer[0] = dpa.d.x; // Expect: 17 + outputBuffer[0] = dpa.d.x; // Expect: 23 outputBuffer[1] = dpa.d.y; // Expect: 0 } diff --git a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt index dd367f3f5..fddc3120a 100644 --- a/tests/autodiff/reverse-addr-eliminate.slang.expected.txt +++ b/tests/autodiff/reverse-addr-eliminate.slang.expected.txt @@ -1,5 +1,5 @@ type: float -17.000000 +23.000000 1.000000 0.000000 0.000000 |
