From b1d7dc0707406da69d94f3915fa48f39020b93ab Mon Sep 17 00:00:00 2001 From: winmad Date: Tue, 7 Feb 2023 22:30:00 -0800 Subject: Add backward derivatives for functions in diff.meta.slang (#2633) * WIP: start adding backward derivatives * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * WIP: working on detach * Arithmetic simplifications and more IR clean up logic. * WIP: adding detach and abs * Fix detach and abs * Fix. * Add IR transform pass for cleaner code emit. * Fix test cases. * Fix type system logic for reference type. * Add backward derivatives for functions that already have forward derivatives * Fix changes --------- Co-authored-by: Yong He Co-authored-by: Lifan Wu --- source/slang/diff.meta.slang | 241 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 4 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 055c44135..af06e6bac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -173,6 +173,27 @@ DifferentialPair> mul(DifferentialPair> left, Differen return DifferentialPair>(primal, diff); } +__generic +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) +{ + vector.Differential left_d_result; + matrix.Differential right_d_result; + for (int i = 0; i < N; ++i) + { + T sum = T(0); + for (int j = 0; j < M; ++j) + { + sum += right.p[i][j] * dOut[j]; + right_d_result[i][j] = left.p[i] * dOut[j]; + } + left_d_result[i] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + // matrix-vector __generic [ForceInline] @@ -184,18 +205,68 @@ DifferentialPair> mul(DifferentialPair> left, Differen return DifferentialPair>(primal, diff); } +__generic +[ForceInline] +[BackwardDerivativeOf(mul)] +void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) +{ + matrix.Differential left_d_result; + vector.Differential right_d_result; + for (int j = 0; j < M; ++j) + { + T sum = T(0); + for (int i = 0; i < N; ++i) + { + sum += left.p[i][j] * dOut[i]; + left_d_result[i][j] = right.p[j] * dOut[i]; + } + right_d_result[j] = sum; + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} // matrix-matrix __generic [ForceInline] [ForwardDerivativeOf(mul)] -DifferentialPair> mul(DifferentialPair> right, DifferentialPair> left) +DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { - let primal = mul(right.p, left.p); - let diff = mul(right.d, left.p) + mul(right.p, left.d); + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } +__generic +[ForceInline] +[BackwardDerivativeOf(mul)] +void mul(inout DifferentialPair> left, inout DifferentialPair> right, matrix.Differential dOut) +{ + matrix.Differential left_d_result; + for (int r = 0; r < R; ++r) + for (int n = 0; n < N; ++n) + left_d_result[r][n] = T(0.0); + + matrix.Differential right_d_result; + for (int n = 0; n < N; ++n) + for (int c = 0; c < C; ++c) + right_d_result[n][c] = T(0.0); + + for (int r = 0; r < R; ++r) + { + for (int c = 0; c < C; ++c) + { + for (int n = 0; n < N; ++n) + { + left_d_result[r][n] += right.p[n][c] * dOut[r][c]; + right_d_result[n][c] += left.p[r][n] * dOut[r][c]; + } + } + } + left = diffPair(left.p, left_d_result); + right = diffPair(right.p, right_d_result); +} + #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector result; \ vector.Differential d_result; \ @@ -207,7 +278,6 @@ DifferentialPair> mul(DifferentialPair> right, Diffe } \ return DifferentialPair>(result, d_result) - #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector result; \ vector.Differential d_result; \ @@ -220,6 +290,28 @@ DifferentialPair> mul(DifferentialPair> right, Diffe } \ return DifferentialPair>(result, d_result) +#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ + vector.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ + D_FUNC(dp_elem, __slang_noop_cast(D_OUT[i])); \ + d_result[i] = __slang_noop_cast(dp_elem.d); \ + } \ + VALUE = diffPair(VALUE.p, d_result) + +#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ + vector.Differential left_d_result, right_d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ + DifferentialPair right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \ + D_FUNC(left_dp, right_dp, __slang_noop_cast(D_OUT[i])); \ + left_d_result[i] = __slang_noop_cast(left_dp.d); \ + right_d_result[i] = __slang_noop_cast(right_dp.d); \ + } \ + LEFT = diffPair(LEFT.p, left_d_result); \ + RIGHT = diffPair(RIGHT.p, right_d_result) // Detach and set derivatives to zero @@ -240,6 +332,20 @@ DifferentialPair> __d_detach_vector(DifferentialPair> VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); } +__generic +[BackwardDerivativeOf(detach)] +void __d_detach(inout DifferentialPair dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dzero()); +} + +__generic +[BackwardDerivativeOf(detach)] +void __d_detach_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + dpx = diffPair(dpx.p, vector.dzero()); +} + // Natural Exponent __generic @@ -295,6 +401,22 @@ DifferentialPair> __d_abs_vector(DifferentialPair> dpx VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } +__generic +[BackwardDerivativeOf(abs)] +void __d_abs(inout DifferentialPair dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(__slang_noop_cast(sign(dpx.p)), dOut)); +} + +__generic +[BackwardDerivativeOf(abs)] +void __d_abs_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut); +} + // Sine __generic @@ -386,6 +508,20 @@ DifferentialPair> __d_log_vector(DifferentialPair> dpx VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); } +__generic +[BackwardDerivativeOf(log)] +void __d_log(inout DifferentialPair dpx, T.Differential dOut) +{ + dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut)); +} + +__generic +[BackwardDerivativeOf(log)] +void __d_log_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut); +} + // Square root __generic @@ -412,6 +548,30 @@ DifferentialPair> __d_sqrt_vector(DifferentialPair> dp VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); } +__generic +[BackwardDerivativeOf(sqrt)] +void __d_sqrt(inout DifferentialPair dpx, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + } + else + { + dpx = diffPair( + dpx.p, + T.dmul(T(0.5) / sqrt(dpx.p), dOut)); + } +} + +__generic +[BackwardDerivativeOf(sqrt)] +void __d_sqrt_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut); +} + // Maximum __generic @@ -431,6 +591,21 @@ DifferentialPair> __d_max_vector(DifferentialPair> dpx VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); } +__generic +[BackwardDerivativeOf(max)] +void __d_max(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); +} + +__generic +[BackwardDerivativeOf(max)] +void __d_max_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut); +} + // Minimum __generic @@ -450,6 +625,21 @@ DifferentialPair> __d_min_vector(DifferentialPair> dpx VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); } +__generic +[BackwardDerivativeOf(min)] +void __d_min(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) +{ + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); +} + +__generic +[BackwardDerivativeOf(min)] +void __d_min_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut); +} + // Raise to a power __generic @@ -478,6 +668,35 @@ DifferentialPair> __d_pow_vector(DifferentialPair> dpx VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); } +__generic +[BackwardDerivativeOf(pow)] +void __d_pow(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) +{ + // Special case + if (dpx.p < T(1e-6)) + { + dpx = diffPair(dpx.p, T.dzero()); + dpy = diffPair(dpy.p, T.dzero()); + } + else + { + T val = pow(dpx.p, dpy.p); + dpx = diffPair( + dpx.p, + T.dmul(val * dpy.p / dpx.p, dOut)); + dpy = diffPair( + dpy.p, + T.dmul(val * log(dpx.p), dOut)); + } +} + +__generic +[BackwardDerivativeOf(pow)] +void __d_pow_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) +{ + VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut); +} + // Vector dot product __generic @@ -494,3 +713,17 @@ DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair } return DifferentialPair(result, d_result); } + +__generic +[BackwardDerivativeOf(dot)] +void __d_dot(inout DifferentialPair> dpx, inout DifferentialPair> dpy, T.Differential dOut) +{ + vector.Differential x_d_result, y_d_result; + for (int i = 0; i < N; ++i) + { + x_d_result[i] = dpy.p[i] * __slang_noop_cast(dOut); + y_d_result[i] = dpx.p[i] * __slang_noop_cast(dOut); + } + dpx = diffPair(dpx.p, x_d_result); + dpy = diffPair(dpy.p, y_d_result); +} -- cgit v1.2.3