/// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; // Exclude "this" parameter from differentiation. __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) struct DifferentialPair : IDifferentiable { typedef DifferentialPair Differential; typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); property p : T { __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) get; } property v : T { __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) get; } property d : T.Differential { __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) get; } [__unsafeForceInlineEarly] T.Differential getDifferential() { return d; } [__unsafeForceInlineEarly] T getPrimal() { return p; } [__unsafeForceInlineEarly] static Differential dzero() { return Differential(T.dzero(), T.Differential.dzero()); } [__unsafeForceInlineEarly] static Differential dadd(Differential a, Differential b) { return Differential( T.dadd( a.p, b.p ), T.Differential.dadd(a.d, b.d)); } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { return Differential( T.dmul(a.p, b.p), T.Differential.dmul(a.d, b.d)); } }; __generic __intrinsic_op($(kIROp_MakeDifferentialPair)) DifferentialPair diffPair(T primal, T.Differential diff); __generic [__unsafeForceInlineEarly] DifferentialPair diffPair(T primal) { return diffPair(primal, T.dzero()); } [__unsafeForceInlineEarly] void updatePrimal(inout DifferentialPair p, T newPrimal) { p = DifferentialPair(newPrimal, p.d); } [__unsafeForceInlineEarly] void updateDiff(inout DifferentialPair p, T.Differential newDiff) { p = DifferentialPair(p.p, newDiff); } [__unsafeForceInlineEarly] void updatePair(inout DifferentialPair p, T newPrimal, T.Differential newDiff) { p = DifferentialPair(newPrimal, newDiff); } __generic __intrinsic_op($(kIROp_MakeArrayFromElement)) Array makeArrayFromElement(T element); __generic extension Array : IDifferentiable { typedef Array Differential; [__unsafeForceInlineEarly] static Differential dzero() { return makeArrayFromElement(T.dzero()); } [__unsafeForceInlineEarly] static Differential dadd(Differential a, Differential b) { Array result; for (int i = 0; i < N; i++) result[i] = T.dadd(a[i], b[i]); return result; } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { Array result; for (int i = 0; i < N; i++) result[i] = T.dmul(a[i], b[i]); return result; } } // vector-matrix __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } __generic [ForceInline] [BackwardDerivativeOf(mul)] void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) { vector.Differential left_d_result; matrix.Differential right_d_result; for (int i = 0; i < N; ++i) { T sum = T(0); for (int j = 0; j < M; ++j) { sum += right.p[i][j] * dOut[j]; right_d_result[i][j] = left.p[i] * dOut[j]; } left_d_result[i] = sum; } left = diffPair(left.p, left_d_result); right = diffPair(right.p, right_d_result); } // matrix-vector __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } __generic [ForceInline] [BackwardDerivativeOf(mul)] void __d_mul(inout DifferentialPair> left, inout DifferentialPair> right, vector.Differential dOut) { matrix.Differential left_d_result; vector.Differential right_d_result; for (int j = 0; j < M; ++j) { T sum = T(0); for (int i = 0; i < N; ++i) { sum += left.p[i][j] * dOut[i]; left_d_result[i][j] = right.p[j] * dOut[i]; } right_d_result[j] = sum; } left = diffPair(left.p, left_d_result); right = diffPair(right.p, right_d_result); } // matrix-matrix __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } __generic [ForceInline] [BackwardDerivativeOf(mul)] void mul(inout DifferentialPair> left, inout DifferentialPair> right, matrix.Differential dOut) { matrix.Differential left_d_result; for (int r = 0; r < R; ++r) for (int n = 0; n < N; ++n) left_d_result[r][n] = T(0.0); matrix.Differential right_d_result; for (int n = 0; n < N; ++n) for (int c = 0; c < C; ++c) right_d_result[n][c] = T(0.0); for (int r = 0; r < R; ++r) { for (int c = 0; c < C; ++c) { for (int n = 0; n < N; ++n) { left_d_result[r][n] += right.p[n][c] * dOut[r][c]; right_d_result[n][c] += left.p[r][n] * dOut[r][c]; } } } left = diffPair(left.p, left_d_result); right = diffPair(right.p, right_d_result); } #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector result; \ vector.Differential d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair dp_elem = D_FUNC(DifferentialPair(VALUE.p[i], __slang_noop_cast(VALUE.d[i]))); \ result[i] = dp_elem.p; \ d_result[i] = __slang_noop_cast(dp_elem.d); \ } \ return DifferentialPair>(result, d_result) #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector result; \ vector.Differential d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair dp_elem = D_FUNC(DifferentialPair(LEFT.p[i], __slang_noop_cast(LEFT.d[i])), \ DifferentialPair(RIGHT.p[i], __slang_noop_cast(RIGHT.d[i]))); \ result[i] = dp_elem.p; \ d_result[i] = __slang_noop_cast(dp_elem.d); \ } \ return DifferentialPair>(result, d_result) #define VECTOR_MAP_BWD_D_UNARY(TYPE, COUNT, D_FUNC, VALUE, D_OUT) \ vector.Differential d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair dp_elem = diffPair(VALUE.p[i], TYPE.dzero()); \ D_FUNC(dp_elem, __slang_noop_cast(D_OUT[i])); \ d_result[i] = __slang_noop_cast(dp_elem.d); \ } \ VALUE = diffPair(VALUE.p, d_result) #define VECTOR_MAP_BWD_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT, D_OUT) \ vector.Differential left_d_result, right_d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair left_dp = diffPair(LEFT.p[i], TYPE.dzero()); \ DifferentialPair right_dp = diffPair(RIGHT.p[i], TYPE.dzero()); \ D_FUNC(left_dp, right_dp, __slang_noop_cast(D_OUT[i])); \ left_d_result[i] = __slang_noop_cast(left_dp.d); \ right_d_result[i] = __slang_noop_cast(right_dp.d); \ } \ LEFT = diffPair(LEFT.p, left_d_result); \ RIGHT = diffPair(RIGHT.p, right_d_result) // Detach and set derivatives to zero __generic [ForwardDerivativeOf(detach)] DifferentialPair __d_detach(DifferentialPair dpx) { return DifferentialPair( dpx.p, T.dzero() ); } __generic [ForwardDerivativeOf(detach)] DifferentialPair> __d_detach_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); } __generic [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } __generic [BackwardDerivativeOf(detach)] void __d_detach_vector(inout DifferentialPair> dpx, vector.Differential dOut) { dpx = diffPair(dpx.p, vector.dzero()); } // Natural Exponent __generic [ForwardDerivativeOf(exp)] DifferentialPair __d_exp(DifferentialPair dpx) { return DifferentialPair( exp(dpx.p), T.dmul(exp(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(exp)] DifferentialPair> __d_exp_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx); } __generic [BackwardDerivativeOf(exp)] void __d_exp(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair( dpx.p, T.dmul(exp(dpx.p), dOut)); } __generic [BackwardDerivativeOf(exp)] void __d_exp_vector(inout DifferentialPair> dpx, vector.Differential dOut) { dpx = diffPair( dpx.p, vector.dmul(exp(dpx.p), dOut)); } // Absolute value __generic [ForwardDerivativeOf(abs)] DifferentialPair __d_abs(DifferentialPair dpx) { return DifferentialPair( abs(dpx.p), dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d) ); } __generic [ForwardDerivativeOf(abs)] DifferentialPair> __d_abs_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } __generic [BackwardDerivativeOf(abs)] void __d_abs(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair( dpx.p, T.dmul(__slang_noop_cast(sign(dpx.p)), dOut)); } __generic [BackwardDerivativeOf(abs)] void __d_abs_vector(inout DifferentialPair> dpx, vector.Differential dOut) { VECTOR_MAP_BWD_D_UNARY(T, N, __d_abs, dpx, dOut); } // Sine __generic [ForwardDerivativeOf(sin)] DifferentialPair __d_sin(DifferentialPair dpx) { return DifferentialPair( sin(dpx.p), T.dmul(cos(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(sin)] DifferentialPair> __d_sin_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); } __generic [BackwardDerivativeOf(sin)] void __d_sin(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair( dpx.p, T.dmul(cos(dpx.p), dOut)); } __generic [BackwardDerivativeOf(sin)] void __d_sin_vector(inout DifferentialPair> dpx, vector.Differential dOut) { dpx = diffPair( dpx.p, vector.dmul(cos(dpx.p), dOut)); } // Cosine __generic [ForwardDerivativeOf(cos)] DifferentialPair __d_cos(DifferentialPair dpx) { return DifferentialPair( cos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(cos)] DifferentialPair> __d_cos_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); } __generic [BackwardDerivativeOf(cos)] void __d_cos(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair( dpx.p, T.dmul(-sin(dpx.p), dOut)); } __generic [BackwardDerivativeOf(cos)] void __d_cos_vector(inout DifferentialPair> dpx, vector.Differential dOut) { dpx = diffPair( dpx.p, vector.dmul(-sin(dpx.p), dOut)); } // Base-e logarithm __generic [ForwardDerivativeOf(log)] DifferentialPair __d_log(DifferentialPair dpx) { return DifferentialPair( log(dpx.p), T.dmul(T(1.0) / dpx.p, dpx.d) ); } __generic [ForwardDerivativeOf(log)] DifferentialPair> __d_log_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); } __generic [BackwardDerivativeOf(log)] void __d_log(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dmul(T(1.0) / dpx.p, dOut)); } __generic [BackwardDerivativeOf(log)] void __d_log_vector(inout DifferentialPair> dpx, vector.Differential dOut) { VECTOR_MAP_BWD_D_UNARY(T, N, __d_log, dpx, dOut); } // Square root __generic [ForwardDerivativeOf(sqrt)] DifferentialPair __d_sqrt(DifferentialPair dpx) { // Special case if (dpx.p < T(1e-6)) { return DifferentialPair(T(0.0), T.dzero()); } T val = sqrt(dpx.p); return DifferentialPair( val, T.dmul(T(0.5) / val, dpx.d) ); } __generic [ForwardDerivativeOf(sqrt)] DifferentialPair> __d_sqrt_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); } __generic [BackwardDerivativeOf(sqrt)] void __d_sqrt(inout DifferentialPair dpx, T.Differential dOut) { // Special case if (dpx.p < T(1e-6)) { dpx = diffPair(dpx.p, T.dzero()); } else { dpx = diffPair( dpx.p, T.dmul(T(0.5) / sqrt(dpx.p), dOut)); } } __generic [BackwardDerivativeOf(sqrt)] void __d_sqrt_vector(inout DifferentialPair> dpx, vector.Differential dOut) { VECTOR_MAP_BWD_D_UNARY(T, N, __d_sqrt, dpx, dOut); } // Maximum __generic [ForwardDerivativeOf(max)] DifferentialPair __d_max(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( max(dpx.p, dpy.p), dpx.p > dpy.p ? dpx.d : dpy.d ); } __generic [ForwardDerivativeOf(max)] DifferentialPair> __d_max_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); } __generic [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); } __generic [BackwardDerivativeOf(max)] void __d_max_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) { VECTOR_MAP_BWD_D_BINARY(T, N, __d_max, dpx, dpy, dOut); } // Minimum __generic [ForwardDerivativeOf(min)] DifferentialPair __d_min(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( min(dpx.p, dpy.p), dpx.p < dpy.p ? dpx.d : dpy.d ); } __generic [ForwardDerivativeOf(min)] DifferentialPair> __d_min_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); } __generic [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); } __generic [BackwardDerivativeOf(min)] void __d_min_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) { VECTOR_MAP_BWD_D_BINARY(T, N, __d_min, dpx, dpy, dOut); } // Raise to a power __generic [ForwardDerivativeOf(pow)] DifferentialPair __d_pow(DifferentialPair dpx, DifferentialPair dpy) { // Special case if (dpx.p < T(1e-6)) { return DifferentialPair(T(0.0), T.dzero()); } T val = pow(dpx.p, dpy.p); T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); return DifferentialPair( val, T.dadd(d1, d2) ); } __generic [ForwardDerivativeOf(pow)] DifferentialPair> __d_pow_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); } __generic [BackwardDerivativeOf(pow)] void __d_pow(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { // Special case if (dpx.p < T(1e-6)) { dpx = diffPair(dpx.p, T.dzero()); dpy = diffPair(dpy.p, T.dzero()); } else { T val = pow(dpx.p, dpy.p); dpx = diffPair( dpx.p, T.dmul(val * dpy.p / dpx.p, dOut)); dpy = diffPair( dpy.p, T.dmul(val * log(dpx.p), dOut)); } } __generic [BackwardDerivativeOf(pow)] void __d_pow_vector(inout DifferentialPair> dpx, inout DifferentialPair> dpy, vector.Differential dOut) { VECTOR_MAP_BWD_D_BINARY(T, N, __d_pow, dpx, dpy, dOut); } // Vector dot product __generic [ForwardDerivativeOf(dot)] DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair> dpy) { T result = T(0); T.Differential d_result = T.dzero(); for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast(dpy.d[i]))); d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast(dpx.d[i]))); } return DifferentialPair(result, d_result); } __generic [BackwardDerivativeOf(dot)] void __d_dot(inout DifferentialPair> dpx, inout DifferentialPair> dpy, T.Differential dOut) { vector.Differential x_d_result, y_d_result; for (int i = 0; i < N; ++i) { x_d_result[i] = dpy.p[i] * __slang_noop_cast(dOut); y_d_result[i] = dpx.p[i] * __slang_noop_cast(dOut); } dpx = diffPair(dpx.p, x_d_result); dpy = diffPair(dpy.p, y_d_result); }