summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorwinmad <winmad.wlf@gmail.com>2023-02-07 22:30:00 -0800
committerGitHub <noreply@github.com>2023-02-07 22:30:00 -0800
commitb1d7dc0707406da69d94f3915fa48f39020b93ab (patch)
treebb571884616aaf79fa663aec1eb8fcf92f5138cc
parent4be623c52a6518eb86756a0369706c1d6670f6bb (diff)
Add backward derivatives for functions in diff.meta.slang (#2633)
* WIP: start adding backward derivatives * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * WIP: working on detach * Arithmetic simplifications and more IR clean up logic. * WIP: adding detach and abs * Fix detach and abs * Fix. * Add IR transform pass for cleaner code emit. * Fix test cases. * Fix type system logic for reference type. * Add backward derivatives for functions that already have forward derivatives * Fix changes --------- Co-authored-by: Yong He <yhe@nvidia.com> Co-authored-by: Lifan Wu <lifanw@nvidia.com>
-rw-r--r--source/slang/diff.meta.slang241
-rw-r--r--source/slang/hlsl.meta.slang4
-rw-r--r--tests/autodiff-dstdlib/dstdlib-abs.slang28
-rw-r--r--tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt6
-rw-r--r--tests/autodiff-dstdlib/dstdlib-detach.slang33
-rw-r--r--tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt3
-rw-r--r--tests/autodiff-dstdlib/dstdlib-dot.slang29
-rw-r--r--tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt7
-rw-r--r--tests/autodiff-dstdlib/dstdlib-log.slang37
-rw-r--r--tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt4
-rw-r--r--tests/autodiff-dstdlib/dstdlib-max.slang20
-rw-r--r--tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt4
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang30
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt9
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang35
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt13
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang35
-rw-r--r--tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt13
-rw-r--r--tests/autodiff-dstdlib/dstdlib-pow.slang20
-rw-r--r--tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt6
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang17
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt4
-rw-r--r--tests/autodiff-dstdlib/dstdlib-unary.slang18
-rw-r--r--tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt3
-rw-r--r--tests/autodiff/reverse-inout-param.slang2
25 files changed, 581 insertions, 40 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 055c44135..af06e6bac 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -173,6 +173,27 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen
return DifferentialPair<vector<T,M>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut)
+{
+ vector<T, N>.Differential left_d_result;
+ matrix<T, N, M>.Differential right_d_result;
+ for (int i = 0; i < N; ++i)
+ {
+ T sum = T(0);
+ for (int j = 0; j < M; ++j)
+ {
+ sum += right.p[i][j] * dOut[j];
+ right_d_result[i][j] = left.p[i] * dOut[j];
+ }
+ left_d_result[i] = sum;
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
+
// matrix-vector
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
@@ -184,18 +205,68 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
return DifferentialPair<vector<T,N>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut)
+{
+ matrix<T, N, M>.Differential left_d_result;
+ vector<T, M>.Differential right_d_result;
+ for (int j = 0; j < M; ++j)
+ {
+ T sum = T(0);
+ for (int i = 0; i < N; ++i)
+ {
+ sum += left.p[i][j] * dOut[i];
+ left_d_result[i][j] = right.p[j] * dOut[i];
+ }
+ right_d_result[j] = sum;
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
// matrix-matrix
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
-DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left)
+DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right)
{
- let primal = mul(right.p, left.p);
- let diff = mul(right.d, left.p) + mul(right.p, left.d);
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
return DifferentialPair<matrix<T,R,C>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut)
+{
+ matrix<T, R, N>.Differential left_d_result;
+ for (int r = 0; r < R; ++r)
+ for (int n = 0; n < N; ++n)
+ left_d_result[r][n] = T(0.0);
+
+ matrix<T, N, C>.Differential right_d_result;
+ for (int n = 0; n < N; ++n)
+ for (int c = 0; c < C; ++c)
+ right_d_result[n][c] = T(0.0);
+
+ for (int r = 0; r < R; ++r)
+ {
+ for (int c = 0; c < C; ++c)
+ {
+ for (int n = 0; n < N; ++n)
+ {
+ left_d_result[r][n] += right.p[n][c] * dOut[r][c];
+ right_d_result[n][c] += left.p[r][n] * dOut[r][c];
+ }
+ }
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
+
#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
@@ -207,7 +278,6 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe
} \
return DifferentialPair<vector<TYPE, COUNT>>(result, d_result)
-
#define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
@@ -220,6 +290,28 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe
} \
return DifferentialPair<vector<TYPE, COUNT>>(result, d_result)
+#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \
+ vector<TYPE, COUNT>.Differential d_result; \
+ for (int i = 0; i < N; ++i) \
+ { \
+ DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \
+ D_FUNC(dp_elem, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \
+ d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \
+ } \
+ VALUE = diffPair(VALUE.p, d_result)
+
+#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \
+ vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \
+ for (int i = 0; i < N; ++i) \
+ { \
+ DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \
+ DifferentialPair<TYPE> right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \
+ D_FUNC(left_dp, right_dp, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \
+ left_d_result[i] = __slang_noop_cast<TYPE>(left_dp.d); \
+ right_d_result[i] = __slang_noop_cast<TYPE>(right_dp.d); \
+ } \
+ LEFT = diffPair(LEFT.p, left_d_result); \
+ RIGHT = diffPair(RIGHT.p, right_d_result)
// Detach and set derivatives to zero
@@ -240,6 +332,20 @@ DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>>
VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(detach)]
+void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(detach)]
+void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ dpx = diffPair(dpx.p, vector<T, N>.dzero());
+}
+
// Natural Exponent
__generic<T : __BuiltinFloatingPointType>
@@ -295,6 +401,22 @@ DifferentialPair<vector<T, N>> __d_abs_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(abs)]
+void __d_abs(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(abs)]
+void __d_abs_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut);
+}
+
// Sine
__generic<T : __BuiltinFloatingPointType>
@@ -386,6 +508,20 @@ DifferentialPair<vector<T, N>> __d_log_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_log, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(log)]
+void __d_log(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(log)]
+void __d_log_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut);
+}
+
// Square root
__generic<T : __BuiltinFloatingPointType>
@@ -412,6 +548,30 @@ DifferentialPair<vector<T, N>> __d_sqrt_vector(DifferentialPair<vector<T, N>> dp
VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(sqrt)]
+void __d_sqrt(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ // Special case
+ if (dpx.p < T(1e-6))
+ {
+ dpx = diffPair(dpx.p, T.dzero());
+ }
+ else
+ {
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(T(0.5) / sqrt(dpx.p), dOut));
+ }
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(sqrt)]
+void __d_sqrt_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut);
+}
+
// Maximum
__generic<T : __BuiltinFloatingPointType>
@@ -431,6 +591,21 @@ DifferentialPair<vector<T, N>> __d_max_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(max)]
+void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
+ dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(max)]
+void __d_max_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut);
+}
+
// Minimum
__generic<T : __BuiltinFloatingPointType>
@@ -450,6 +625,21 @@ DifferentialPair<vector<T, N>> __d_min_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(min)]
+void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
+ dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(min)]
+void __d_min_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut);
+}
+
// Raise to a power
__generic<T : __BuiltinFloatingPointType>
@@ -478,6 +668,35 @@ DifferentialPair<vector<T, N>> __d_pow_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(pow)]
+void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ // Special case
+ if (dpx.p < T(1e-6))
+ {
+ dpx = diffPair(dpx.p, T.dzero());
+ dpy = diffPair(dpy.p, T.dzero());
+ }
+ else
+ {
+ T val = pow(dpx.p, dpy.p);
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(val * dpy.p / dpx.p, dOut));
+ dpy = diffPair(
+ dpy.p,
+ T.dmul(val * log(dpx.p), dOut));
+ }
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(pow)]
+void __d_pow_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut);
+}
+
// Vector dot product
__generic<T : __BuiltinFloatingPointType, let N : int>
@@ -494,3 +713,17 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
}
return DifferentialPair<T>(result, d_result);
}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(dot)]
+void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
+{
+ vector<T, N>.Differential x_d_result, y_d_result;
+ for (int i = 0; i < N; ++i)
+ {
+ x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut);
+ y_d_result[i] = dpx.p[i] * __slang_noop_cast<T>(dOut);
+ }
+ dpx = diffPair(dpx.p, x_d_result);
+ dpy = diffPair(dpy.p, y_d_result);
+}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 6c7a2c1d2..306e0dbb9 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -781,13 +781,13 @@ T detach(T x)
__generic<T : __BuiltinFloatingPointType, let N : int>
vector<T, N> detach(vector<T, N> x)
{
- VECTOR_MAP_UNARY(T, N, detach, x);
+ return x;
}
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
matrix<T, N, M> detach(matrix<T, N, M> x)
{
- MATRIX_MAP_UNARY(T, N, M, detach, x);
+ return x;
}
// Absolute value (HLSL SM 1.0)
diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang b/tests/autodiff-dstdlib/dstdlib-abs.slang
index 0da2de4c7..b7988f573 100644
--- a/tests/autodiff-dstdlib/dstdlib-abs.slang
+++ b/tests/autodiff-dstdlib/dstdlib-abs.slang
@@ -1,21 +1,21 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;
typedef DifferentialPair<float3> dpfloat3;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float diffAbs(float x)
{
return abs(x);
}
-[ForwardDifferentiable]
-float3 diffAbs3(float3 x)
+[BackwardDifferentiable]
+float3 diffAbs(float3 x)
{
return abs(x);
}
@@ -32,12 +32,26 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(-2.f, 3.f, -1.f));
- dpfloat3 res = __fwd_diff(diffAbs3)(dpx);
+ dpfloat3 res = __fwd_diff(diffAbs)(dpx);
outputBuffer[2] = res.p[0]; // Expect: 2.000000
outputBuffer[3] = res.d[0]; // Expect: -2.000000
outputBuffer[4] = res.p[1]; // Expect: 3.000000
outputBuffer[5] = res.d[1]; // Expect: -3.000000
- outputBuffer[6] = res.p[2]; // Expect: 1.000000
- outputBuffer[7] = res.d[2]; // Expect: 1.000000
+ outputBuffer[6] = res.p[2]; // Expect: 1.000000
+ outputBuffer[7] = res.d[2]; // Expect: 1.000000
+ }
+
+ {
+ dpfloat dpx = dpfloat(-3.0, 0.0);
+ __bwd_diff(diffAbs)(dpx, 2.0);
+ outputBuffer[8] = dpx.d; // Expect: -2.000000
+ }
+
+ {
+ dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(0.f, 0.f, 0.f));
+ __bwd_diff(diffAbs)(dpx, float3(1.f, 1.f, 1.f));
+ outputBuffer[9] = dpx.d[0]; // Expect: 1.000000
+ outputBuffer[10] = dpx.d[1]; // Expect: -1.000000
+ outputBuffer[11] = dpx.d[2]; // Expect: -1.000000
}
}
diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt
index 18dcc32fe..f50d892fd 100644
--- a/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt
@@ -6,4 +6,8 @@ type: float
3.000000
-3.000000
1.000000
-1.000000 \ No newline at end of file
+1.000000
+-2.000000
+1.000000
+-1.000000
+-1.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-detach.slang b/tests/autodiff-dstdlib/dstdlib-detach.slang
index acb21a0ce..e5275b821 100644
--- a/tests/autodiff-dstdlib/dstdlib-detach.slang
+++ b/tests/autodiff-dstdlib/dstdlib-detach.slang
@@ -1,21 +1,27 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;
typedef DifferentialPair<float3> dpfloat3;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float diffDetach(float x)
{
return detach(x);
}
-[ForwardDifferentiable]
-float3 diffDetach3(float3 x)
+[BackwardDifferentiable]
+float2 diffDetach(float2 x)
+{
+ return detach(x);
+}
+
+[BackwardDifferentiable]
+float3 diffDetach(float3 x)
{
return detach(x);
}
@@ -32,12 +38,25 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(-2.f, 3.f, -1.f));
- dpfloat3 res = __fwd_diff(diffDetach3)(dpx);
+ dpfloat3 res = __fwd_diff(diffDetach)(dpx);
outputBuffer[2] = res.p[0]; // Expect: 2.000000
outputBuffer[3] = res.d[0]; // Expect: 0.000000
outputBuffer[4] = res.p[1]; // Expect: -3.000000
outputBuffer[5] = res.d[1]; // Expect: 0.000000
- outputBuffer[6] = res.p[2]; // Expect: -1.000000
- outputBuffer[7] = res.d[2]; // Expect: 0.000000
+ outputBuffer[6] = res.p[2]; // Expect: -1.000000
+ outputBuffer[7] = res.d[2]; // Expect: 0.000000
+ }
+
+ {
+ dpfloat dpx = dpfloat(-5.0, 1.0);
+ __bwd_diff(diffDetach)(dpx, 1.0);
+ outputBuffer[8] = dpx.d; // Expect: 0.000000
+ }
+
+ {
+ dpfloat2 dpx = dpfloat2(float2(1.0, -2.0), float2(1.0, 1.0));
+ __bwd_diff(diffDetach)(dpx, float2(2.0, -3.0));
+ outputBuffer[9] = dpx.d[0]; // Expect: 0.000000
+ outputBuffer[10] = dpx.d[1]; // Expect: 0.000000
}
}
diff --git a/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt
index 950affb9a..042fc9da7 100644
--- a/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt
@@ -6,4 +6,7 @@ type: float
-3.000000
0.000000
-1.000000
+0.000000
+0.000000
+0.000000
0.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-dot.slang b/tests/autodiff-dstdlib/dstdlib-dot.slang
new file mode 100644
index 000000000..4a4e2a78b
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-dot.slang
@@ -0,0 +1,29 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float3> dpfloat3;
+
+[BackwardDifferentiable]
+float diffDot(float3 x, float3 y)
+{
+ return dot(x, y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ dpfloat3 dpx = dpfloat3(float3(-0.5, 0.7, 0.8), float3(0.0, 0.0, 0.0));
+ dpfloat3 dpy = dpfloat3(float3(0.0, 0.0, 1.0), float3(0.0, 0.0, 0.0));
+ __bwd_diff(diffDot)(dpx, dpy, 1.0);
+ outputBuffer[0] = dpx.d[0]; // Expect: 0.000000
+ outputBuffer[1] = dpx.d[1]; // Expect: 0.000000
+ outputBuffer[2] = dpx.d[2]; // Expect: 1.000000
+ outputBuffer[3] = dpy.d[0]; // Expect: -0.500000
+ outputBuffer[4] = dpy.d[1]; // Expect: 0.700000
+ outputBuffer[5] = dpy.d[2]; // Expect: 0.800000
+ }
+}
diff --git a/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt
new file mode 100644
index 000000000..115036603
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt
@@ -0,0 +1,7 @@
+type: float
+0.000000
+0.000000
+1.000000
+-0.500000
+0.700000
+0.800000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-log.slang b/tests/autodiff-dstdlib/dstdlib-log.slang
new file mode 100644
index 000000000..6961d2f09
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-log.slang
@@ -0,0 +1,37 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef DifferentialPair<float2> dpfloat2;
+
+[BackwardDifferentiable]
+float diffLog(float x)
+{
+ return log(x);
+}
+
+[BackwardDifferentiable]
+float2 diffLog(float2 x)
+{
+ return log(x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(10.0, 0.0);
+ __bwd_diff(diffLog)(dpa, 3.0);
+ outputBuffer[0] = dpa.d; // Expect: 0.300000
+ }
+
+ {
+ dpfloat2 dpx = dpfloat2(float2(2.0, 5.0), float2(0.0, 0.0));
+ __bwd_diff(diffLog)(dpx, float2(1.0, 1.0));
+ outputBuffer[1] = dpx.d[0]; // Expect: 0.500000
+ outputBuffer[2] = dpx.d[1]; // Expect: 0.200000
+ }
+}
diff --git a/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt
new file mode 100644
index 000000000..a04079bb7
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt
@@ -0,0 +1,4 @@
+type: float
+0.300000
+0.500000
+0.200000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang b/tests/autodiff-dstdlib/dstdlib-max.slang
index d28b17e06..026914c8c 100644
--- a/tests/autodiff-dstdlib/dstdlib-max.slang
+++ b/tests/autodiff-dstdlib/dstdlib-max.slang
@@ -1,20 +1,20 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float diffMax(float x, float y)
{
return max(x, y);
}
-[ForwardDifferentiable]
-float2 diffMax2(float2 x, float2 y)
+[BackwardDifferentiable]
+float2 diffMax(float2 x, float2 y)
{
return max(x, y);
}
@@ -33,10 +33,20 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
dpfloat2 dpx = dpfloat2(float2(-3.0, 4.0), float2(-1.0, -1.0));
dpfloat2 dpy = dpfloat2(float2(1.0, 2.0), float2(2.0, 2.0));
- dpfloat2 res = __fwd_diff(diffMax2)(dpx, dpy);
+ dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy);
outputBuffer[2] = res.p[0]; // Expect: 1.000000
outputBuffer[3] = res.d[0]; // Expect: 2.000000
outputBuffer[4] = res.p[1]; // Expect: 4.000000
outputBuffer[5] = res.d[1]; // Expect: -1.000000
}
+
+ {
+ dpfloat2 dpx = dpfloat2(float2(2.0, 3.0), float2(0.0, 0.0));
+ dpfloat2 dpy = dpfloat2(float2(5.0, 1.0), float2(0.0, 0.0));
+ __bwd_diff(diffMax)(dpx, dpy, float2(1.0, 2.0));
+ outputBuffer[6] = dpx.d[0]; // Expect: 0.000000
+ outputBuffer[7] = dpx.d[1]; // Expect: 2.000000
+ outputBuffer[8] = dpy.d[0]; // Expect: 1.000000
+ outputBuffer[9] = dpy.d[1]; // Expect: 0.000000
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt
index cca137aea..4cc1e9533 100644
--- a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt
@@ -5,3 +5,7 @@ type: float
2.000000
4.000000
-1.000000
+0.000000
+2.000000
+1.000000
+0.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang
new file mode 100644
index 000000000..6419e92aa
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang
@@ -0,0 +1,30 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float2x2> dpmat2;
+
+[BackwardDifferentiable]
+float2x2 diffMul(float2x2 a, float2x2 b)
+{
+ return mul(a, b);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ dpmat2 dpa = dpmat2(float2x2(1.0, 2.0, 3.0, 4.0), float2x2(0.0, 0.0, 0.0, 0.0));
+ dpmat2 dpb = dpmat2(float2x2(5.0, 6.0, 7.0, 8.0), float2x2(0.0, 0.0, 0.0, 0.0));
+ float2x2 dOut = float2x2(1.0, -2.0, -3.0, 4.0);
+ __bwd_diff(diffMul)(dpa, dpb, dOut);
+ outputBuffer[0] = dpa.d[0][0]; // Expect: -7.000000
+ outputBuffer[1] = dpa.d[0][1]; // Expect: -9.000000
+ outputBuffer[2] = dpa.d[1][0]; // Expect: 9.000000
+ outputBuffer[3] = dpa.d[1][1]; // Expect: 11.000000
+ outputBuffer[4] = dpb.d[0][0]; // Expect: -8.000000
+ outputBuffer[5] = dpb.d[0][1]; // Expect: 10.000000
+ outputBuffer[6] = dpb.d[1][0]; // Expect: -10.000000
+ outputBuffer[7] = dpb.d[1][1]; // Expect: 12.000000
+}
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt
new file mode 100644
index 000000000..ae3e09265
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt
@@ -0,0 +1,9 @@
+type: float
+-7.000000
+-9.000000
+9.000000
+11.000000
+-8.000000
+10.000000
+-10.000000
+12.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang
new file mode 100644
index 000000000..23ec9cabb
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang
@@ -0,0 +1,35 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float3> dpfloat3;
+typedef DifferentialPair<float3x3> dpmat3;
+
+[BackwardDifferentiable]
+float3 diffMul(float3x3 m, float3 v)
+{
+ return mul(m, v);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ dpfloat3 dpv = dpfloat3(float3(0.5, 1.2, -0.8), float3(0.0, 0.0, 0.0));
+ dpmat3 dpm = dpmat3(float3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0),
+ float3x3(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
+ __bwd_diff(diffMul)(dpm, dpv, float3(1.0, 10.0, 100.0));
+ outputBuffer[0] = dpm.d[0][0]; // Expect: 0.500000
+ outputBuffer[1] = dpm.d[0][1]; // Expect: 1.200000
+ outputBuffer[2] = dpm.d[0][2]; // Expect: -0.800000
+ outputBuffer[3] = dpm.d[1][0]; // Expect: 5.000000
+ outputBuffer[4] = dpm.d[1][1]; // Expect: 12.000000
+ outputBuffer[5] = dpm.d[1][2]; // Expect: -8.000000
+ outputBuffer[6] = dpm.d[2][0]; // Expect: 50.000000
+ outputBuffer[7] = dpm.d[2][1]; // Expect: 120.000000
+ outputBuffer[8] = dpm.d[2][2]; // Expect: -80.000000
+ outputBuffer[9] = dpv.d[0]; // Expect: 741.000000
+ outputBuffer[10] = dpv.d[1]; // Expect: 852.000000
+ outputBuffer[11] = dpv.d[2]; // Expect: 963.000000
+}
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt
new file mode 100644
index 000000000..909dbaa03
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt
@@ -0,0 +1,13 @@
+type: float
+0.500000
+1.200000
+-0.800000
+5.000000
+12.000000
+-8.000000
+50.000000
+120.000000
+-80.000000
+741.000000
+852.000000
+963.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang
new file mode 100644
index 000000000..a4e86091a
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang
@@ -0,0 +1,35 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float3> dpfloat3;
+typedef DifferentialPair<float3x3> dpmat3;
+
+[BackwardDifferentiable]
+float3 diffMul(float3 v, float3x3 m)
+{
+ return mul(v, m);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ dpfloat3 dpv = dpfloat3(float3(0.5, 1.2, -0.8), float3(0.0, 0.0, 0.0));
+ dpmat3 dpm = dpmat3(float3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0),
+ float3x3(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
+ __bwd_diff(diffMul)(dpv, dpm, float3(1.0, 10.0, 100.0));
+ outputBuffer[0] = dpv.d[0]; // Expect: 321.000000
+ outputBuffer[1] = dpv.d[1]; // Expect: 654.000000
+ outputBuffer[2] = dpv.d[2]; // Expect: 987.000000
+ outputBuffer[3] = dpm.d[0][0]; // Expect: 0.500000
+ outputBuffer[4] = dpm.d[0][1]; // Expect: 5.000000
+ outputBuffer[5] = dpm.d[0][2]; // Expect: 50.000000
+ outputBuffer[6] = dpm.d[1][0]; // Expect: 1.200000
+ outputBuffer[7] = dpm.d[1][1]; // Expect: 12.000000
+ outputBuffer[8] = dpm.d[1][2]; // Expect: 120.000000
+ outputBuffer[9] = dpm.d[2][0]; // Expect: -0.800000
+ outputBuffer[10] = dpm.d[2][1]; // Expect: -8.000000
+ outputBuffer[11] = dpm.d[2][2]; // Expect: -80.000000
+}
diff --git a/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt
new file mode 100644
index 000000000..b4583a653
--- /dev/null
+++ b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt
@@ -0,0 +1,13 @@
+type: float
+321.000000
+654.000000
+987.000000
+0.500000
+5.000000
+50.000000
+1.200000
+12.000000
+120.000000
+-0.800000
+-8.000000
+-80.000000 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-pow.slang b/tests/autodiff-dstdlib/dstdlib-pow.slang
index 66f919b03..a1e2b9eda 100644
--- a/tests/autodiff-dstdlib/dstdlib-pow.slang
+++ b/tests/autodiff-dstdlib/dstdlib-pow.slang
@@ -1,21 +1,21 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;
typedef DifferentialPair<float3> dpfloat3;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float diffPow(float x, float y)
{
return pow(x, y);
}
-[ForwardDifferentiable]
-float2 diffPow2(float2 x, float2 y)
+[BackwardDifferentiable]
+float2 diffPow(float2 x, float2 y)
{
return pow(x, y);
}
@@ -34,10 +34,20 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
dpfloat2 dpx = dpfloat2(float2(10.0, 3.0), float2(0.0, -2.0));
dpfloat2 dpy = dpfloat2(float2(2.5, 4.0), float2(1.0, 1.0));
- dpfloat2 res = __fwd_diff(diffPow2)(dpx, dpy);
+ dpfloat2 res = __fwd_diff(diffPow)(dpx, dpy);
outputBuffer[2] = res.p[0]; // Expect: 316.227722
outputBuffer[3] = res.d[0]; // Expect: 728.141235
outputBuffer[4] = res.p[1]; // Expect: 81.000000
outputBuffer[5] = res.d[1]; // Expect: -127.012398
}
+
+ {
+ dpfloat2 dpx = dpfloat2(float2(2.0, 4.0), float2(0.0, 0.0));
+ dpfloat2 dpy = dpfloat2(float2(3.0, -2.0), float2(0.0, 0.0));
+ __bwd_diff(diffPow)(dpx, dpy, float2(1.0, 1.0));
+ outputBuffer[6] = dpx.d[0]; // Expect: 12.000000
+ outputBuffer[7] = dpx.d[1]; // Expect: -0.031250
+ outputBuffer[8] = dpy.d[0]; // Expect: 5.545177
+ outputBuffer[9] = dpy.d[1]; // Expect: 0.086643
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt
index 8ac728ccf..f1bc71090 100644
--- a/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt
@@ -4,4 +4,8 @@ type: float
316.227722
728.141235
81.000000
--127.012398 \ No newline at end of file
+-127.012398
+12.000000
+-0.031250
+5.545177
+0.086643 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
index 9ef69ea34..15573c4ef 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
@@ -1,19 +1,25 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef DifferentialPair<float2> dpfloat2;
typedef DifferentialPair<float3> dpfloat3;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float diffSqrt(float x)
{
return sqrt(x);
}
+[BackwardDifferentiable]
+float2 diffSqrt(float2 x)
+{
+ return sqrt(x);
+}
+
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
@@ -37,4 +43,11 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[4] = res.p; // Expect: 0.000000
outputBuffer[5] = res.d; // Expect: 0.000000
}
+
+ {
+ dpfloat2 dpx = dpfloat2(float2(10.0, 3.0), float2(0.0, 0.0));
+ __bwd_diff(diffSqrt)(dpx, float2(1.0, 2.0));
+ outputBuffer[6] = dpx.d[0]; // Expect: 0.158114
+ outputBuffer[7] = dpx.d[1]; // Expect: 0.577350
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
index 887fd13ca..fe6487fef 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
@@ -4,4 +4,6 @@ type: float
10.000000
-0.150000
0.000000
-0.000000 \ No newline at end of file
+0.000000
+0.158114
+0.577350 \ No newline at end of file
diff --git a/tests/autodiff-dstdlib/dstdlib-unary.slang b/tests/autodiff-dstdlib/dstdlib-unary.slang
index 4d869996d..ea1bd3d2b 100644
--- a/tests/autodiff-dstdlib/dstdlib-unary.slang
+++ b/tests/autodiff-dstdlib/dstdlib-unary.slang
@@ -1,24 +1,24 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float f(float x)
{
return exp(x);
}
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float g(float x)
{
return sin(x);
}
-[ForwardDifferentiable]
+[BackwardDifferentiable]
float h(float x)
{
return cos(x);
@@ -37,4 +37,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[4] = h(dpa.p); // Expect: -0.416146
outputBuffer[5] = __fwd_diff(h)(dpa).d; // Expect: -0.909297
}
+
+ {
+ dpfloat dpa = dpfloat(2.0, 0.0);
+ __bwd_diff(f)(dpa, 1.0);
+ outputBuffer[6] = dpa.d; // Expect: 7.389056
+ __bwd_diff(g)(dpa, 1.0);
+ outputBuffer[7] = dpa.d; // Expect: -0.416146
+ __bwd_diff(h)(dpa, 1.0);
+ outputBuffer[8] = dpa.d; // Expect: -0.909297
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt
index 2e7338f5c..8b04eced8 100644
--- a/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt
@@ -5,3 +5,6 @@ type: float
-0.416147
-0.416147
-0.909297
+7.389056
+-0.416147
+-0.909297
diff --git a/tests/autodiff/reverse-inout-param.slang b/tests/autodiff/reverse-inout-param.slang
index 7d7f4cb05..de0d8f7ed 100644
--- a/tests/autodiff/reverse-inout-param.slang
+++ b/tests/autodiff/reverse-inout-param.slang
@@ -44,4 +44,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
var refVal = __fwd_diff(f_ref)(3.0, diffPair(2.0, 1.0)).d;
outputBuffer[2] = refVal; // 3024
-} \ No newline at end of file
+}