/// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; // Exclude "this" parameter from differentiation. __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) struct DifferentialPair : IDifferentiable { typedef DifferentialPair Differential; typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); property p : T { __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) get; } property v : T { __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) get; } property d : T.Differential { __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) get; } [__unsafeForceInlineEarly] T.Differential getDifferential() { return d; } [__unsafeForceInlineEarly] T getPrimal() { return p; } [__unsafeForceInlineEarly] static Differential dzero() { return Differential(T.dzero(), T.Differential.dzero()); } [__unsafeForceInlineEarly] static Differential dadd(Differential a, Differential b) { return Differential( T.dadd( a.p, b.p ), T.Differential.dadd(a.d, b.d)); } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { return Differential( T.dmul(a.p, b.p), T.Differential.dmul(a.d, b.d)); } }; __generic __intrinsic_op($(kIROp_MakeDifferentialPair)) DifferentialPair diffPair(T primal, T.Differential diff); __generic [__unsafeForceInlineEarly] DifferentialPair diffPair(T primal) { return diffPair(primal, T.dzero()); } [__unsafeForceInlineEarly] void updatePrimal(inout DifferentialPair p, T newPrimal) { p = DifferentialPair(newPrimal, p.d); } [__unsafeForceInlineEarly] void updateDiff(inout DifferentialPair p, T.Differential newDiff) { p = DifferentialPair(p.p, newDiff); } [__unsafeForceInlineEarly] void updatePair(inout DifferentialPair p, T newPrimal, T.Differential newDiff) { p = DifferentialPair(newPrimal, newDiff); } // vector-matrix __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } // matrix-vector __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) { let primal = mul(left.p, right.p); let diff = mul(left.d, right.p) + mul(left.p, right.d); return DifferentialPair>(primal, diff); } // matrix-matrix __generic [ForceInline] [ForwardDerivativeOf(mul)] DifferentialPair> mul(DifferentialPair> right, DifferentialPair> left) { let primal = mul(right.p, left.p); let diff = mul(right.d, left.p) + mul(right.p, left.d); return DifferentialPair>(primal, diff); } #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector result; \ vector.Differential d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair dp_elem = D_FUNC(DifferentialPair(VALUE.p[i], __slang_noop_cast(VALUE.d[i]))); \ result[i] = dp_elem.p; \ d_result[i] = __slang_noop_cast(dp_elem.d); \ } \ return DifferentialPair>(result, d_result) #define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ vector result; \ vector.Differential d_result; \ for (int i = 0; i < N; ++i) \ { \ DifferentialPair dp_elem = D_FUNC(DifferentialPair(LEFT.p[i], __slang_noop_cast(LEFT.d[i])), \ DifferentialPair(RIGHT.p[i], __slang_noop_cast(RIGHT.d[i]))); \ result[i] = dp_elem.p; \ d_result[i] = __slang_noop_cast(dp_elem.d); \ } \ return DifferentialPair>(result, d_result) // Detach and set derivatives to zero __generic [ForwardDerivativeOf(detach)] DifferentialPair __d_detach(DifferentialPair dpx) { return DifferentialPair( dpx.p, T.dzero() ); } __generic [ForwardDerivativeOf(detach)] DifferentialPair> __d_detach_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); } // Natural Exponent __generic [ForwardDerivativeOf(exp)] DifferentialPair __d_exp(DifferentialPair dpx) { return DifferentialPair( exp(dpx.p), T.dmul(exp(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(exp)] DifferentialPair> __d_exp_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx); } // Absolute value __generic [ForwardDerivativeOf(abs)] DifferentialPair __d_abs(DifferentialPair dpx) { return DifferentialPair( abs(dpx.p), dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d) ); } __generic [ForwardDerivativeOf(abs)] DifferentialPair> __d_abs_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } // Sine __generic [ForwardDerivativeOf(sin)] DifferentialPair __d_sin(DifferentialPair dpx) { return DifferentialPair( sin(dpx.p), T.dmul(cos(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(sin)] DifferentialPair> __d_sin_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); } // Cosine __generic [ForwardDerivativeOf(cos)] DifferentialPair __d_cos(DifferentialPair dpx) { return DifferentialPair( cos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } __generic [ForwardDerivativeOf(cos)] DifferentialPair> __d_cos_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); } // Base-e logarithm __generic [ForwardDerivativeOf(log)] DifferentialPair __d_log(DifferentialPair dpx) { return DifferentialPair( log(dpx.p), T.dmul(T(1.0) / dpx.p, dpx.d) ); } __generic [ForwardDerivativeOf(log)] DifferentialPair> __d_log_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); } // Square root __generic [ForwardDerivativeOf(sqrt)] DifferentialPair __d_sqrt(DifferentialPair dpx) { // Special case if (dpx.p < T(1e-6)) { return DifferentialPair(T(0.0), T.dzero()); } T val = sqrt(dpx.p); return DifferentialPair( val, T.dmul(T(0.5) / val, dpx.d) ); } __generic [ForwardDerivativeOf(sqrt)] DifferentialPair> __d_sqrt_vector(DifferentialPair> dpx) { VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); } // Maximum __generic [ForwardDerivativeOf(max)] DifferentialPair __d_max(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( max(dpx.p, dpy.p), dpx.p > dpy.p ? dpx.d : dpy.d ); } __generic [ForwardDerivativeOf(max)] DifferentialPair> __d_max_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); } // Minimum __generic [ForwardDerivativeOf(min)] DifferentialPair __d_min(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( min(dpx.p, dpy.p), dpx.p < dpy.p ? dpx.d : dpy.d ); } __generic [ForwardDerivativeOf(min)] DifferentialPair> __d_min_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); } // Raise to a power __generic [ForwardDerivativeOf(pow)] DifferentialPair __d_pow(DifferentialPair dpx, DifferentialPair dpy) { // Special case if (dpx.p < T(1e-6)) { return DifferentialPair(T(0.0), T.dzero()); } T val = pow(dpx.p, dpy.p); T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); return DifferentialPair( val, T.dadd(d1, d2) ); } __generic [ForwardDerivativeOf(pow)] DifferentialPair> __d_pow_vector(DifferentialPair> dpx, DifferentialPair> dpy) { VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); } // Vector dot product __generic [ForwardDerivativeOf(dot)] DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair> dpy) { T result = T(0); T.Differential d_result = T.dzero(); for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast(dpy.d[i]))); d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast(dpx.d[i]))); } return DifferentialPair(result, d_result); }