diff options
| author | winmad <winmad.wlf@gmail.com> | 2022-11-14 16:43:55 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-14 16:43:55 -0800 |
| commit | 25affe8e724fe4ee60a3b8ec2c494926930ba59f (patch) | |
| tree | 39d2d3d209a99152e80bf40c395002697d2c3338 /source/slang/diff.meta.slang | |
| parent | 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (diff) | |
Adding some math functions and their derivatives (#2497)
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 234 |
1 files changed, 216 insertions, 18 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6f1008277..69ced9156 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,4 +1,3 @@ - /// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. @@ -7,14 +6,14 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) -attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) -attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; +attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -90,11 +89,53 @@ struct DifferentialPair : IDifferentiable } }; -#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ - vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result + +#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ + vector<TYPE, COUNT> result; \ + vector<TYPE, COUNT>.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(VALUE.p[i], __slang_noop_cast<TYPE.Differential>(VALUE.d[i]))); \ + result[i] = dp_elem.p; \ + d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \ + } \ + return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) + + +#define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ + vector<TYPE, COUNT> result; \ + vector<TYPE, COUNT>.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair<TYPE> dp_elem = D_FUNC(DifferentialPair<TYPE>(LEFT.p[i], __slang_noop_cast<TYPE.Differential>(LEFT.d[i])), \ + DifferentialPair<TYPE>(RIGHT.p[i], __slang_noop_cast<TYPE.Differential>(RIGHT.d[i]))); \ + result[i] = dp_elem.p; \ + d_result[i] = __slang_noop_cast<TYPE>(dp_elem.d); \ + } \ + return DifferentialPair<vector<TYPE, COUNT>>(result, d_result) + + +// Detach and set derivatives to zero + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(detach)] +DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + dpx.p, + T.dzero() + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(detach)] +DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); +} // Natural Exponent - + __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(exp)] DifferentialPair<T> __d_exp(DifferentialPair<T> dpx) @@ -104,35 +145,192 @@ DifferentialPair<T> __d_exp(DifferentialPair<T> dpx) T.dmul(exp(dpx.p), dpx.d)); } -__generic<T:__BuiltinFloatingPointType, let N : int> +__generic<T : __BuiltinFloatingPointType, let N : int> [ForwardDerivativeOf(exp)] DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx) { - vector<T, N> result; - vector<T, N>.Differential d_result; - for(int i = 0; i < N; ++i) - { - DifferentialPair<T> dpexp = __d_exp(DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); - result[i] = dpexp.p; - d_result[i] = __slang_noop_cast<T>(dpexp.d); - } - return DifferentialPair<vector<T, N>>(result, d_result); + VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx); +} + +// Absolute value + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(abs)] +DifferentialPair<T> __d_abs(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + abs(dpx.p), + dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d) + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(abs)] +DifferentialPair<vector<T, N>> __d_abs_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } +// Sine + __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(sin)] -DifferentialPair<T> d_sin(DifferentialPair<T> dpx) +DifferentialPair<T> __d_sin(DifferentialPair<T> dpx) { return DifferentialPair<T>( sin(dpx.p), T.dmul(cos(dpx.p), dpx.d)); } +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(sin)] +DifferentialPair<vector<T, N>> __d_sin_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); +} + +// Cosine + __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(cos)] -DifferentialPair<T> d_cos(DifferentialPair<T> dpx) +DifferentialPair<T> __d_cos(DifferentialPair<T> dpx) { return DifferentialPair<T>( cos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(cos)] +DifferentialPair<vector<T, N>> __d_cos_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); +} + +// Base-e logarithm + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(log)] +DifferentialPair<T> __d_log(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + log(dpx.p), + T.dmul(T(1.0) / dpx.p, dpx.d) + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(log)] +DifferentialPair<vector<T, N>> __d_log_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); +} + +// Square root + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(sqrt)] +DifferentialPair<T> __d_sqrt(DifferentialPair<T> dpx) +{ + // Special case + if (dpx.p < T(1e-6)) + { + return DifferentialPair<T>(T(0.0), T.dzero()); + } + + T val = sqrt(dpx.p); + return DifferentialPair<T>( + val, + T.dmul(T(0.5) / val, dpx.d) + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(sqrt)] +DifferentialPair<vector<T, N>> __d_sqrt_vector(DifferentialPair<vector<T, N>> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); +} + +// Maximum + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(max)] +DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) +{ + return DifferentialPair<T>( + max(dpx.p, dpy.p), + dpx.p > dpy.p ? dpx.d : dpy.d + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(max)] +DifferentialPair<vector<T, N>> __d_max_vector(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); +} + +// Minimum + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(min)] +DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) +{ + return DifferentialPair<T>( + min(dpx.p, dpy.p), + dpx.p < dpy.p ? dpx.d : dpy.d + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(min)] +DifferentialPair<vector<T, N>> __d_min_vector(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); +} + +// Raise to a power + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(pow)] +DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) +{ + // Special case + if (dpx.p < T(1e-6)) + { + return DifferentialPair<T>(T(0.0), T.dzero()); + } + + T val = pow(dpx.p, dpy.p); + T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); + T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); + return DifferentialPair<T>( + val, + T.dadd(d1, d2) + ); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(pow)] +DifferentialPair<vector<T, N>> __d_pow_vector(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); +} + +// Vector dot product + +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(dot)] +DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) +{ + T result = T(0); + T.Differential d_result = T.dzero(); + for (int i = 0; i < N; ++i) + { + result = result + dpx.p[i] * dpy.p[i]; + d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast<T.Differential>(dpy.d[i]))); + d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); + } + return DifferentialPair<T>(result, d_result); +} |
