summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang62
1 files changed, 9 insertions, 53 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 6f4888a5d..8a46f7d60 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1563,71 +1563,27 @@ void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, i
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)
// fma
+__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[ForwardDerivativeOf(fma)]
[PreferRecompute]
-DifferentialPair<double> __d_fma(DifferentialPair<double> dpx, DifferentialPair<double> dpy, DifferentialPair<double> dpz)
+DifferentialPair<T> __d_fma(DifferentialPair<T> dpx, DifferentialPair<T> dpy, DifferentialPair<T> dpz)
{
- return DifferentialPair<double>(
+ return DifferentialPair<T>(
fma(dpx.p, dpy.p, dpz.p),
- dpy.p * dpx.d + dpx.p * dpy.d + dpz.d);
+ T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d));
}
+__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
[BackwardDerivativeOf(fma)]
[PreferRecompute]
-void __d_fma(inout DifferentialPair<double> dpx, inout DifferentialPair<double> dpy, inout DifferentialPair<double> dpz, double dOut)
+void __d_fma(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut)
{
- dpx = diffPair(dpx.p, dpy.p * dOut);
- dpy = diffPair(dpy.p, dpx.p * dOut);
+ dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut));
+ dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut));
dpz = diffPair(dpz.p, dOut);
}
-__generic<let N : int>
-[BackwardDifferentiable]
-[ForwardDerivativeOf(fma)]
-[PreferRecompute]
-DifferentialPair<vector<double, N>> __d_fma_vector(
- DifferentialPair<vector<double, N>> dpx,
- DifferentialPair<vector<double, N>> dpy,
- DifferentialPair<vector<double, N>> dpz)
-{
- vector<double, N> result;
- vector<double, N>.Differential d_result;
- [ForceUnroll] for (int i = 0; i < N; ++i)
- {
- DifferentialPair<double> dp_elem = __d_fma(
- DifferentialPair<double>(dpx.p[i], dpx.d[i]),
- DifferentialPair<double>(dpy.p[i], dpy.d[i]),
- DifferentialPair<double>(dpz.p[i], dpz.d[i]));
- result[i] = dp_elem.p;
- d_result[i] = dp_elem.d;
- }
- return DifferentialPair<vector<double, N>>(result, d_result);
-}
-__generic<let N : int>
-[BackwardDifferentiable]
-[BackwardDerivativeOf(fma)]
-[PreferRecompute]
-void __d_fma_vector(
- inout DifferentialPair<vector<double, N>> dpx,
- inout DifferentialPair<vector<double, N>> dpy,
- inout DifferentialPair<vector<double, N>> dpz,
- vector<double, N> dOut)
-{
- vector<double, N>.Differential x_d_result, y_d_result, z_d_result;
- [ForceUnroll] for (int i = 0; i < N; ++i)
- {
- DifferentialPair<double> x_dp = diffPair(dpx.p[i], 0.0);
- DifferentialPair<double> y_dp = diffPair(dpy.p[i], 0.0);
- DifferentialPair<double> z_dp = diffPair(dpz.p[i], 0.0);
- __d_fma(x_dp, y_dp, z_dp, dOut[i]);
- x_d_result[i] = x_dp.d;
- y_d_result[i] = y_dp.d;
- z_d_result[i] = z_dp.d;
- }
- dpx = diffPair(dpx.p, x_d_result);
- dpy = diffPair(dpy.p, y_d_result);
- dpz = diffPair(dpz.p, z_d_result);
-}
+VECTOR_MATRIX_TERNARY_DIFF_IMPL(fma)
// mad
__generic<T : __BuiltinFloatingPointType>