summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-24 19:50:51 -0400
committerGitHub <noreply@github.com>2023-03-24 16:50:51 -0700
commit7292edbd3eba3da7e8490ad19169a7d18283057a (patch)
treeb49fb1ba6a76d9775f788057d91b22b88b4fc19c
parente794de0d63e6de9be564c971fd40486ecf631293 (diff)
Added `[BackwardDifferentiable]` tags for intrinsic + builtin methods (#2732)
* Added higher-order differentiability decorators for built-ins + preliminary tests * Update diff.meta.slang
-rw-r--r--source/slang/diff.meta.slang58
-rw-r--r--tests/autodiff/high-order-builtins-1.slang47
-rw-r--r--tests/autodiff/high-order-builtins-1.slang.expected.txt5
-rw-r--r--tests/autodiff/high-order-builtins-2.slang47
-rw-r--r--tests/autodiff/high-order-builtins-2.slang.expected.txt5
5 files changed, 157 insertions, 5 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 26a673512..bbe94dbc2 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -292,6 +292,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define VECTOR_MATRIX_BINARY_DIFF_IMPL(NAME) \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector( \
DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) \
@@ -309,6 +310,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \
@@ -327,6 +329,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<matrix<T, M, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, \
@@ -346,6 +349,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpy = diffPair(dpy.p, right_d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> dpx, \
@@ -368,6 +372,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector( \
DifferentialPair<vector<T, N>> dpx, \
@@ -388,8 +393,9 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
- DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
DifferentialPair<matrix<T, M, N>> dpx, \
DifferentialPair<matrix<T, M, N>> dpy, \
DifferentialPair<matrix<T, M, N>> dpz) \
@@ -409,6 +415,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<matrix<T, M, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, \
@@ -433,6 +440,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpz = diffPair(dpz.p, right_d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> dpx, \
@@ -460,12 +468,14 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
__generic<T : __BuiltinFloatingPointType> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \
{ \
return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \
{ \
@@ -481,6 +491,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpx) \
{ \
@@ -498,12 +509,14 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<matrix<T, M, N>>(result, d_result); \
} \
__generic<T : __BuiltinFloatingPointType> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \
{ \
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
@@ -518,6 +531,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, d_result); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
inout DifferentialPair<matrix<T, M, N>> dpx, matrix<T, M, N>.Differential dOut) \
@@ -581,6 +595,7 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(atan, T(1.0) / (T(1.0) + dpx.p * dpx.p))
// Atan2
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(atan2)]
DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
{
@@ -592,6 +607,7 @@ DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(atan2)]
void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut)
{
@@ -603,12 +619,14 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2)
// fmod
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(fmod)]
DifferentialPair<T> __d_fmod(DifferentialPair<T> x, DifferentialPair<T> y)
{
return DifferentialPair<T>(fmod(x.p, y.p), x.d);
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(fmod)]
void __d_fmod(inout DifferentialPair<T> x, inout DifferentialPair<T> y, T.Differential dOut)
{
@@ -619,6 +637,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(fmod)
// Raise to a power
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(pow)]
DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -638,6 +657,7 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(pow)]
void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -663,6 +683,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(pow)
// Maximum
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(max)]
DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -673,6 +694,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(max)]
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -684,6 +706,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
// Minimum
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(min)]
DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
@@ -694,6 +717,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(min)]
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
@@ -705,6 +729,7 @@ VECTOR_MATRIX_BINARY_DIFF_IMPL(min)
// Lerp
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(lerp)]
DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dps)
{
@@ -714,6 +739,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D
);
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(lerp)]
void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut)
{
@@ -725,6 +751,7 @@ VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp)
// Clamp
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(clamp)]
DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin, DifferentialPair<T> dpMax)
{
@@ -733,6 +760,7 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
dpx.p < dpMin.p ? (dpx.p > dpMax.p ? dpMax.d : dpx.d) : dpMin.d);
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(clamp)]
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
{
@@ -743,6 +771,7 @@ void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, i
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)
// fma
+[BackwardDifferentiable]
[ForwardDerivativeOf(fma)]
DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz)
{
@@ -750,6 +779,7 @@ DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<
fma(dpx.p, dpy.p, dpz.p),
dpy.p * dpx.d + dpx.p * dpy.d + dpz.d);
}
+[BackwardDifferentiable]
[BackwardDerivativeOf(fma)]
void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut)
{
@@ -758,6 +788,7 @@ void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double>
dpz = diffPair(dpz.p, dOut);
}
__generic<let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(fma)]
DifferentialPair<vector<double, N>> __d_fma_vector(
DifferentialPair<vector<double, N>> dpx,
@@ -778,6 +809,7 @@ DifferentialPair<vector<double, N>> __d_fma_vector(
return DifferentialPair<vector<double, N>>(result, d_result);
}
__generic<let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(fma)]
void __d_fma_vector(
inout DifferentialPair<vector<double, N>> dpx,
@@ -803,6 +835,7 @@ void __d_fma_vector(
// mad
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(mad)]
DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz)
{
@@ -811,6 +844,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di
T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d));
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(mad)]
void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut)
{
@@ -829,12 +863,14 @@ T __smoothstep_impl(T minVal, T maxVal, T x)
return t * t * (T(3.0) - T(2.0) * t);
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[ForwardDerivativeOf(smoothstep)]
DifferentialPair<T> __d_smoothstep(DifferentialPair<T> minVal, DifferentialPair<T> maxVal, DifferentialPair<T> x)
{
return __fwd_diff(__smoothstep_impl)(minVal, maxVal, x);
}
__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
[BackwardDerivativeOf(smoothstep)]
void __d_smoothstep(inout DifferentialPair<T> minVal, inout DifferentialPair<T> maxVal, inout DifferentialPair<T> x, T.Differential dOut)
{
@@ -856,6 +892,7 @@ T __length_impl(vector<T, N> x)
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(length)]
[ForceInline]
DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x)
@@ -864,6 +901,7 @@ DifferentialPair<T> __d_length(DifferentialPair<vector<T, N>> x)
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(length)]
[ForceInline]
void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut)
@@ -872,13 +910,14 @@ void __d_length(inout DifferentialPair<vector<T, N>> x, T.Differential dOut)
}
// Vector distance
-__generic<T : __BuiltinFloatingPointType, let N : int>
+__generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
T __distance_impl(vector<T, N> x, vector<T, N> y)
{
return length(y - x);
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(distance)]
[ForceInline]
DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialPair<vector<T, N>> y)
@@ -887,6 +926,7 @@ DifferentialPair<T> __d_distance(DifferentialPair<vector<T, N>> x, DifferentialP
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(distance)]
[ForceInline]
void __d_distance(inout DifferentialPair<vector<T, N>> x, inout DifferentialPair<vector<T, N>> y, T.Differential dOut)
@@ -903,13 +943,15 @@ vector<T, N> __normalize_impl(vector<T, N> x)
return x * r;
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(normalize)]
[ForceInline]
DifferentialPair<vector<T, N>> __d_normalize(DifferentialPair<vector<T, N>> x)
{
return __fwd_diff(__normalize_impl)(x);
}
-__generic<T: __BuiltinFloatingPointType, let N : int>
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(normalize)]
[ForceInline]
void __d_distance(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dOut)
@@ -924,14 +966,16 @@ vector<T, N> __reflect_impl(vector<T, N> i, vector<T, N> n)
{
return i - n * (T(2.0) * dot(i, n));
}
-__generic<T: __BuiltinFloatingPointType, let N : int>
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(reflect)]
[ForceInline]
DifferentialPair<vector<T, N>> __d_reflect(DifferentialPair<vector<T, N>> i, DifferentialPair<vector<T, N>> n)
{
return __fwd_diff(__reflect_impl)(i, n);
}
-__generic<T: __BuiltinFloatingPointType, let N : int>
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(reflect)]
[ForceInline]
void __d_reflect(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<vector<T, N>> n, vector<T, N>.Differential dOut)
@@ -948,6 +992,7 @@ vector<T, N> __refract_impl(vector<T, N> i, vector<T, N> n, T eta)
return (k < T(0.0)) ? vector<T, N>(T(0.0)) : eta * i - (eta * dot(n, i) + sqrt(max(T(0.0),k))) * n;
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(refract)]
[ForceInline]
DifferentialPair<vector<T, N>> __d_refract(DifferentialPair<vector<T, N>> i, DifferentialPair<vector<T, N>> n, DifferentialPair<T> eta)
@@ -955,6 +1000,7 @@ DifferentialPair<vector<T, N>> __d_refract(DifferentialPair<vector<T, N>> i, Dif
return __fwd_diff(__refract_impl)(i, n, eta);
}
__generic<T: __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(refract)]
[ForceInline]
void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<vector<T, N>> n, inout DifferentialPair<T> eta, vector<T, N>.Differential dOut)
@@ -1053,6 +1099,7 @@ T __determinant_impl(matrix<T,N,N> m)
return result;
}
__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[ForwardDerivativeOf(determinant)]
[ForceInline]
DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m)
@@ -1060,6 +1107,7 @@ DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m)
return __fwd_diff(__determinant_impl)(m);
}
__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
[BackwardDerivativeOf(determinant)]
[ForceInline]
void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut)
diff --git a/tests/autodiff/high-order-builtins-1.slang b/tests/autodiff/high-order-builtins-1.slang
new file mode 100644
index 000000000..6b0c33ca6
--- /dev/null
+++ b/tests/autodiff/high-order-builtins-1.slang
@@ -0,0 +1,47 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float f(float x)
+{
+ return x * x;
+}
+
+[BackwardDifferentiable]
+float outerF(float x)
+{
+ return f(sin(x));
+}
+
+[BackwardDifferentiable]
+float df(float x)
+{
+ return __fwd_diff(outerF)(DifferentialPair<float>(x, 1.0)).d; // 4*sin^3(x)
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Given f(x) = sin^2(x),
+ // f'(x) = 2*sin(x)*cos(x)
+ // f''(x) = 2*cos^2(x) - 2*sin^2(x)
+ //
+
+ // Expect f''(4) = -0.291
+ {
+ var p = diffPair(4.0, 0.0);
+ __bwd_diff(df)(p, 1.0);
+ outputBuffer[0] = p.d;
+ }
+
+ // Expect f''(4) = -0.653643
+ {
+ var p = diffPair(2.0, 0.0);
+ __bwd_diff(df)(p, 0.5);
+ outputBuffer[1] = p.d;
+ }
+}
diff --git a/tests/autodiff/high-order-builtins-1.slang.expected.txt b/tests/autodiff/high-order-builtins-1.slang.expected.txt
new file mode 100644
index 000000000..4fa4ade6d
--- /dev/null
+++ b/tests/autodiff/high-order-builtins-1.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+-0.291000
+-0.653643
+0.000000
+0.000000
diff --git a/tests/autodiff/high-order-builtins-2.slang b/tests/autodiff/high-order-builtins-2.slang
new file mode 100644
index 000000000..b15dee7c7
--- /dev/null
+++ b/tests/autodiff/high-order-builtins-2.slang
@@ -0,0 +1,47 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float f(float x)
+{
+ return sin(x) + x * x;
+}
+
+[BackwardDifferentiable]
+float outerF(float x)
+{
+ return f(pow(x, 3));
+}
+
+[BackwardDifferentiable]
+float df(float x)
+{
+ return __fwd_diff(outerF)(DifferentialPair<float>(x, 1.0)).d; // x^3 + x^6
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Given f(x) = sin(x^3) + x^6
+ // f'(x) = cos(x^3)*(3*x^2) + 6*x^5
+ // f''(x) = 6*cos(x^3)*x - 9*sin(x^3)*(x^4) + 30*x^4
+ //
+
+ // Expect f''(4) = -0.291
+ {
+ var p = diffPair(1.0, 0.0);
+ __bwd_diff(df)(p, 1.0);
+ outputBuffer[0] = p.d;
+ }
+
+ // Expect f''(4) = -0.653643
+ {
+ var p = diffPair(2.0, 0.0);
+ __bwd_diff(df)(p, 0.5);
+ outputBuffer[1] = p.d;
+ }
+}
diff --git a/tests/autodiff/high-order-builtins-2.slang.expected.txt b/tests/autodiff/high-order-builtins-2.slang.expected.txt
new file mode 100644
index 000000000..9f5d74b29
--- /dev/null
+++ b/tests/autodiff/high-order-builtins-2.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+25.668574
+167.893206
+0.000000
+0.000000