From 25c79ada2c0fcc6c5ecb3e71ca073109adc1d7eb Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:22:51 -0400 Subject: Fix `atan2` stdlib derivative + add tests. (#3218) * Fix atan2 stdlib derivative. Add tests for atan2 * Create dstdlib-atan2.slang.expected.txt * Update tests --- source/slang/diff.meta.slang | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'source/slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 495b6b989..75c57018c 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1259,7 +1259,7 @@ __generic DifferentialPair __d_atan2(DifferentialPair dpy, DifferentialPair dpx) { T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); - T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); + T.Differential dy = __mul_p_d(dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); return DifferentialPair( atan2(dpy.p, dpx.p), T.dadd(dx, dy)); @@ -1271,8 +1271,8 @@ __generic [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair dpy, inout DifferentialPair dpx, T.Differential dOut) { - dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); - dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); + dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dOut)); } VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) -- cgit v1.2.3