summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-02-20 23:31:05 -0800
committerGitHub <noreply@github.com>2025-02-20 23:31:05 -0800
commitca592d2791047be4df3cac44c18af99f003bd085 (patch)
tree9cbe6ca41ad95257d16bbcc01b0639505203916a /source/slang/diff.meta.slang
parent4d286aab2ec23c081f23846f5dfdb30b1c05728b (diff)
Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. (#6411)
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang12
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 38a3220be..790dfaa79 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
return DifferentialPair<T>(
max(dpx.p, dpy.p),
- dpx.p > dpy.p ? dpx.d : dpy.d
+ dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
);
}
@@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(max)]
void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
- dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero());
- dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero());
+ dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
+ dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
}
VECTOR_MATRIX_BINARY_DIFF_IMPL(max)
@@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
{
return DifferentialPair<T>(
min(dpx.p, dpy.p),
- dpx.p < dpy.p ? dpx.d : dpy.d
+ dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d)))
);
}
@@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(min)]
void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut)
{
- dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero());
- dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero());
+ dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
+ dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut)));
}
VECTOR_MATRIX_BINARY_DIFF_IMPL(min)