From 25affe8e724fe4ee60a3b8ec2c494926930ba59f Mon Sep 17 00:00:00 2001 From: winmad Date: Mon, 14 Nov 2022 16:43:55 -0800 Subject: Adding some math functions and their derivatives (#2497) --- source/slang/diff.meta.slang | 234 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 216 insertions(+), 18 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6f1008277..69ced9156 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,4 +1,3 @@ - /// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. @@ -7,14 +6,14 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) -attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) -attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; +attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -90,11 +89,53 @@ struct DifferentialPair : IDifferentiable } }; -#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ - vector result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result + +#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ + vector result; \ + vector.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair dp_elem = D_FUNC(DifferentialPair(VALUE.p[i], __slang_noop_cast(VALUE.d[i]))); \ + result[i] = dp_elem.p; \ + d_result[i] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result) + + +#define VECTOR_MAP_D_BINARY(TYPE, COUNT, D_FUNC, LEFT, RIGHT) \ + vector result; \ + vector.Differential d_result; \ + for (int i = 0; i < N; ++i) \ + { \ + DifferentialPair dp_elem = D_FUNC(DifferentialPair(LEFT.p[i], __slang_noop_cast(LEFT.d[i])), \ + DifferentialPair(RIGHT.p[i], __slang_noop_cast(RIGHT.d[i]))); \ + result[i] = dp_elem.p; \ + d_result[i] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result) + + +// Detach and set derivatives to zero + +__generic +[ForwardDerivativeOf(detach)] +DifferentialPair __d_detach(DifferentialPair dpx) +{ + return DifferentialPair( + dpx.p, + T.dzero() + ); +} + +__generic +[ForwardDerivativeOf(detach)] +DifferentialPair> __d_detach_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); +} // Natural Exponent - + __generic [ForwardDerivativeOf(exp)] DifferentialPair __d_exp(DifferentialPair dpx) @@ -104,35 +145,192 @@ DifferentialPair __d_exp(DifferentialPair dpx) T.dmul(exp(dpx.p), dpx.d)); } -__generic +__generic [ForwardDerivativeOf(exp)] DifferentialPair> __d_exp_vector(DifferentialPair> dpx) { - vector result; - vector.Differential d_result; - for(int i = 0; i < N; ++i) - { - DifferentialPair dpexp = __d_exp(DifferentialPair(dpx.p[i], __slang_noop_cast(dpx.d[i]))); - result[i] = dpexp.p; - d_result[i] = __slang_noop_cast(dpexp.d); - } - return DifferentialPair>(result, d_result); + VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx); +} + +// Absolute value + +__generic +[ForwardDerivativeOf(abs)] +DifferentialPair __d_abs(DifferentialPair dpx) +{ + return DifferentialPair( + abs(dpx.p), + dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d) + ); +} + +__generic +[ForwardDerivativeOf(abs)] +DifferentialPair> __d_abs_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_abs, dpx); } +// Sine + __generic [ForwardDerivativeOf(sin)] -DifferentialPair d_sin(DifferentialPair dpx) +DifferentialPair __d_sin(DifferentialPair dpx) { return DifferentialPair( sin(dpx.p), T.dmul(cos(dpx.p), dpx.d)); } +__generic +[ForwardDerivativeOf(sin)] +DifferentialPair> __d_sin_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); +} + +// Cosine + __generic [ForwardDerivativeOf(cos)] -DifferentialPair d_cos(DifferentialPair dpx) +DifferentialPair __d_cos(DifferentialPair dpx) { return DifferentialPair( cos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } + +__generic +[ForwardDerivativeOf(cos)] +DifferentialPair> __d_cos_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); +} + +// Base-e logarithm + +__generic +[ForwardDerivativeOf(log)] +DifferentialPair __d_log(DifferentialPair dpx) +{ + return DifferentialPair( + log(dpx.p), + T.dmul(T(1.0) / dpx.p, dpx.d) + ); +} + +__generic +[ForwardDerivativeOf(log)] +DifferentialPair> __d_log_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_log, dpx); +} + +// Square root + +__generic +[ForwardDerivativeOf(sqrt)] +DifferentialPair __d_sqrt(DifferentialPair dpx) +{ + // Special case + if (dpx.p < T(1e-6)) + { + return DifferentialPair(T(0.0), T.dzero()); + } + + T val = sqrt(dpx.p); + return DifferentialPair( + val, + T.dmul(T(0.5) / val, dpx.d) + ); +} + +__generic +[ForwardDerivativeOf(sqrt)] +DifferentialPair> __d_sqrt_vector(DifferentialPair> dpx) +{ + VECTOR_MAP_D_UNARY(T, N, __d_sqrt, dpx); +} + +// Maximum + +__generic +[ForwardDerivativeOf(max)] +DifferentialPair __d_max(DifferentialPair dpx, DifferentialPair dpy) +{ + return DifferentialPair( + max(dpx.p, dpy.p), + dpx.p > dpy.p ? dpx.d : dpy.d + ); +} + +__generic +[ForwardDerivativeOf(max)] +DifferentialPair> __d_max_vector(DifferentialPair> dpx, DifferentialPair> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_max, dpx, dpy); +} + +// Minimum + +__generic +[ForwardDerivativeOf(min)] +DifferentialPair __d_min(DifferentialPair dpx, DifferentialPair dpy) +{ + return DifferentialPair( + min(dpx.p, dpy.p), + dpx.p < dpy.p ? dpx.d : dpy.d + ); +} + +__generic +[ForwardDerivativeOf(min)] +DifferentialPair> __d_min_vector(DifferentialPair> dpx, DifferentialPair> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_min, dpx, dpy); +} + +// Raise to a power + +__generic +[ForwardDerivativeOf(pow)] +DifferentialPair __d_pow(DifferentialPair dpx, DifferentialPair dpy) +{ + // Special case + if (dpx.p < T(1e-6)) + { + return DifferentialPair(T(0.0), T.dzero()); + } + + T val = pow(dpx.p, dpy.p); + T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); + T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); + return DifferentialPair( + val, + T.dadd(d1, d2) + ); +} + +__generic +[ForwardDerivativeOf(pow)] +DifferentialPair> __d_pow_vector(DifferentialPair> dpx, DifferentialPair> dpy) +{ + VECTOR_MAP_D_BINARY(T, N, __d_pow, dpx, dpy); +} + +// Vector dot product + +__generic +[ForwardDerivativeOf(dot)] +DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair> dpy) +{ + T result = T(0); + T.Differential d_result = T.dzero(); + for (int i = 0; i < N; ++i) + { + result = result + dpx.p[i] * dpy.p[i]; + d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast(dpy.d[i]))); + d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast(dpx.d[i]))); + } + return DifferentialPair(result, d_result); +} -- cgit v1.2.3