From ca592d2791047be4df3cac44c18af99f003bd085 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Feb 2025 23:31:05 -0800 Subject: Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. (#6411) Co-authored-by: Yong He --- source/slang/diff.meta.slang | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38a3220be..790dfaa79 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2074,7 +2074,7 @@ DifferentialPair __d_max(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( max(dpx.p, dpy.p), - dpx.p > dpy.p ? dpx.d : dpy.d + dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2084,8 +2084,8 @@ __generic [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(max) @@ -2099,7 +2099,7 @@ DifferentialPair __d_min(DifferentialPair dpx, DifferentialPair dpy) { return DifferentialPair( min(dpx.p, dpy.p), - dpx.p < dpy.p ? dpx.d : dpy.d + dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2109,8 +2109,8 @@ __generic [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(min) -- cgit v1.2.3