diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-11-21 10:29:57 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-21 10:29:57 -0500 |
| commit | 545de51298ddda52ac51ded03ad489c98bdda397 (patch) | |
| tree | def78374f743d2c722fbde45eba60951a6f5c8f9 /tests | |
| parent | d58e08f8237a1888ceaad53402d534679ea83b1a (diff) | |
WIP: Fixed inout struct and added testing for calls to non-differentiable functions (#2505)
* Added non-differentiable call test
* Extended testing for nondifferentiable calls
* Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage
* More fixes. inout struct params now work fine
* Update inout-struct-parameters-jvp.slang
* Update slang-ir.cpp
* Fixed hoisting lookup_interface_method
* Fixed non-diff call return value
* Fixed issue with phi nodes
* Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType
* Fixed non-diff call test to conform to the new 'no_diff' system
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/inout-struct-parameters-jvp.slang | 41 | ||||
| -rw-r--r-- | tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/nondiff-call.slang | 66 | ||||
| -rw-r--r-- | tests/autodiff/nondiff-call.slang.expected.txt | 6 |
4 files changed, 118 insertions, 0 deletions
diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang b/tests/autodiff/inout-struct-parameters-jvp.slang new file mode 100644 index 000000000..80ff57b7d --- /dev/null +++ b/tests/autodiff/inout-struct-parameters-jvp.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +struct A : IDifferentiable +{ + float p; + float3 q; +} + +[ForwardDifferentiable] +void g(A a, inout A aout) +{ + float t = a.p + a.q.y * a.q.x; + aout.p = aout.p + t; + aout.q = aout.q * t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float p = 1.0; + float3 q = float3(1.0, 2.0, 3.0); + + float dp = 1.0; + float3 dq = float3(1.0, 0.5, 0.25); + + DifferentialPair<A> dpa = DifferentialPair<A>({p, q}, {dp, dq}); + + __fwd_diff(g)(DifferentialPair<A>( { p, q }, { dp, dq }), dpa); + + outputBuffer[0] = dpa.p.p; // Expect: 4.0 + outputBuffer[1] = dpa.d.q.x; // Expect: 6.5 + outputBuffer[2] = dpa.d.q.y; // Expect: 8.5 + outputBuffer[3] = dpa.d.q.z; // Expect: 11.25 + +}
\ No newline at end of file diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt new file mode 100644 index 000000000..4cc3c313d --- /dev/null +++ b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +4.000000 +6.500000 +8.500000 +11.25000 diff --git a/tests/autodiff/nondiff-call.slang b/tests/autodiff/nondiff-call.slang new file mode 100644 index 000000000..d62de1b78 --- /dev/null +++ b/tests/autodiff/nondiff-call.slang @@ -0,0 +1,66 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef DifferentialPair<float3> dpfloat3; + +[ForwardDifferentiable] +float f(float x) +{ + return x * x + x * x * x; +} + +[ForwardDifferentiable] +float f2(float x) +{ + return f(x); +} + +float g(float x) +{ + return x * x + x * x * x; +} + +[ForwardDifferentiable] +float g2(float x) +{ + return no_diff(g(x)); +} + +struct A +{ + float o; + + [ForwardDifferentiable] + float doSomethingDifferentiable(float b) + { + return o + b; + } + + float doSomethingNotDifferentiable(float b) + { + return o * b; + } +} + +[ForwardDifferentiable] +float h2(A a, float k) +{ + float v = k * k; + return no_diff(a.doSomethingNotDifferentiable(k)) + v; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + outputBuffer[0] = f2(1.0); // Expect: 2.0 + outputBuffer[1] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).d; // Expect: 5.0 + outputBuffer[2] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).p; // Expect: 2.0 + outputBuffer[3] = __fwd_diff(g2)(dpfloat(1.0, 1.0)).d; // Expect: 0.0 + outputBuffer[4] = __fwd_diff(h2)({1.0}, DifferentialPair<float>(1.0, 2.0)).d; // Expect: 4.0 + } +} diff --git a/tests/autodiff/nondiff-call.slang.expected.txt b/tests/autodiff/nondiff-call.slang.expected.txt new file mode 100644 index 000000000..8f85913bc --- /dev/null +++ b/tests/autodiff/nondiff-call.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +5.000000 +2.000000 +0.000000 +4.000000
\ No newline at end of file |
