diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 62 |
1 files changed, 9 insertions, 53 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6f4888a5d..8a46f7d60 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1563,71 +1563,27 @@ void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, i VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp) // fma +__generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [ForwardDerivativeOf(fma)] [PreferRecompute] -DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz) +DifferentialPair<T> __d_fma(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz) { - return DifferentialPair<double>( + return DifferentialPair<T>( fma(dpx.p, dpy.p, dpz.p), - dpy.p * dpx.d + dpx.p * dpy.d + dpz.d); + T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d)); } +__generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] [BackwardDerivativeOf(fma)] [PreferRecompute] -void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut) +void __d_fma(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut) { - dpx = diffPair(dpx.p, dpy.p * dOut); - dpy = diffPair(dpy.p, dpx.p * dOut); + dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut)); dpz = diffPair(dpz.p, dOut); } -__generic<let N : int> -[BackwardDifferentiable] -[ForwardDerivativeOf(fma)] -[PreferRecompute] -DifferentialPair<vector<double, N>> __d_fma_vector( - DifferentialPair<vector<double, N>> dpx, - DifferentialPair<vector<double, N>> dpy, - DifferentialPair<vector<double, N>> dpz) -{ - vector<double, N> result; - vector<double, N>.Differential d_result; - [ForceUnroll] for (int i = 0; i < N; ++i) - { - DifferentialPair<double> dp_elem = __d_fma( - DifferentialPair<double>(dpx.p[i], dpx.d[i]), - DifferentialPair<double>(dpy.p[i], dpy.d[i]), - DifferentialPair<double>(dpz.p[i], dpz.d[i])); - result[i] = dp_elem.p; - d_result[i] = dp_elem.d; - } - return DifferentialPair<vector<double, N>>(result, d_result); -} -__generic<let N : int> -[BackwardDifferentiable] -[BackwardDerivativeOf(fma)] -[PreferRecompute] -void __d_fma_vector( - inout DifferentialPair<vector<double, N>> dpx, - inout DifferentialPair<vector<double, N>> dpy, - inout DifferentialPair<vector<double, N>> dpz, - vector<double, N> dOut) -{ - vector<double, N>.Differential x_d_result, y_d_result, z_d_result; - [ForceUnroll] for (int i = 0; i < N; ++i) - { - DifferentialPair<double> x_dp = diffPair(dpx.p[i], 0.0); - DifferentialPair<double> y_dp = diffPair(dpy.p[i], 0.0); - DifferentialPair<double> z_dp = diffPair(dpz.p[i], 0.0); - __d_fma(x_dp, y_dp, z_dp, dOut[i]); - x_d_result[i] = x_dp.d; - y_d_result[i] = y_dp.d; - z_d_result[i] = z_dp.d; - } - dpx = diffPair(dpx.p, x_d_result); - dpy = diffPair(dpy.p, y_d_result); - dpz = diffPair(dpz.p, z_d_result); -} +VECTOR_MATRIX_TERNARY_DIFF_IMPL(fma) // mad __generic<T : __BuiltinFloatingPointType> |
