From e004511b5f75bb24df1adec71b005146917afb39 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 23 Mar 2023 22:27:30 -0400 Subject: AD: Fix type checking for higher-order custom derivatives definitions (#2729) * Fixed type coercion issue with higher-order user defined methods * Placed associated type lookup method in a loop * Update high-order-user-defined-derivative.slang * Revert changes to associated type lookup method --- .../high-order-user-defined-derivative.slang | 77 ++++++++++++++++++++++ ...rder-user-defined-derivative.slang.expected.txt | 10 +++ 2 files changed, 87 insertions(+) create mode 100644 tests/autodiff/high-order-user-defined-derivative.slang create mode 100644 tests/autodiff/high-order-user-defined-derivative.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/high-order-user-defined-derivative.slang b/tests/autodiff/high-order-user-defined-derivative.slang new file mode 100644 index 000000000..4ad4aad12 --- /dev/null +++ b/tests/autodiff/high-order-user-defined-derivative.slang @@ -0,0 +1,77 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +//TEST_INPUT:ubuffer(data=[0.0 1.0 2.0 3.0], stride=4):name=endpointBuffer +RWStructuredBuffer endpointBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=endpointDifferentialBuffer +RWStructuredBuffer endpointDifferentialBuffer; + +struct LineSegment : IDifferentiable +{ + float x0; + float x1; + + [BackwardDifferentiable] + __init(float _x0, float _x1) + { + x0 = _x0; + x1 = _x1; + } +}; + +[BackwardDerivative(d_loadLineSegment)] +[ForwardDerivative(fwd_loadLineSegment)] +LineSegment loadLineSegment(uint id) +{ + return {endpointBuffer[id * 2], endpointBuffer[id * 2 + 1]}; +} + +[BackwardDerivative(d_fwd_loadLineSegment)] +DifferentialPair fwd_loadLineSegment(uint id) +{ + return DifferentialPair(loadLineSegment(id), LineSegment.dzero()); +} + +void d_loadLineSegment(uint id, LineSegment.Differential d_ls) +{ + endpointDifferentialBuffer[id * 2] += d_ls.x0; + endpointDifferentialBuffer[id * 2 + 1] += d_ls.x1; +} + +void d_fwd_loadLineSegment(uint id, DifferentialPair.Differential dp_ls) +{ + endpointDifferentialBuffer[id * 2] += dp_ls.p.x0; + endpointDifferentialBuffer[id * 2 + 1] += dp_ls.p.x1; +} + +[BackwardDifferentiable] +float something() +{ + LineSegment ls = __fwd_diff(loadLineSegment)(1).p; + return ls.x0 + ls.x1; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + LineSegment ls = __fwd_diff(loadLineSegment)(0).p; + outputBuffer[0] = ls.x0; // Expect: 0 + outputBuffer[1] = ls.x1; // Expect: 1 + } + + { + LineSegment.Differential d_ls = __fwd_diff(loadLineSegment)(0).d; + outputBuffer[2] = d_ls.x1; // Expect: 0 + } + + { + // Expect: 2.0 in endpointDifferentialBuffer[2] + // Expect: 2.0 in endpointDifferentialBuffer[3] + __bwd_diff(something)(2.0); + } +} \ No newline at end of file diff --git a/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt new file mode 100644 index 000000000..b5f259133 --- /dev/null +++ b/tests/autodiff/high-order-user-defined-derivative.slang.expected.txt @@ -0,0 +1,10 @@ +type: float +0.000000 +1.000000 +0.000000 +0.000000 +type: float +0.000000 +0.000000 +2.000000 +2.000000 -- cgit v1.2.3