From d82992e30d5985001870e00afdf27091f59464f2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 10 Apr 2023 13:43:18 -0700 Subject: Cleaner impl of unary stdlib derivative functions. (#2785) * Cleaner impl of unary stdlib derivative functions. * fixup * Fix. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 68 +++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 42 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 2bdaccee3..cb87156f5 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -740,6 +740,7 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair __d_##NAME(DifferentialPair dpx) \ { \ + typealias ReturnType = T; \ return DifferentialPair(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic \ @@ -747,40 +748,29 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair> __d_##NAME##_vector(DifferentialPair> dpx) \ { \ - vector result; \ - vector.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair dp_elem = __d_##NAME( \ - DifferentialPair(dpx.p[i], __slang_noop_cast(dpx.d[i]))); \ - result[i] = dp_elem.p; \ - d_result[i] = __slang_noop_cast(dp_elem.d); \ - } \ - return DifferentialPair>(result, d_result); \ + typealias ReturnType = vector; \ + return DifferentialPair(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic \ [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ - DifferentialPair> __d_##NAME##_m(DifferentialPair> dpx) \ + DifferentialPair> __d_##NAME##_m(DifferentialPair> dpm) \ { \ - matrix result; \ - matrix.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector; \ + matrix.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair dp_elem = __d_##NAME( \ - DifferentialPair(dpx.p[i][j], \ - __slang_noop_cast(dpx.d[i][j]))); \ - result[i][j] = dp_elem.p; \ - d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + var dpx = diffPair(dpm.p[i], dpm.d[i]); \ + diff[i] = FWD_DIFF_FUNC; \ } \ - return DifferentialPair>(result, d_result); \ + return diffPair(NAME(dpm.p), diff); \ } \ __generic \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair dpx, T.Differential dOut) \ { \ + typealias ReturnType = T; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic \ @@ -789,32 +779,26 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair> dpx, vector.Differential dOut) \ { \ - vector.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair dp_elem = diffPair(dpx.p[i], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast(dOut[i])); \ - d_result[i] = __slang_noop_cast(dp_elem.d); \ - } \ - dpx = diffPair(dpx.p, d_result); \ + typealias ReturnType = vector; \ + dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair> dpx, matrix.Differential dOut) \ + inout DifferentialPair> m, matrix.Differential mdOut) \ { \ - matrix.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector; \ + matrix.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair dp_elem = diffPair(dpx.p[i][j], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast(dOut[i][j])); \ - d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + var dpx = diffPair(m.p[i], m.d[i]); \ + var dOut = mdOut[i]; \ + diff[i] = BWD_DIFF_FUNC; \ } \ - dpx = diffPair(dpx.p, d_result); \ + m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, T.dmul(DIFF_FUNC, dpx.d), T.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) // Detach and set derivatives to zero __generic @@ -824,9 +808,9 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, (dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d)), (T.dmul(__slang_noop_cast(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast(sign(dpx.p)), dOut))) // Saturate -UNARY_DERIVATIVE_IMPL(saturate, (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dpx.d), (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dOut)) +UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees @@ -849,9 +833,9 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(log, T(1.0) / dpx.p) SIMPLE_UNARY_DERIVATIVE_IMPL(log10, T(1.0) / (dpx.p * T(52.3025850929940456840179914546844))) SIMPLE_UNARY_DERIVATIVE_IMPL(log2, T(1.0) / (dpx.p * T(50.69314718055994530941723212145818))) // Square root -SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, (dpx.p < T(1e-7) ? T(0.0) : T(0.5) / sqrt(dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, T(0.5) / sqrt(max(ReturnType(T(1e-7)), dpx.p))) // Reciprocal -SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, (dpx.p < T(1e-7) ? T(0.0) : T(-1.0) / (dpx.p * dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, T(-1.0) / max(ReturnType(T(1e-7)), dpx.p * dpx.p)) // rsqrt SIMPLE_UNARY_DERIVATIVE_IMPL(rsqrt, T(-0.5) / (dpx.p * sqrt(dpx.p))) // Arc-sin -- cgit v1.2.3