summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-02-19 13:05:10 -0800
committerGitHub <noreply@github.com>2025-02-19 13:05:10 -0800
commita02379208f8906272d3fd773d4b5cfe8eec3be3b (patch)
tree854f74fa2dc6da7ca660c0b7eba9407e11040c32 /source
parent0959d7ebeb6932b1949a4be10e5c472327006352 (diff)
Fix issue with `clamp`'s derivatives at the boundary. (#6403)
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang12
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 6f2bd2cd4..38a3220be 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1,4 +1,3 @@
-
/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode
/// derivative implementation.
/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing
@@ -80,7 +79,6 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified
/// name.
///
-///
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
@@ -2150,7 +2148,7 @@ DifferentialPair<T> __d_clamp(DifferentialPair<T> dpx, DifferentialPair<T> dpMin
{
return DifferentialPair<T>(
clamp(dpx.p, dpMin.p, dpMax.p),
- dpx.p < dpMin.p ? dpMin.d : (dpx.p > dpMax.p ? dpMax.d : dpx.d));
+ (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dpx.d : (dpx.p < dpMin.p ? dpMin.d : dpMax.d));
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
@@ -2158,9 +2156,11 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(clamp)]
void __d_clamp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpMin, inout DifferentialPair<T> dpMax, T.Differential dOut)
{
- dpx = diffPair(dpx.p, dpx.p > dpMin.p && dpx.p < dpMax.p ? dOut : T.dzero());
- dpMin = diffPair(dpMin.p, dpx.p <= dpMin.p ? dOut : T.dzero());
- dpMax = diffPair(dpMax.p, dpx.p >= dpMax.p ? dOut : T.dzero());
+ // Propagate the derivative to x if x is within [min, max] (including the boundaries).
+ dpx = diffPair(dpx.p, (dpx.p >= dpMin.p && dpx.p <= dpMax.p) ? dOut : T.dzero());
+ // If x is strictly below min or above max, the gradient is instead applied to the clamp bounds
+ dpMin = diffPair(dpMin.p, dpx.p < dpMin.p ? dOut : T.dzero());
+ dpMax = diffPair(dpMax.p, dpx.p > dpMax.p ? dOut : T.dzero());
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(clamp)