diff options
| author | winmad <winmad.wlf@gmail.com> | 2023-02-07 22:30:00 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-07 22:30:00 -0800 |
| commit | b1d7dc0707406da69d94f3915fa48f39020b93ab (patch) | |
| tree | bb571884616aaf79fa663aec1eb8fcf92f5138cc | |
| parent | 4be623c52a6518eb86756a0369706c1d6670f6bb (diff) | |
Add backward derivatives for functions in diff.meta.slang (#2633)
* WIP: start adding backward derivatives
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* WIP: working on detach
* Arithmetic simplifications and more IR clean up logic.
* WIP: adding detach and abs
* Fix detach and abs
* Fix.
* Add IR transform pass for cleaner code emit.
* Fix test cases.
* Fix type system logic for reference type.
* Add backward derivatives for functions that already have forward derivatives
* Fix changes
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Co-authored-by: Lifan Wu <lifanw@nvidia.com>
25 files changed, 581 insertions, 40 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 055c44135..af06e6bac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -173,6 +173,27 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen return DifferentialPair<vector<T,M>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut) +{ + vector<T, N>.Differential left_d_result; + matrix<T, N, M>.Differential right_d_result; + for (int i = 0; i < N; ++i) + { + T sum = T(0); + for (int j = 0; j < M; ++j) + { + sum += right.p[i][j] * dOut[j]; + right_d_result[i][j] = left.p[i] * dOut[j]; + } + left_d_result[i] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + // matrix-vector __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] @@ -184,18 +205,68 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen return DifferentialPair<vector<T,N>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut) +{ + matrix<T, N, M>.Differential left_d_result; + vector<T, M>.Differential right_d_result; + for (int j = 0; j < M; ++j) + { + T sum = T(0); + for (int i = 0; i < N; ++i) + { + sum += left.p[i][j] * dOut[i]; + left_d_result[i][j] = right.p[j] * dOut[i]; + } + right_d_result[j] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} // matrix-matrix __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [ForceInline] [ForwardDerivativeOf(mul)] -DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left) +DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right) { - let primal = mul(right.p, left.p); - let diff = mul(right.d, left.p) + mul(right.p, left.d); + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair<matrix<T,R,C>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut) +{ + matrix<T, R, N>.Differential left_d_result; + for (int r = 0; r < R; ++r) + for (int n = 0; n < N; ++n) + left_d_result[r][n] = T(0.0); + + matrix<T, N, C>.Differential right_d_result; + for (int n = 0; n < N; ++n) + for (int c = 0; c < C; ++c) + right_d_result[n][c] = T(0.0); + + for (int r = 0; r < R; ++r) + { + for (int c = 0; c < C; ++c) + { + for (int n = 0; n < N; ++n) + { + left_d_result[r][n] += right.p[n][c] * dOut[r][c]; + right_d_result[n][c] += left.p[r][n] * dOut[r][c]; + } + } + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ @@ -207,7 +278,6 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe } \ return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) - #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ @@ -220,6 +290,28 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe } \ return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) +#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ + vector<TYPE, COUNT>.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ + D_FUNC(dp_elem, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \ + d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \ + } \ + VALUE = diffPair(VALUE.p, d_result) + +#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ + vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ + DifferentialPair<TYPE> right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \ + D_FUNC(left_dp, right_dp, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \ + left_d_result[i] = __slang_noop_cast<TYPE>(left_dp.d); \ + right_d_result[i] = __slang_noop_cast<TYPE>(right_dp.d); \ + } \ + LEFT = diffPair(LEFT.p, left_d_result); \ + RIGHT = diffPair(RIGHT.p, right_d_result) // Detach and set derivatives to zero @@ -240,6 +332,20 @@ DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(detach)] +void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(detach)] +void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair(dpx.p, vector<T, N>.dzero()); +} + // Natural Exponent __generic<T : __BuiltinFloatingPointType> @@ -295,6 +401,22 @@ DifferentialPair<vector<T, N>> __d_abs_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(abs)] +void __d_abs(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(abs)] +void __d_abs_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut); +} + // Sine __generic<T : __BuiltinFloatingPointType> @@ -386,6 +508,20 @@ DifferentialPair<vector<T, N>> __d_log_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(log)] +void __d_log(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(log)] +void __d_log_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut); +} + // Square root __generic<T : __BuiltinFloatingPointType> @@ -412,6 +548,30 @@ DifferentialPair<vector<T, N>> __d_sqrt_vector(DifferentialPair<vector<T, N>> dp VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(sqrt)] +void __d_sqrt(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + } + else + { + dpx = diffPair( + dpx.p, + T.dmul(T(0.5) / sqrt(dpx.p), dOut)); + } +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(sqrt)] +void __d_sqrt_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut); +} + // Maximum __generic<T : __BuiltinFloatingPointType> @@ -431,6 +591,21 @@ DifferentialPair<vector<T, N>> __d_max_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(max)] +void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(max)] +void __d_max_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut); +} + // Minimum __generic<T : __BuiltinFloatingPointType> @@ -450,6 +625,21 @@ DifferentialPair<vector<T, N>> __d_min_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(min)] +void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(min)] +void __d_min_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut); +} + // Raise to a power __generic<T : __BuiltinFloatingPointType> @@ -478,6 +668,35 @@ DifferentialPair<vector<T, N>> __d_pow_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(pow)] +void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + dpy = diffPair(dpy.p, T.dzero()); + } + else + { + T val = pow(dpx.p, dpy.p); + dpx = diffPair( + dpx.p, + T.dmul(val * dpy.p / dpx.p, dOut)); + dpy = diffPair( + dpy.p, + T.dmul(val * log(dpx.p), dOut)); + } +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(pow)] +void __d_pow_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut); +} + // Vector dot product __generic<T : __BuiltinFloatingPointType, let N : int> @@ -494,3 +713,17 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair } return DifferentialPair<T>(result, d_result); } + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(dot)] +void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) +{ + vector<T, N>.Differential x_d_result, y_d_result; + for (int i = 0; i < N; ++i) + { + x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut); + y_d_result[i] = dpx.p[i] * __slang_noop_cast<T>(dOut); + } + dpx = diffPair(dpx.p, x_d_result); + dpy = diffPair(dpy.p, y_d_result); +} diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 6c7a2c1d2..306e0dbb9 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -781,13 +781,13 @@ T detach(T x) __generic<T : __BuiltinFloatingPointType, let N : int> vector<T, N> detach(vector<T, N> x) { - VECTOR_MAP_UNARY(T, N, detach, x); + return x; } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> matrix<T, N, M> detach(matrix<T, N, M> x) { - MATRIX_MAP_UNARY(T, N, M, detach, x); + return x; } // Absolute value (HLSL SM 1.0) diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang b/tests/autodiff-dstdlib/dstdlib-abs.slang index 0da2de4c7..b7988f573 100644 --- a/tests/autodiff-dstdlib/dstdlib-abs.slang +++ b/tests/autodiff-dstdlib/dstdlib-abs.slang @@ -1,21 +1,21 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef DifferentialPair<float2> dpfloat2; typedef DifferentialPair<float3> dpfloat3; -[ForwardDifferentiable] +[BackwardDifferentiable] float diffAbs(float x) { return abs(x); } -[ForwardDifferentiable] -float3 diffAbs3(float3 x) +[BackwardDifferentiable] +float3 diffAbs(float3 x) { return abs(x); } @@ -32,12 +32,26 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(-2.f, 3.f, -1.f)); - dpfloat3 res = __fwd_diff(diffAbs3)(dpx); + dpfloat3 res = __fwd_diff(diffAbs)(dpx); outputBuffer[2] = res.p[0]; // Expect: 2.000000 outputBuffer[3] = res.d[0]; // Expect: -2.000000 outputBuffer[4] = res.p[1]; // Expect: 3.000000 outputBuffer[5] = res.d[1]; // Expect: -3.000000 - outputBuffer[6] = res.p[2]; // Expect: 1.000000 - outputBuffer[7] = res.d[2]; // Expect: 1.000000 + outputBuffer[6] = res.p[2]; // Expect: 1.000000 + outputBuffer[7] = res.d[2]; // Expect: 1.000000 + } + + { + dpfloat dpx = dpfloat(-3.0, 0.0); + __bwd_diff(diffAbs)(dpx, 2.0); + outputBuffer[8] = dpx.d; // Expect: -2.000000 + } + + { + dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(0.f, 0.f, 0.f)); + __bwd_diff(diffAbs)(dpx, float3(1.f, 1.f, 1.f)); + outputBuffer[9] = dpx.d[0]; // Expect: 1.000000 + outputBuffer[10] = dpx.d[1]; // Expect: -1.000000 + outputBuffer[11] = dpx.d[2]; // Expect: -1.000000 } } diff --git a/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt index 18dcc32fe..f50d892fd 100644 --- a/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-abs.slang.expected.txt @@ -6,4 +6,8 @@ type: float 3.000000 -3.000000 1.000000 -1.000000
\ No newline at end of file +1.000000 +-2.000000 +1.000000 +-1.000000 +-1.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-detach.slang b/tests/autodiff-dstdlib/dstdlib-detach.slang index acb21a0ce..e5275b821 100644 --- a/tests/autodiff-dstdlib/dstdlib-detach.slang +++ b/tests/autodiff-dstdlib/dstdlib-detach.slang @@ -1,21 +1,27 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef DifferentialPair<float2> dpfloat2; typedef DifferentialPair<float3> dpfloat3; -[ForwardDifferentiable] +[BackwardDifferentiable] float diffDetach(float x) { return detach(x); } -[ForwardDifferentiable] -float3 diffDetach3(float3 x) +[BackwardDifferentiable] +float2 diffDetach(float2 x) +{ + return detach(x); +} + +[BackwardDifferentiable] +float3 diffDetach(float3 x) { return detach(x); } @@ -32,12 +38,25 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat3 dpx = dpfloat3(float3(2.f, -3.f, -1.f), float3(-2.f, 3.f, -1.f)); - dpfloat3 res = __fwd_diff(diffDetach3)(dpx); + dpfloat3 res = __fwd_diff(diffDetach)(dpx); outputBuffer[2] = res.p[0]; // Expect: 2.000000 outputBuffer[3] = res.d[0]; // Expect: 0.000000 outputBuffer[4] = res.p[1]; // Expect: -3.000000 outputBuffer[5] = res.d[1]; // Expect: 0.000000 - outputBuffer[6] = res.p[2]; // Expect: -1.000000 - outputBuffer[7] = res.d[2]; // Expect: 0.000000 + outputBuffer[6] = res.p[2]; // Expect: -1.000000 + outputBuffer[7] = res.d[2]; // Expect: 0.000000 + } + + { + dpfloat dpx = dpfloat(-5.0, 1.0); + __bwd_diff(diffDetach)(dpx, 1.0); + outputBuffer[8] = dpx.d; // Expect: 0.000000 + } + + { + dpfloat2 dpx = dpfloat2(float2(1.0, -2.0), float2(1.0, 1.0)); + __bwd_diff(diffDetach)(dpx, float2(2.0, -3.0)); + outputBuffer[9] = dpx.d[0]; // Expect: 0.000000 + outputBuffer[10] = dpx.d[1]; // Expect: 0.000000 } } diff --git a/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt index 950affb9a..042fc9da7 100644 --- a/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-detach.slang.expected.txt @@ -6,4 +6,7 @@ type: float -3.000000 0.000000 -1.000000 +0.000000 +0.000000 +0.000000 0.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-dot.slang b/tests/autodiff-dstdlib/dstdlib-dot.slang new file mode 100644 index 000000000..4a4e2a78b --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-dot.slang @@ -0,0 +1,29 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float3> dpfloat3; + +[BackwardDifferentiable] +float diffDot(float3 x, float3 y) +{ + return dot(x, y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat3 dpx = dpfloat3(float3(-0.5, 0.7, 0.8), float3(0.0, 0.0, 0.0)); + dpfloat3 dpy = dpfloat3(float3(0.0, 0.0, 1.0), float3(0.0, 0.0, 0.0)); + __bwd_diff(diffDot)(dpx, dpy, 1.0); + outputBuffer[0] = dpx.d[0]; // Expect: 0.000000 + outputBuffer[1] = dpx.d[1]; // Expect: 0.000000 + outputBuffer[2] = dpx.d[2]; // Expect: 1.000000 + outputBuffer[3] = dpy.d[0]; // Expect: -0.500000 + outputBuffer[4] = dpy.d[1]; // Expect: 0.700000 + outputBuffer[5] = dpy.d[2]; // Expect: 0.800000 + } +} diff --git a/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt new file mode 100644 index 000000000..115036603 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-dot.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +0.000000 +0.000000 +1.000000 +-0.500000 +0.700000 +0.800000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-log.slang b/tests/autodiff-dstdlib/dstdlib-log.slang new file mode 100644 index 000000000..6961d2f09 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-log.slang @@ -0,0 +1,37 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef DifferentialPair<float2> dpfloat2; + +[BackwardDifferentiable] +float diffLog(float x) +{ + return log(x); +} + +[BackwardDifferentiable] +float2 diffLog(float2 x) +{ + return log(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(10.0, 0.0); + __bwd_diff(diffLog)(dpa, 3.0); + outputBuffer[0] = dpa.d; // Expect: 0.300000 + } + + { + dpfloat2 dpx = dpfloat2(float2(2.0, 5.0), float2(0.0, 0.0)); + __bwd_diff(diffLog)(dpx, float2(1.0, 1.0)); + outputBuffer[1] = dpx.d[0]; // Expect: 0.500000 + outputBuffer[2] = dpx.d[1]; // Expect: 0.200000 + } +} diff --git a/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt new file mode 100644 index 000000000..a04079bb7 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-log.slang.expected.txt @@ -0,0 +1,4 @@ +type: float +0.300000 +0.500000 +0.200000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang b/tests/autodiff-dstdlib/dstdlib-max.slang index d28b17e06..026914c8c 100644 --- a/tests/autodiff-dstdlib/dstdlib-max.slang +++ b/tests/autodiff-dstdlib/dstdlib-max.slang @@ -1,20 +1,20 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef DifferentialPair<float2> dpfloat2; -[ForwardDifferentiable] +[BackwardDifferentiable] float diffMax(float x, float y) { return max(x, y); } -[ForwardDifferentiable] -float2 diffMax2(float2 x, float2 y) +[BackwardDifferentiable] +float2 diffMax(float2 x, float2 y) { return max(x, y); } @@ -33,10 +33,20 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat2 dpx = dpfloat2(float2(-3.0, 4.0), float2(-1.0, -1.0)); dpfloat2 dpy = dpfloat2(float2(1.0, 2.0), float2(2.0, 2.0)); - dpfloat2 res = __fwd_diff(diffMax2)(dpx, dpy); + dpfloat2 res = __fwd_diff(diffMax)(dpx, dpy); outputBuffer[2] = res.p[0]; // Expect: 1.000000 outputBuffer[3] = res.d[0]; // Expect: 2.000000 outputBuffer[4] = res.p[1]; // Expect: 4.000000 outputBuffer[5] = res.d[1]; // Expect: -1.000000 } + + { + dpfloat2 dpx = dpfloat2(float2(2.0, 3.0), float2(0.0, 0.0)); + dpfloat2 dpy = dpfloat2(float2(5.0, 1.0), float2(0.0, 0.0)); + __bwd_diff(diffMax)(dpx, dpy, float2(1.0, 2.0)); + outputBuffer[6] = dpx.d[0]; // Expect: 0.000000 + outputBuffer[7] = dpx.d[1]; // Expect: 2.000000 + outputBuffer[8] = dpy.d[0]; // Expect: 1.000000 + outputBuffer[9] = dpy.d[1]; // Expect: 0.000000 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt index cca137aea..4cc1e9533 100644 --- a/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-max.slang.expected.txt @@ -5,3 +5,7 @@ type: float 2.000000 4.000000 -1.000000 +0.000000 +2.000000 +1.000000 +0.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang new file mode 100644 index 000000000..6419e92aa --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang @@ -0,0 +1,30 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float2x2> dpmat2; + +[BackwardDifferentiable] +float2x2 diffMul(float2x2 a, float2x2 b) +{ + return mul(a, b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + dpmat2 dpa = dpmat2(float2x2(1.0, 2.0, 3.0, 4.0), float2x2(0.0, 0.0, 0.0, 0.0)); + dpmat2 dpb = dpmat2(float2x2(5.0, 6.0, 7.0, 8.0), float2x2(0.0, 0.0, 0.0, 0.0)); + float2x2 dOut = float2x2(1.0, -2.0, -3.0, 4.0); + __bwd_diff(diffMul)(dpa, dpb, dOut); + outputBuffer[0] = dpa.d[0][0]; // Expect: -7.000000 + outputBuffer[1] = dpa.d[0][1]; // Expect: -9.000000 + outputBuffer[2] = dpa.d[1][0]; // Expect: 9.000000 + outputBuffer[3] = dpa.d[1][1]; // Expect: 11.000000 + outputBuffer[4] = dpb.d[0][0]; // Expect: -8.000000 + outputBuffer[5] = dpb.d[0][1]; // Expect: 10.000000 + outputBuffer[6] = dpb.d[1][0]; // Expect: -10.000000 + outputBuffer[7] = dpb.d[1][1]; // Expect: 12.000000 +} diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt new file mode 100644 index 000000000..ae3e09265 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-mat.slang.expected.txt @@ -0,0 +1,9 @@ +type: float +-7.000000 +-9.000000 +9.000000 +11.000000 +-8.000000 +10.000000 +-10.000000 +12.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang new file mode 100644 index 000000000..23ec9cabb --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang @@ -0,0 +1,35 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float3x3> dpmat3; + +[BackwardDifferentiable] +float3 diffMul(float3x3 m, float3 v) +{ + return mul(m, v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + dpfloat3 dpv = dpfloat3(float3(0.5, 1.2, -0.8), float3(0.0, 0.0, 0.0)); + dpmat3 dpm = dpmat3(float3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0), + float3x3(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)); + __bwd_diff(diffMul)(dpm, dpv, float3(1.0, 10.0, 100.0)); + outputBuffer[0] = dpm.d[0][0]; // Expect: 0.500000 + outputBuffer[1] = dpm.d[0][1]; // Expect: 1.200000 + outputBuffer[2] = dpm.d[0][2]; // Expect: -0.800000 + outputBuffer[3] = dpm.d[1][0]; // Expect: 5.000000 + outputBuffer[4] = dpm.d[1][1]; // Expect: 12.000000 + outputBuffer[5] = dpm.d[1][2]; // Expect: -8.000000 + outputBuffer[6] = dpm.d[2][0]; // Expect: 50.000000 + outputBuffer[7] = dpm.d[2][1]; // Expect: 120.000000 + outputBuffer[8] = dpm.d[2][2]; // Expect: -80.000000 + outputBuffer[9] = dpv.d[0]; // Expect: 741.000000 + outputBuffer[10] = dpv.d[1]; // Expect: 852.000000 + outputBuffer[11] = dpv.d[2]; // Expect: 963.000000 +} diff --git a/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt new file mode 100644 index 000000000..909dbaa03 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-mat-vec.slang.expected.txt @@ -0,0 +1,13 @@ +type: float +0.500000 +1.200000 +-0.800000 +5.000000 +12.000000 +-8.000000 +50.000000 +120.000000 +-80.000000 +741.000000 +852.000000 +963.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang new file mode 100644 index 000000000..a4e86091a --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang @@ -0,0 +1,35 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float3x3> dpmat3; + +[BackwardDifferentiable] +float3 diffMul(float3 v, float3x3 m) +{ + return mul(v, m); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + dpfloat3 dpv = dpfloat3(float3(0.5, 1.2, -0.8), float3(0.0, 0.0, 0.0)); + dpmat3 dpm = dpmat3(float3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0), + float3x3(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)); + __bwd_diff(diffMul)(dpv, dpm, float3(1.0, 10.0, 100.0)); + outputBuffer[0] = dpv.d[0]; // Expect: 321.000000 + outputBuffer[1] = dpv.d[1]; // Expect: 654.000000 + outputBuffer[2] = dpv.d[2]; // Expect: 987.000000 + outputBuffer[3] = dpm.d[0][0]; // Expect: 0.500000 + outputBuffer[4] = dpm.d[0][1]; // Expect: 5.000000 + outputBuffer[5] = dpm.d[0][2]; // Expect: 50.000000 + outputBuffer[6] = dpm.d[1][0]; // Expect: 1.200000 + outputBuffer[7] = dpm.d[1][1]; // Expect: 12.000000 + outputBuffer[8] = dpm.d[1][2]; // Expect: 120.000000 + outputBuffer[9] = dpm.d[2][0]; // Expect: -0.800000 + outputBuffer[10] = dpm.d[2][1]; // Expect: -8.000000 + outputBuffer[11] = dpm.d[2][2]; // Expect: -80.000000 +} diff --git a/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt new file mode 100644 index 000000000..b4583a653 --- /dev/null +++ b/tests/autodiff-dstdlib/dstdlib-mul-vec-mat.slang.expected.txt @@ -0,0 +1,13 @@ +type: float +321.000000 +654.000000 +987.000000 +0.500000 +5.000000 +50.000000 +1.200000 +12.000000 +120.000000 +-0.800000 +-8.000000 +-80.000000
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-pow.slang b/tests/autodiff-dstdlib/dstdlib-pow.slang index 66f919b03..a1e2b9eda 100644 --- a/tests/autodiff-dstdlib/dstdlib-pow.slang +++ b/tests/autodiff-dstdlib/dstdlib-pow.slang @@ -1,21 +1,21 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef DifferentialPair<float2> dpfloat2; typedef DifferentialPair<float3> dpfloat3; -[ForwardDifferentiable] +[BackwardDifferentiable] float diffPow(float x, float y) { return pow(x, y); } -[ForwardDifferentiable] -float2 diffPow2(float2 x, float2 y) +[BackwardDifferentiable] +float2 diffPow(float2 x, float2 y) { return pow(x, y); } @@ -34,10 +34,20 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat2 dpx = dpfloat2(float2(10.0, 3.0), float2(0.0, -2.0)); dpfloat2 dpy = dpfloat2(float2(2.5, 4.0), float2(1.0, 1.0)); - dpfloat2 res = __fwd_diff(diffPow2)(dpx, dpy); + dpfloat2 res = __fwd_diff(diffPow)(dpx, dpy); outputBuffer[2] = res.p[0]; // Expect: 316.227722 outputBuffer[3] = res.d[0]; // Expect: 728.141235 outputBuffer[4] = res.p[1]; // Expect: 81.000000 outputBuffer[5] = res.d[1]; // Expect: -127.012398 } + + { + dpfloat2 dpx = dpfloat2(float2(2.0, 4.0), float2(0.0, 0.0)); + dpfloat2 dpy = dpfloat2(float2(3.0, -2.0), float2(0.0, 0.0)); + __bwd_diff(diffPow)(dpx, dpy, float2(1.0, 1.0)); + outputBuffer[6] = dpx.d[0]; // Expect: 12.000000 + outputBuffer[7] = dpx.d[1]; // Expect: -0.031250 + outputBuffer[8] = dpy.d[0]; // Expect: 5.545177 + outputBuffer[9] = dpy.d[1]; // Expect: 0.086643 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt index 8ac728ccf..f1bc71090 100644 --- a/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-pow.slang.expected.txt @@ -4,4 +4,8 @@ type: float 316.227722 728.141235 81.000000 --127.012398
\ No newline at end of file +-127.012398 +12.000000 +-0.031250 +5.545177 +0.086643
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang index 9ef69ea34..15573c4ef 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang @@ -1,19 +1,25 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef DifferentialPair<float2> dpfloat2; typedef DifferentialPair<float3> dpfloat3; -[ForwardDifferentiable] +[BackwardDifferentiable] float diffSqrt(float x) { return sqrt(x); } +[BackwardDifferentiable] +float2 diffSqrt(float2 x) +{ + return sqrt(x); +} + [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { @@ -37,4 +43,11 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[4] = res.p; // Expect: 0.000000 outputBuffer[5] = res.d; // Expect: 0.000000 } + + { + dpfloat2 dpx = dpfloat2(float2(10.0, 3.0), float2(0.0, 0.0)); + __bwd_diff(diffSqrt)(dpx, float2(1.0, 2.0)); + outputBuffer[6] = dpx.d[0]; // Expect: 0.158114 + outputBuffer[7] = dpx.d[1]; // Expect: 0.577350 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt index 887fd13ca..fe6487fef 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt @@ -4,4 +4,6 @@ type: float 10.000000 -0.150000 0.000000 -0.000000
\ No newline at end of file +0.000000 +0.158114 +0.577350
\ No newline at end of file diff --git a/tests/autodiff-dstdlib/dstdlib-unary.slang b/tests/autodiff-dstdlib/dstdlib-unary.slang index 4d869996d..ea1bd3d2b 100644 --- a/tests/autodiff-dstdlib/dstdlib-unary.slang +++ b/tests/autodiff-dstdlib/dstdlib-unary.slang @@ -1,24 +1,24 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; -[ForwardDifferentiable] +[BackwardDifferentiable] float f(float x) { return exp(x); } -[ForwardDifferentiable] +[BackwardDifferentiable] float g(float x) { return sin(x); } -[ForwardDifferentiable] +[BackwardDifferentiable] float h(float x) { return cos(x); @@ -37,4 +37,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) outputBuffer[4] = h(dpa.p); // Expect: -0.416146 outputBuffer[5] = __fwd_diff(h)(dpa).d; // Expect: -0.909297 } + + { + dpfloat dpa = dpfloat(2.0, 0.0); + __bwd_diff(f)(dpa, 1.0); + outputBuffer[6] = dpa.d; // Expect: 7.389056 + __bwd_diff(g)(dpa, 1.0); + outputBuffer[7] = dpa.d; // Expect: -0.416146 + __bwd_diff(h)(dpa, 1.0); + outputBuffer[8] = dpa.d; // Expect: -0.909297 + } } diff --git a/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt index 2e7338f5c..8b04eced8 100644 --- a/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-unary.slang.expected.txt @@ -5,3 +5,6 @@ type: float -0.416147 -0.416147 -0.909297 +7.389056 +-0.416147 +-0.909297 diff --git a/tests/autodiff/reverse-inout-param.slang b/tests/autodiff/reverse-inout-param.slang index 7d7f4cb05..de0d8f7ed 100644 --- a/tests/autodiff/reverse-inout-param.slang +++ b/tests/autodiff/reverse-inout-param.slang @@ -44,4 +44,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) var refVal = __fwd_diff(f_ref)(3.0, diffPair(2.0, 1.0)).d; outputBuffer[2] = refVal; // 3024 -}
\ No newline at end of file +} |
