From c6e6b7a9177bf4f7fc2f05da36c5952979006d78 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 4 Nov 2022 09:36:23 -0700 Subject: Higher order differentiation. (#2487) Co-authored-by: Yong He --- tests/autodiff/generic-impl-jvp.slang | 6 +++--- tests/autodiff/getter-setter-multi.slang | 2 +- tests/autodiff/getter-setter.slang | 2 +- tests/autodiff/high-order-forward-diff.slang | 17 +++++++++++++---- .../autodiff/high-order-forward-diff.slang.expected.txt | 5 +++++ 5 files changed, 23 insertions(+), 9 deletions(-) create mode 100644 tests/autodiff/high-order-forward-diff.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 511e0b0d8..7f4c4313e 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -8,8 +8,8 @@ typedef float Real; typealias IDFloat = IFloat & IDifferentiable; -__generic -struct dvector +__generic +struct dvector : IDifferentiable { T values[N]; }; @@ -139,7 +139,7 @@ DifferentialPair dot_jvp(dpvector a, dpvector b) } __generic -struct lineardvector +struct lineardvector : IDifferentiable { myvector.Differential val; diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 3bf208e02..85b6a3c63 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; -struct B +struct B : IDifferentiable { float3 z; float.Differential k[10]; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 5842654b5..a9e01b8c6 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; -struct B +struct B : IDifferentiable { float z; }; diff --git a/tests/autodiff/high-order-forward-diff.slang b/tests/autodiff/high-order-forward-diff.slang index fde659227..94b4d2a0d 100644 --- a/tests/autodiff/high-order-forward-diff.slang +++ b/tests/autodiff/high-order-forward-diff.slang @@ -1,15 +1,21 @@ -//DTEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//DTEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; [ForwardDifferentiable] -float f(float x) +float mySqr(float x) { return x * x; } +[ForwardDifferentiable] +float f(float x) +{ + return mySqr(x * x); +} + [ForwardDifferentiable] float df(float x) { @@ -19,5 +25,8 @@ float df(float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = __fwd_diff(df)(DifferentialPair(1.0, 1.0)).d(); // Expect: 2.0 + // Given f(x) = x^4, + // f''(x) = 12 * x^2 + // Expect f''(4) = 192 + outputBuffer[0] = __fwd_diff(df)(DifferentialPair(4.0, 1.0)).d(); } diff --git a/tests/autodiff/high-order-forward-diff.slang.expected.txt b/tests/autodiff/high-order-forward-diff.slang.expected.txt new file mode 100644 index 000000000..0f08247f0 --- /dev/null +++ b/tests/autodiff/high-order-forward-diff.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +192.000000 +0.000000 +0.000000 +0.000000 -- cgit v1.2.3