diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-08 21:52:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-08 21:52:34 -0800 |
| commit | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch) | |
| tree | b4f9eb6cb1eea88145fde0bd1f670a8803120257 /tests | |
| parent | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff) | |
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`.
* Fix
* Fix.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/primal-substitute-2.slang | 34 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute-2.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute-3.slang | 52 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute-3.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute.slang | 27 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute.slang.expected.txt | 3 |
6 files changed, 128 insertions, 0 deletions
diff --git a/tests/autodiff/primal-substitute-2.slang b/tests/autodiff/primal-substitute-2.slang new file mode 100644 index 000000000..6c53f84a6 --- /dev/null +++ b/tests/autodiff/primal-substitute-2.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +float original(float x) +{ + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst(float x) +{ + return 2.0f * x * x; +} + +[BackwardDifferentiable] +float caller(float x) +{ + return original(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var a = diffPair(3.0, 1.0); + __bwd_diff(caller)(a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(caller)(diffPair(3.0, 1.0)).p; // Expect: 18.0 + outputBuffer[2] = caller(3.0); // Expect: 9.0 +} diff --git a/tests/autodiff/primal-substitute-2.slang.expected.txt b/tests/autodiff/primal-substitute-2.slang.expected.txt new file mode 100644 index 000000000..ee60dfa22 --- /dev/null +++ b/tests/autodiff/primal-substitute-2.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +18.000000 +9.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang new file mode 100644 index 000000000..ab2899bdc --- /dev/null +++ b/tests/autodiff/primal-substitute-3.slang @@ -0,0 +1,52 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + float doSomething(); +} + +struct A : IFoo +{ + float doSomething() + { + return 0.0f; + } +} + +float original<T : IFoo>(T p, float x) +{ + p.doSomething(); + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst<T : IFoo>(T p, float x) +{ + return 2.0f * x * x; +} + +[BackwardDifferentiable] +float caller(IFoo d, float x) +{ + return original(d, x); +} + +//TEST_INPUT: type_conformance A:IFoo = 0 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IFoo>(dispatchThreadID.x, 0); // A + + var a = diffPair(3.0, 1.0); + __bwd_diff(caller)(obj, a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(caller)(obj, diffPair(3.0, 1.0)).p; // Expect: 18.0 + outputBuffer[2] = caller(obj, 3.0); // Expect: 9.0 +} diff --git a/tests/autodiff/primal-substitute-3.slang.expected.txt b/tests/autodiff/primal-substitute-3.slang.expected.txt new file mode 100644 index 000000000..ee60dfa22 --- /dev/null +++ b/tests/autodiff/primal-substitute-3.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +18.000000 +9.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/primal-substitute.slang b/tests/autodiff/primal-substitute.slang new file mode 100644 index 000000000..01f221f2a --- /dev/null +++ b/tests/autodiff/primal-substitute.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +float original(float x) +{ + return x * x; +} + +[PrimalSubstituteOf(original)] +[BackwardDifferentiable] +float primalSubst(float x) +{ + return 2.0f * x * x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var a = diffPair(3.0, 1.0); + __bwd_diff(original)(a, 1.0); + outputBuffer[0] = a.d; // Expect: 12.0 + outputBuffer[1] = __fwd_diff(original)(diffPair(3.0, 1.0)).p; // Expect: 18.0 +} diff --git a/tests/autodiff/primal-substitute.slang.expected.txt b/tests/autodiff/primal-substitute.slang.expected.txt new file mode 100644 index 000000000..af1b9f528 --- /dev/null +++ b/tests/autodiff/primal-substitute.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +12.0 +18.0 |
