From 9f246a43667b4893040669873400e2e3813328ff Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:02:16 -0400 Subject: Support custom derivatives of member functions of differentiable types (#5124) * Initial work to support custom derivatives for member methods of differentiable types * Support custom derivatives of member functions of differentiable types - Also adds support for declaring custom derivatives via extensions. * Fix * move defs * Update slang-check-decl.cpp * Create diff-member-func-custom-derivative.slang.expected.txt * Update slang-check-decl.cpp * Fix for static custom derivatives * Fix diagnostics for [PreferRecompute] * Add backward custom derivative tests --- .../diff-member-func-custom-derivative.slang | 59 ++++++++++++++++++++++ ...ember-func-custom-derivative.slang.expected.txt | 3 ++ .../member-func-extension-custom-derivative.slang | 55 ++++++++++++++++++++ ...-extension-custom-derivative.slang.expected.txt | 2 + tests/autodiff/static-func-custom-derivative.slang | 59 ++++++++++++++++++++++ ...tatic-func-custom-derivative.slang.expected.txt | 3 ++ 6 files changed, 181 insertions(+) create mode 100644 tests/autodiff/diff-member-func-custom-derivative.slang create mode 100644 tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt create mode 100644 tests/autodiff/member-func-extension-custom-derivative.slang create mode 100644 tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt create mode 100644 tests/autodiff/static-func-custom-derivative.slang create mode 100644 tests/autodiff/static-func-custom-derivative.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang b/tests/autodiff/diff-member-func-custom-derivative.slang new file mode 100644 index 000000000..4e4f540f9 --- /dev/null +++ b/tests/autodiff/diff-member-func-custom-derivative.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct A : IDifferentiable +{ + float x; + + [ForwardDerivative(diff_f)] + float f(float v) + { + return v * v; + } + + static DifferentialPair diff_f(DifferentialPair dpa, DifferentialPair v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } + + [BackwardDerivative(diff_g)] + float g(float v) + { + return v * v; + } + + static void diff_g(inout DifferentialPair dpa, inout DifferentialPair v, float dOut) + { + v = diffPair(v.p, dOut * 2.0); + } +} + +[ForwardDifferentiable] +float test(A obj, float v) +{ + return obj.f(v); +} + +[BackwardDifferentiable] +float test2(A obj, float v) +{ + return obj.g(v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = fwd_diff(test)(diffPair(a, {1.0}), p); + + var q = diffPair(3.0); + var qa = diffPair(a); + bwd_diff(test2)(qa, q, 1.0); + + outputBuffer[0] = rs.d; + outputBuffer[1] = q.d; +} diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..1bb28547d --- /dev/null +++ b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +6.000000 +2.000000 \ No newline at end of file diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang b/tests/autodiff/member-func-extension-custom-derivative.slang new file mode 100644 index 000000000..8752dfff5 --- /dev/null +++ b/tests/autodiff/member-func-extension-custom-derivative.slang @@ -0,0 +1,55 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct A +{ + float x; + + float f(float v) + { + return v * v; + } +} + +extension A +{ + [ForwardDerivativeOf(f)] + DifferentialPair diff_f(DifferentialPair v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } +} + +struct Foo +{ + T value; + T doThing() { return value; } +} + +extension Foo +{ + [ForwardDerivativeOf(doThing)] + DifferentialPair diff_doThing() + { + return diffPair(value, T.dzero()); + } +} + + +[ForwardDifferentiable] +float test(Foo obj, float v) +{ + return obj.doThing() * v; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Foo a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = __fwd_diff(test)(a, p); + outputBuffer[0] = rs.d; +} diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..4b1f4c0d9 --- /dev/null +++ b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +0.000000 diff --git a/tests/autodiff/static-func-custom-derivative.slang b/tests/autodiff/static-func-custom-derivative.slang new file mode 100644 index 000000000..b75012735 --- /dev/null +++ b/tests/autodiff/static-func-custom-derivative.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct A : IDifferentiable +{ + float x; + + [ForwardDerivative(diff_f)] + static float f(float v) + { + return v * v; + } + + static DifferentialPair diff_f(DifferentialPair v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } + + [BackwardDerivative(diff_g)] + static float g(float v) + { + return v * v; + } + + static void diff_g(inout DifferentialPair v, float.Differential dOut) + { + v = diffPair(v.p, dOut * 2.0); + } +} + +[ForwardDifferentiable] +float test(A obj, float v) +{ + return obj.f(v); +} + +[BackwardDifferentiable] +float test2(A obj, float v) +{ + return obj.g(v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = fwd_diff(test)(diffPair(a, {1.0}), p); + + var q = diffPair(3.0); + var qa = diffPair(a); + bwd_diff(test2)(qa, q, 1.0); + + outputBuffer[0] = rs.d; + outputBuffer[1] = q.d; +} diff --git a/tests/autodiff/static-func-custom-derivative.slang.expected.txt b/tests/autodiff/static-func-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..1bb28547d --- /dev/null +++ b/tests/autodiff/static-func-custom-derivative.slang.expected.txt @@ -0,0 +1,3 @@ +type: float +6.000000 +2.000000 \ No newline at end of file -- cgit v1.2.3