From a02379208f8906272d3fd773d4b5cfe8eec3be3b Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 19 Feb 2025 13:05:10 -0800 Subject: Fix issue with `clamp`'s derivatives at the boundary. (#6403) --- source/slang/diff.meta.slang | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 6f2bd2cd4..38a3220be 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,4 +1,3 @@ - /// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode /// derivative implementation. /// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing @@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; /// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified /// name. /// -/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; @@ -2150,7 +2148,7 @@ DifferentialPair __d_clamp(DifferentialPair dpx, DifferentialPair dpMin { return DifferentialPair( clamp(dpx.p, dpMin.p, dpMax.p), - dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d)); + (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d)); } __generic [BackwardDifferentiable] @@ -2158,9 +2156,11 @@ __generic [BackwardDerivativeOf(clamp)] void __d_clamp(inout DifferentialPair dpx, inout DifferentialPair dpMin, inout DifferentialPair dpMax, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero()); - dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero()); - dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero()); + // Propagate the derivative to x if x is within [min, max] (including the boundaries). + dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero()); + // If x is strictly below min or above max, the gradient is instead applied to the clamp bounds + dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero()); + dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero()); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp) -- cgit v1.2.3