diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-14 22:50:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-14 22:50:57 -0800 |
| commit | 1c9b33157322751c456bf7abbd386edccf4413c3 (patch) | |
| tree | bbe4d28172a839fb06ac4b9e8c983a619bf04842 /tests | |
| parent | 14fab67c5edd8eb697ffb10dbcc0467678521eef (diff) | |
Support custom backward derivative attribute. (#2594)
Diffstat (limited to 'tests')
6 files changed, 149 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang new file mode 100644 index 000000000..bd0780174 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang @@ -0,0 +1,61 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return 1.0; } +}; + +struct B : IInterface +{ + static float calc(float x) { return 2.0; } +}; + +void dsqr<T:IInterface>(T obj, inout DifferentialPair<float> x, float dOut) +{ + float diff = 2.0 * x.p * dOut; + updateDiff(x, diff); +} + +[BackwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} + +// Use automatically differentiated outer function to triger the primal/propagate func generation logic +// on a function that has user provided backward derivative. +[BackwardDifferentiable] +float sqr_outter<T:IInterface>(T obj, float x) +{ + return sqr(obj, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + var p = DifferentialPair<float>(2.0, 1.0); + __bwd_diff(sqr_outter)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[0] = p.d; + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + p = DifferentialPair<float>(1.5, 1.0); + __bwd_diff(sqr)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[1] = p.d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt new file mode 100644 index 000000000..780ba6ed4 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang new file mode 100644 index 000000000..930c1c82b --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang @@ -0,0 +1,53 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return 1.0; } +}; + +struct B : IInterface +{ + static float calc(float x) { return 2.0; } +}; + +DifferentialPair<float> dsqr<T:IInterface>(T obj, DifferentialPair<float> x) +{ + float primal = obj.calc(x.p) + x.p * x.p; + float diff = 2.0 * x.p * x.d; + return diffPair(primal, diff); +} + +[ForwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + var p = DifferentialPair<float>(2.0, 1.0); + + outputBuffer[0] = __fwd_diff(sqr)(obj, p).d; // A.calc, expect 4 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + p = DifferentialPair<float>(1.5, 1.0); + outputBuffer[1] = __fwd_diff(sqr)(obj, p).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt new file mode 100644 index 000000000..780ba6ed4 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/language-server/robustness-6.slang b/tests/language-server/robustness-6.slang new file mode 100644 index 000000000..ef5924cf3 --- /dev/null +++ b/tests/language-server/robustness-6.slang @@ -0,0 +1,10 @@ +//TEST:LANG_SERVER: +//HOVER:4,8 + +float dsqr<T:II + +[ForwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} diff --git a/tests/language-server/robustness-6.slang.expected.txt b/tests/language-server/robustness-6.slang.expected.txt new file mode 100644 index 000000000..d5aa6c8c9 --- /dev/null +++ b/tests/language-server/robustness-6.slang.expected.txt @@ -0,0 +1,13 @@ +-------- +range: 3,6 - 3,10 +content: +``` +func dsqr<T>(T obj, float x) -> float +``` + +TEST:LANG_SERVER: +HOVER:4,8 + +{REDACTED}.slang(4) + + |
