From 6804680b7da2133f04513293836f70ff61d03b77 Mon Sep 17 00:00:00 2001 From: winmad Date: Thu, 7 Sep 2023 15:43:58 -0700 Subject: Fix erroneous diagnostic when checking a generic differentiable mutable method. (#3192) * Add a unit test * Fix. --------- Co-authored-by: Lifan Wu Co-authored-by: Yong He --- tests/autodiff/struct-this-parameter.slang | 50 ++++++++++++++++++++++ .../struct-this-parameter.slang.expected.txt | 4 ++ 2 files changed, 54 insertions(+) create mode 100644 tests/autodiff/struct-this-parameter.slang create mode 100644 tests/autodiff/struct-this-parameter.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/struct-this-parameter.slang b/tests/autodiff/struct-this-parameter.slang new file mode 100644 index 000000000..9c8ddc724 --- /dev/null +++ b/tests/autodiff/struct-this-parameter.slang @@ -0,0 +1,50 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct A : IDifferentiable +{ + float data[3]; + + [Differentiable] + __init() + { + [ForceUnroll] for (uint i = 0; i < 3; i++) data[i] = 0.f; + } + + [mutating] + [Differentiable] + void write(float value, inout uint offset) + { + data[offset++] = value; + } + + [mutating] + [Differentiable] + void write(vector value, inout uint offset) + { + [ForceUnroll] for (uint i = 0; i < N; i++) write(value[i], offset); + } +} + +[Differentiable] +float3 run(float3 x) +{ + A a = A(); + uint offset = 0; + a.write(x, offset); + return float3(a.data[0], a.data[1], a.data[2]); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + DifferentialPair dp = diffPair(float3(1.f, 2.f, 3.f), float3(0.f)); + float3 dOut = float3(1.f, 0.5f, 0.1f); + bwd_diff(run)(dp, dOut); + outputBuffer[0] = dp.d[0]; + outputBuffer[1] = dp.d[1]; + outputBuffer[2] = dp.d[2]; +} diff --git a/tests/autodiff/struct-this-parameter.slang.expected.txt b/tests/autodiff/struct-this-parameter.slang.expected.txt new file mode 100644 index 000000000..bd23befb3 --- /dev/null +++ b/tests/autodiff/struct-this-parameter.slang.expected.txt @@ -0,0 +1,4 @@ +type: float +1.000000 +0.500000 +0.100000 -- cgit v1.2.3