From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- tests/autodiff/primal-substitute-2.slang | 34 ++++++++++++++ .../primal-substitute-2.slang.expected.txt | 6 +++ tests/autodiff/primal-substitute-3.slang | 52 ++++++++++++++++++++++ .../primal-substitute-3.slang.expected.txt | 6 +++ tests/autodiff/primal-substitute.slang | 27 +++++++++++ .../autodiff/primal-substitute.slang.expected.txt | 3 ++ 6 files changed, 128 insertions(+) create mode 100644 tests/autodiff/primal-substitute-2.slang create mode 100644 tests/autodiff/primal-substitute-2.slang.expected.txt create mode 100644 tests/autodiff/primal-substitute-3.slang create mode 100644 tests/autodiff/primal-substitute-3.slang.expected.txt create mode 100644 tests/autodiff/primal-substitute.slang create mode 100644 tests/autodiff/primal-substitute.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/primal-substitute-2.slang b/tests/autodiff/primal-substitute-2.slang new file mode 100644 index 000000000..6c53f84a6 --- /dev/null +++ b/tests/autodiff/primal-substitute-2.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +float original(float x) +{ + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst(float x) +{ + return 2.0f * x * x; +} + +[BackwardDifferentiable] +float caller(float x) +{ + return original(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var a = diffPair(3.0, 1.0); + __bwd_diff(caller)(a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(caller)(diffPair(3.0, 1.0)).p; // Expect: 18.0 + outputBuffer[2] = caller(3.0); // Expect: 9.0 +} diff --git a/tests/autodiff/primal-substitute-2.slang.expected.txt b/tests/autodiff/primal-substitute-2.slang.expected.txt new file mode 100644 index 000000000..ee60dfa22 --- /dev/null +++ b/tests/autodiff/primal-substitute-2.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +18.000000 +9.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang new file mode 100644 index 000000000..ab2899bdc --- /dev/null +++ b/tests/autodiff/primal-substitute-3.slang @@ -0,0 +1,52 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + float doSomething(); +} + +struct A : IFoo +{ + float doSomething() + { + return 0.0f; + } +} + +float original(T p, float x) +{ + p.doSomething(); + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst(T p, float x) +{ + return 2.0f * x * x; +} + +[BackwardDifferentiable] +float caller(IFoo d, float x) +{ + return original(d, x); +} + +//TEST_INPUT: type_conformance A:IFoo = 0 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject(dispatchThreadID.x, 0); // A + + var a = diffPair(3.0, 1.0); + __bwd_diff(caller)(obj, a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(caller)(obj, diffPair(3.0, 1.0)).p; // Expect: 18.0 + outputBuffer[2] = caller(obj, 3.0); // Expect: 9.0 +} diff --git a/tests/autodiff/primal-substitute-3.slang.expected.txt b/tests/autodiff/primal-substitute-3.slang.expected.txt new file mode 100644 index 000000000..ee60dfa22 --- /dev/null +++ b/tests/autodiff/primal-substitute-3.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +18.000000 +9.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/primal-substitute.slang b/tests/autodiff/primal-substitute.slang new file mode 100644 index 000000000..01f221f2a --- /dev/null +++ b/tests/autodiff/primal-substitute.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +float original(float x) +{ + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst(float x) +{ + return 2.0f * x * x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var a = diffPair(3.0, 1.0); + __bwd_diff(original)(a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(original)(diffPair(3.0, 1.0)).p; // Expect: 18.0 +} diff --git a/tests/autodiff/primal-substitute.slang.expected.txt b/tests/autodiff/primal-substitute.slang.expected.txt new file mode 100644 index 000000000..af1b9f528 --- /dev/null +++ b/tests/autodiff/primal-substitute.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +12.0 +18.0 -- cgit v1.2.3