diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-20 23:31:05 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-20 23:31:05 -0800 |
| commit | ca592d2791047be4df3cac44c18af99f003bd085 (patch) | |
| tree | 9cbe6ca41ad95257d16bbcc01b0639505203916a /source/slang | |
| parent | 4d286aab2ec23c081f23846f5dfdb30b1c05728b (diff) | |
Fix gradient behavior for min() and max() functions at boundaries. When input values are equal, the gradient is split evenly between both inputs. (#6411)
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38a3220be..790dfaa79 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2074,7 +2074,7 @@ DifferentialPair<T> __d_max(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { return DifferentialPair<T>( max(dpx.p, dpy.p), - dpx.p > dpy.p ? dpx.d : dpy.d + dpx.p > dpy.p ? dpx.d : (dpx.p < dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2084,8 +2084,8 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(max)] void __d_max(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p > dpy.p ? dOut : (dpx.p < dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p > dpx.p ? dOut : (dpy.p < dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(max) @@ -2099,7 +2099,7 @@ DifferentialPair<T> __d_min(DifferentialPair<T> dpx, DifferentialPair<T> dpy) { return DifferentialPair<T>( min(dpx.p, dpy.p), - dpx.p < dpy.p ? dpx.d : dpy.d + dpx.p < dpy.p ? dpx.d : (dpx.p > dpy.p ? dpy.d : __mul_p_d(T(0.5), T.dadd(dpx.d, dpy.d))) ); } @@ -2109,8 +2109,8 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(min)] void __d_min(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Differential dOut) { - dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : T.dzero()); - dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : T.dzero()); + dpx = diffPair(dpx.p, dpx.p < dpy.p ? dOut : (dpx.p > dpy.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); + dpy = diffPair(dpy.p, dpy.p < dpx.p ? dOut : (dpy.p > dpx.p ? T.dzero() : __mul_p_d(T(0.5), dOut))); } VECTOR_MATRIX_BINARY_DIFF_IMPL(min) |
