diff options
| author | winmad <winmad.wlf@gmail.com> | 2023-09-07 15:43:58 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-07 15:43:58 -0700 |
| commit | 6804680b7da2133f04513293836f70ff61d03b77 (patch) | |
| tree | 68092223595fac8c8c2ac6dcb3b305842b7f0516 /tests | |
| parent | 269282fd3647f1b201d2aae4c82b0c0af16c6420 (diff) | |
Fix erroneous diagnostic when checking a generic differentiable mutable method. (#3192)
* Add a unit test
* Fix.
---------
Co-authored-by: Lifan Wu <lifanw@nvidia.com>
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/struct-this-parameter.slang | 50 | ||||
| -rw-r--r-- | tests/autodiff/struct-this-parameter.slang.expected.txt | 4 |
2 files changed, 54 insertions, 0 deletions
diff --git a/tests/autodiff/struct-this-parameter.slang b/tests/autodiff/struct-this-parameter.slang new file mode 100644 index 000000000..9c8ddc724 --- /dev/null +++ b/tests/autodiff/struct-this-parameter.slang @@ -0,0 +1,50 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A : IDifferentiable +{ + float data[3]; + + [Differentiable] + __init() + { + [ForceUnroll] for (uint i = 0; i < 3; i++) data[i] = 0.f; + } + + [mutating] + [Differentiable] + void write(float value, inout uint offset) + { + data[offset++] = value; + } + + [mutating] + [Differentiable] + void write<let N : int>(vector<float, N> value, inout uint offset) + { + [ForceUnroll] for (uint i = 0; i < N; i++) write(value[i], offset); + } +} + +[Differentiable] +float3 run(float3 x) +{ + A a = A(); + uint offset = 0; + a.write(x, offset); + return float3(a.data[0], a.data[1], a.data[2]); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + DifferentialPair<float3> dp = diffPair(float3(1.f, 2.f, 3.f), float3(0.f)); + float3 dOut = float3(1.f, 0.5f, 0.1f); + bwd_diff(run)(dp, dOut); + outputBuffer[0] = dp.d[0]; + outputBuffer[1] = dp.d[1]; + outputBuffer[2] = dp.d[2]; +} diff --git a/tests/autodiff/struct-this-parameter.slang.expected.txt b/tests/autodiff/struct-this-parameter.slang.expected.txt new file mode 100644 index 000000000..bd23befb3 --- /dev/null +++ b/tests/autodiff/struct-this-parameter.slang.expected.txt @@ -0,0 +1,4 @@ +type: float +1.000000 +0.500000 +0.100000 |
