diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-10-03 16:02:16 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-03 16:02:16 -0400 |
| commit | 9f246a43667b4893040669873400e2e3813328ff (patch) | |
| tree | f1fafe8c266b1db6f5f2cb76ab4fb7332cc2be54 /tests | |
| parent | aa64c853142076b17bd020f1386ea5fc6fcd5e3e (diff) | |
Support custom derivatives of member functions of differentiable types (#5124)
* Initial work to support custom derivatives for member methods of differentiable types
* Support custom derivatives of member functions of differentiable types
- Also adds support for declaring custom derivatives via extensions.
* Fix
* move defs
* Update slang-check-decl.cpp
* Create diff-member-func-custom-derivative.slang.expected.txt
* Update slang-check-decl.cpp
* Fix for static custom derivatives
* Fix diagnostics for [PreferRecompute]
* Add backward custom derivative tests
Diffstat (limited to 'tests')
6 files changed, 181 insertions, 0 deletions
diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang b/tests/autodiff/diff-member-func-custom-derivative.slang new file mode 100644 index 000000000..4e4f540f9 --- /dev/null +++ b/tests/autodiff/diff-member-func-custom-derivative.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A : IDifferentiable +{ + float x; + + [ForwardDerivative(diff_f)] + float f(float v) + { + return v * v; + } + + static DifferentialPair<float> diff_f(DifferentialPair<A> dpa, DifferentialPair<float> v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } + + [BackwardDerivative(diff_g)] + float g(float v) + { + return v * v; + } + + static void diff_g(inout DifferentialPair<A> dpa, inout DifferentialPair<float> v, float dOut) + { + v = diffPair(v.p, dOut * 2.0); + } +} + +[ForwardDifferentiable] +float test(A obj, float v) +{ + return obj.f(v); +} + +[BackwardDifferentiable] +float test2(A obj, float v) +{ + return obj.g(v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = fwd_diff(test)(diffPair(a, {1.0}), p); + + var q = diffPair(3.0); + var qa = diffPair(a); + bwd_diff(test2)(qa, q, 1.0); + + outputBuffer[0] = rs.d; + outputBuffer[1] = q.d; +} diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..1bb28547d --- /dev/null +++ b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +6.000000 +2.000000
\ No newline at end of file diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang b/tests/autodiff/member-func-extension-custom-derivative.slang new file mode 100644 index 000000000..8752dfff5 --- /dev/null +++ b/tests/autodiff/member-func-extension-custom-derivative.slang @@ -0,0 +1,55 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A +{ + float x; + + float f(float v) + { + return v * v; + } +} + +extension A +{ + [ForwardDerivativeOf(f)] + DifferentialPair<float> diff_f(DifferentialPair<float> v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } +} + +struct Foo<T : IDifferentiable> +{ + T value; + T doThing() { return value; } +} + +extension<T : IDifferentiable> Foo<T> +{ + [ForwardDerivativeOf(doThing)] + DifferentialPair<T> diff_doThing() + { + return diffPair(value, T.dzero()); + } +} + + +[ForwardDifferentiable] +float test(Foo<float> obj, float v) +{ + return obj.doThing() * v; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Foo<float> a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = __fwd_diff(test)(a, p); + outputBuffer[0] = rs.d; +} diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..4b1f4c0d9 --- /dev/null +++ b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +0.000000 diff --git a/tests/autodiff/static-func-custom-derivative.slang b/tests/autodiff/static-func-custom-derivative.slang new file mode 100644 index 000000000..b75012735 --- /dev/null +++ b/tests/autodiff/static-func-custom-derivative.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A : IDifferentiable +{ + float x; + + [ForwardDerivative(diff_f)] + static float f(float v) + { + return v * v; + } + + static DifferentialPair<float> diff_f(DifferentialPair<float> v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } + + [BackwardDerivative(diff_g)] + static float g(float v) + { + return v * v; + } + + static void diff_g(inout DifferentialPair<float> v, float.Differential dOut) + { + v = diffPair(v.p, dOut * 2.0); + } +} + +[ForwardDifferentiable] +float test(A obj, float v) +{ + return obj.f(v); +} + +[BackwardDifferentiable] +float test2(A obj, float v) +{ + return obj.g(v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = fwd_diff(test)(diffPair(a, {1.0}), p); + + var q = diffPair(3.0); + var qa = diffPair(a); + bwd_diff(test2)(qa, q, 1.0); + + outputBuffer[0] = rs.d; + outputBuffer[1] = q.d; +} diff --git a/tests/autodiff/static-func-custom-derivative.slang.expected.txt b/tests/autodiff/static-func-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..1bb28547d --- /dev/null +++ b/tests/autodiff/static-func-custom-derivative.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +6.000000 +2.000000
\ No newline at end of file |
