summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorwinmad <winmad.wlf@gmail.com>2023-02-07 22:30:00 -0800
committerGitHub <noreply@github.com>2023-02-07 22:30:00 -0800
commitb1d7dc0707406da69d94f3915fa48f39020b93ab (patch)
treebb571884616aaf79fa663aec1eb8fcf92f5138cc /source/slang/diff.meta.slang
parent4be623c52a6518eb86756a0369706c1d6670f6bb (diff)
Add backward derivatives for functions in diff.meta.slang (#2633)
* WIP: start adding backward derivatives * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * WIP: working on detach * Arithmetic simplifications and more IR clean up logic. * WIP: adding detach and abs * Fix detach and abs * Fix. * Add IR transform pass for cleaner code emit. * Fix test cases. * Fix type system logic for reference type. * Add backward derivatives for functions that already have forward derivatives * Fix changes --------- Co-authored-by: Yong He <yhe@nvidia.com> Co-authored-by: Lifan Wu <lifanw@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang241
1 files changed, 237 insertions, 4 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 055c44135..af06e6bac 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -173,6 +173,27 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen
return DifferentialPair<vector<T,M>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut)
+{
+ vector<T, N>.Differential left_d_result;
+ matrix<T, N, M>.Differential right_d_result;
+ for (int i = 0; i < N; ++i)
+ {
+ T sum = T(0);
+ for (int j = 0; j < M; ++j)
+ {
+ sum += right.p[i][j] * dOut[j];
+ right_d_result[i][j] = left.p[i] * dOut[j];
+ }
+ left_d_result[i] = sum;
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
+
// matrix-vector
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
@@ -184,18 +205,68 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
return DifferentialPair<vector<T,N>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut)
+{
+ matrix<T, N, M>.Differential left_d_result;
+ vector<T, M>.Differential right_d_result;
+ for (int j = 0; j < M; ++j)
+ {
+ T sum = T(0);
+ for (int i = 0; i < N; ++i)
+ {
+ sum += left.p[i][j] * dOut[i];
+ left_d_result[i][j] = right.p[j] * dOut[i];
+ }
+ right_d_result[j] = sum;
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
// matrix-matrix
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
-DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left)
+DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right)
{
- let primal = mul(right.p, left.p);
- let diff = mul(right.d, left.p) + mul(right.p, left.d);
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
return DifferentialPair<matrix<T,R,C>>(primal, diff);
}
+__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
+[ForceInline]
+[BackwardDerivativeOf(mul)]
+void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut)
+{
+ matrix<T, R, N>.Differential left_d_result;
+ for (int r = 0; r < R; ++r)
+ for (int n = 0; n < N; ++n)
+ left_d_result[r][n] = T(0.0);
+
+ matrix<T, N, C>.Differential right_d_result;
+ for (int n = 0; n < N; ++n)
+ for (int c = 0; c < C; ++c)
+ right_d_result[n][c] = T(0.0);
+
+ for (int r = 0; r < R; ++r)
+ {
+ for (int c = 0; c < C; ++c)
+ {
+ for (int n = 0; n < N; ++n)
+ {
+ left_d_result[r][n] += right.p[n][c] * dOut[r][c];
+ right_d_result[n][c] += left.p[r][n] * dOut[r][c];
+ }
+ }
+ }
+ left = diffPair(left.p, left_d_result);
+ right = diffPair(right.p, right_d_result);
+}
+
#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
@@ -207,7 +278,6 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe
} \
return DifferentialPair<vector<TYPE, COUNT>>(result, d_result)
-
#define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \
vector<TYPE, COUNT> result; \
vector<TYPE, COUNT>.Differential d_result; \
@@ -220,6 +290,28 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, Diffe
} \
return DifferentialPair<vector<TYPE, COUNT>>(result, d_result)
+#define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \
+ vector<TYPE, COUNT>.Differential d_result; \
+ for (int i = 0; i < N; ++i) \
+ { \
+ DifferentialPair<TYPE> dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \
+ D_FUNC(dp_elem, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \
+ d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \
+ } \
+ VALUE = diffPair(VALUE.p, d_result)
+
+#define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \
+ vector<TYPE, COUNT>.Differential left_d_result, right_d_result; \
+ for (int i = 0; i < N; ++i) \
+ { \
+ DifferentialPair<TYPE> left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \
+ DifferentialPair<TYPE> right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \
+ D_FUNC(left_dp, right_dp, __slang_noop_cast<TYPE.Differential>(D_OUT[i])); \
+ left_d_result[i] = __slang_noop_cast<TYPE>(left_dp.d); \
+ right_d_result[i] = __slang_noop_cast<TYPE>(right_dp.d); \
+ } \
+ LEFT = diffPair(LEFT.p, left_d_result); \
+ RIGHT = diffPair(RIGHT.p, right_d_result)
// Detach and set derivatives to zero
@@ -240,6 +332,20 @@ DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>>
VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(detach)]
+void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(detach)]
+void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ dpx = diffPair(dpx.p, vector<T, N>.dzero());
+}
+
// Natural Exponent
__generic<T : __BuiltinFloatingPointType>
@@ -295,6 +401,22 @@ DifferentialPair<vector<T, N>> __d_abs_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(abs)]
+void __d_abs(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(abs)]
+void __d_abs_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut);
+}
+
// Sine
__generic<T : __BuiltinFloatingPointType>
@@ -386,6 +508,20 @@ DifferentialPair<vector<T, N>> __d_log_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_UNARY(T, N, __d_log, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(log)]
+void __d_log(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut));
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(log)]
+void __d_log_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut);
+}
+
// Square root
__generic<T : __BuiltinFloatingPointType>
@@ -412,6 +548,30 @@ DifferentialPair<vector<T, N>> __d_sqrt_vector(DifferentialPair<vector<T, N>> dp
VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(sqrt)]
+void __d_sqrt(inout DifferentialPair<T> dpx, T.Differential dOut)
+{
+ // Special case
+ if (dpx.p < T(1e-6))
+ {
+ dpx = diffPair(dpx.p, T.dzero());
+ }
+ else
+ {
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(T(0.5) / sqrt(dpx.p), dOut));
+ }
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(sqrt)]
+void __d_sqrt_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut);
+}
+
// Maximum
__generic<T : __BuiltinFloatingPointType>
@@ -431,6 +591,21 @@ DifferentialPair<vector<T, N>> __d_max_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(max)]
+void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
+ dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(max)]
+void __d_max_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut);
+}
+
// Minimum
__generic<T : __BuiltinFloatingPointType>
@@ -450,6 +625,21 @@ DifferentialPair<vector<T, N>> __d_min_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(min)]
+void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
+ dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(min)]
+void __d_min_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut);
+}
+
// Raise to a power
__generic<T : __BuiltinFloatingPointType>
@@ -478,6 +668,35 @@ DifferentialPair<vector<T, N>> __d_pow_vector(DifferentialPair<vector<T, N>> dpx
VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy);
}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(pow)]
+void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
+{
+ // Special case
+ if (dpx.p < T(1e-6))
+ {
+ dpx = diffPair(dpx.p, T.dzero());
+ dpy = diffPair(dpy.p, T.dzero());
+ }
+ else
+ {
+ T val = pow(dpx.p, dpy.p);
+ dpx = diffPair(
+ dpx.p,
+ T.dmul(val * dpy.p / dpx.p, dOut));
+ dpy = diffPair(
+ dpy.p,
+ T.dmul(val * log(dpx.p), dOut));
+ }
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(pow)]
+void __d_pow_vector(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, vector<T, N>.Differential dOut)
+{
+ VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut);
+}
+
// Vector dot product
__generic<T : __BuiltinFloatingPointType, let N : int>
@@ -494,3 +713,17 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
}
return DifferentialPair<T>(result, d_result);
}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(dot)]
+void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
+{
+ vector<T, N>.Differential x_d_result, y_d_result;
+ for (int i = 0; i < N; ++i)
+ {
+ x_d_result[i] = dpy.p[i] * __slang_noop_cast<T>(dOut);
+ y_d_result[i] = dpx.p[i] * __slang_noop_cast<T>(dOut);
+ }
+ dpx = diffPair(dpx.p, x_d_result);
+ dpy = diffPair(dpy.p, y_d_result);
+}