diff options
| author | winmad <winmad.wlf@gmail.com> | 2023-02-07 22:30:00 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-07 22:30:00 -0800 |
| commit | b1d7dc0707406da69d94f3915fa48f39020b93ab (patch) | |
| tree | bb571884616aaf79fa663aec1eb8fcf92f5138cc /source/slang/diff.meta.slang | |
| parent | 4be623c52a6518eb86756a0369706c1d6670f6bb (diff) | |
Add backward derivatives for functions in diff.meta.slang (#2633)
* WIP: start adding backward derivatives
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* WIP: working on detach
* Arithmetic simplifications and more IR clean up logic.
* WIP: adding detach and abs
* Fix detach and abs
* Fix.
* Add IR transform pass for cleaner code emit.
* Fix test cases.
* Fix type system logic for reference type.
* Add backward derivatives for functions that already have forward derivatives
* Fix changes
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Co-authored-by: Lifan Wu <lifanw@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 241 |
1 files changed, 237 insertions, 4 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 055c44135..af06e6bac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -173,6 +173,27 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen return DifferentialPair<vector<T,M>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut) +{ + vector<T, N>.Differential left_d_result; + matrix<T, N, M>.Differential right_d_result; + for (int i = 0; i < N; ++i) + { + T sum = T(0); + for (int j = 0; j < M; ++j) + { + sum += right.p[i][j] * dOut[j]; + right_d_result[i][j] = left.p[i] * dOut[j]; + } + left_d_result[i] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + // matrix-vector __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] @@ -184,18 +205,68 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen return DifferentialPair<vector<T,N>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut) +{ + matrix<T, N, M>.Differential left_d_result; + vector<T, M>.Differential right_d_result; + for (int j = 0; j < M; ++j) + { + T sum = T(0); + for (int i = 0; i < N; ++i) + { + sum += left.p[i][j] * dOut[i]; + left_d_result[i][j] = right.p[j] * dOut[i]; + } + right_d_result[j] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} // matrix-matrix __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [ForceInline] [ForwardDerivativeOf(mul)] -DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left) +DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right) { - let primal = mul(right.p, left.p); - let diff = mul(right.d, left.p) + mul(right.p, left.d); + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair<matrix<T,R,C>>(primal, diff); } +__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> +[ForceInline] +[BackwardDerivativeOf(mul)] +void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut) +{ + matrix<T, R, N>.Differential left_d_result; + for (int r = 0; r < R; ++r) + for (int n = 0; n < N; ++n) + left_d_result[r][n] = T(0.0); + + matrix<T, N, C>.Differential right_d_result; + for (int n = 0; n < N; ++n) + for (int c = 0; c < C; ++c) + right_d_result[n][c] = T(0.0); + + for (int r = 0; r < R; ++r) + { + for (int c = 0; c < C; ++c) + { + for (int n = 0; n < N; ++n) + { + left_d_result[r][n] += right.p[n][c] * dOut[r][c]; + right_d_result[n][c] += left.p[r][n] * dOut[r][c]; + } + } + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ @@ -207,7 +278,6 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe } \ return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) - #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector<TYPE, COUNT> result; \ vector<TYPE, COUNT>.Differential d_result; \ @@ -220,6 +290,28 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe } \ return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) +#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ + vector<TYPE, COUNT>.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ + D_FUNC(dp_elem, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \ + d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \ + } \ + VALUE = diffPair(VALUE.p, d_result) + +#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ + vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ + DifferentialPair<TYPE> right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \ + D_FUNC(left_dp, right_dp, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \ + left_d_result[i] = __slang_noop_cast<TYPE>(left_dp.d); \ + right_d_result[i] = __slang_noop_cast<TYPE>(right_dp.d); \ + } \ + LEFT = diffPair(LEFT.p, left_d_result); \ + RIGHT = diffPair(RIGHT.p, right_d_result) // Detach and set derivatives to zero @@ -240,6 +332,20 @@ DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(detach)] +void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(detach)] +void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair(dpx.p, vector<T, N>.dzero()); +} + // Natural Exponent __generic<T : __BuiltinFloatingPointType> @@ -295,6 +401,22 @@ DifferentialPair<vector<T, N>> __d_abs_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(abs)] +void __d_abs(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(abs)] +void __d_abs_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut); +} + // Sine __generic<T : __BuiltinFloatingPointType> @@ -386,6 +508,20 @@ DifferentialPair<vector<T, N>> __d_log_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(log)] +void __d_log(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(log)] +void __d_log_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut); +} + // Square root __generic<T : __BuiltinFloatingPointType> @@ -412,6 +548,30 @@ DifferentialPair<vector<T, N>> __d_sqrt_vector(DifferentialPair<vector<T, N>> dp VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(sqrt)] +void __d_sqrt(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + } + else + { + dpx = diffPair( + dpx.p, + T.dmul(T(0.5) / sqrt(dpx.p), dOut)); + } +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(sqrt)] +void __d_sqrt_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut); +} + // Maximum __generic<T : __BuiltinFloatingPointType> @@ -431,6 +591,21 @@ DifferentialPair<vector<T, N>> __d_max_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(max)] +void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(max)] +void __d_max_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut); +} + // Minimum __generic<T : __BuiltinFloatingPointType> @@ -450,6 +625,21 @@ DifferentialPair<vector<T, N>> __d_min_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(min)] +void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(min)] +void __d_min_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut); +} + // Raise to a power __generic<T : __BuiltinFloatingPointType> @@ -478,6 +668,35 @@ DifferentialPair<vector<T, N>> __d_pow_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(pow)] +void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + dpy = diffPair(dpy.p, T.dzero()); + } + else + { + T val = pow(dpx.p, dpy.p); + dpx = diffPair( + dpx.p, + T.dmul(val * dpy.p / dpx.p, dOut)); + dpy = diffPair( + dpy.p, + T.dmul(val * log(dpx.p), dOut)); + } +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(pow)] +void __d_pow_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut); +} + // Vector dot product __generic<T : __BuiltinFloatingPointType, let N : int> @@ -494,3 +713,17 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair } return DifferentialPair<T>(result, d_result); } + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(dot)] +void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) +{ + vector<T, N>.Differential x_d_result, y_d_result; + for (int i = 0; i < N; ++i) + { + x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut); + y_d_result[i] = dpx.p[i] * __slang_noop_cast<T>(dOut); + } + dpx = diffPair(dpx.p, x_d_result); + dpy = diffPair(dpy.p, y_d_result); +} |
