diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-10 13:43:18 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-10 13:43:18 -0700 |
| commit | d82992e30d5985001870e00afdf27091f59464f2 (patch) | |
| tree | 6ee2c4ce4c98686e5721bda89f274343e366ab4e /source/slang/diff.meta.slang | |
| parent | ea15647ba6bccb5ac48de5f4b80b8c2769d69b8f (diff) | |
Cleaner impl of unary stdlib derivative functions. (#2785)
* Cleaner impl of unary stdlib derivative functions.
* fixup
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 68 |
1 files changed, 26 insertions, 42 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 2bdaccee3..cb87156f5 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -740,6 +740,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ + typealias ReturnType = T; \ return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ @@ -747,40 +748,29 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ - vector<T, N> result; \ - vector<T, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair<T> dp_elem = __d_##NAME( \ - DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); \ - result[i] = dp_elem.p; \ - d_result[i] = __slang_noop_cast<T>(dp_elem.d); \ - } \ - return DifferentialPair<vector<T, N>>(result, d_result); \ + typealias ReturnType = vector<T, N>; \ + return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ - DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpx) \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \ { \ - matrix<T, M, N> result; \ - matrix<T, M, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector<T,N>; \ + matrix<T,M,N>.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair<T> dp_elem = __d_##NAME( \ - DifferentialPair<T>(dpx.p[i][j], \ - __slang_noop_cast<T.Differential>(dpx.d[i][j]))); \ - result[i][j] = dp_elem.p; \ - d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + var dpx = diffPair(dpm.p[i], dpm.d[i]); \ + diff[i] = FWD_DIFF_FUNC; \ } \ - return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + return diffPair(NAME(dpm.p), diff); \ } \ __generic<T : __BuiltinFloatingPointType> \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ + typealias ReturnType = T; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ @@ -789,32 +779,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ { \ - vector<T, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair<T> dp_elem = diffPair(dpx.p[i], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i])); \ - d_result[i] = __slang_noop_cast<T>(dp_elem.d); \ - } \ - dpx = diffPair(dpx.p, d_result); \ + typealias ReturnType = vector<T, N>; \ + dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair<matrix<T, M, N>> dpx, matrix<T, M, N>.Differential dOut) \ + inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ { \ - matrix<T, M, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector<T, N>; \ + matrix<T, M, N>.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair<T> dp_elem = diffPair(dpx.p[i][j], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i][j])); \ - d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + var dpx = diffPair(m.p[i], m.d[i]); \ + var dOut = mdOut[i]; \ + diff[i] = BWD_DIFF_FUNC; \ } \ - dpx = diffPair(dpx.p, d_result); \ + m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, T.dmul(DIFF_FUNC, dpx.d), T.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) // Detach and set derivatives to zero __generic<T : IDifferentiable> @@ -824,9 +808,9 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, (dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d)), (T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut))) // Saturate -UNARY_DERIVATIVE_IMPL(saturate, (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dpx.d), (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dOut)) +UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees @@ -849,9 +833,9 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(log, T(1.0) / dpx.p) SIMPLE_UNARY_DERIVATIVE_IMPL(log10, T(1.0) / (dpx.p * T(52.3025850929940456840179914546844))) SIMPLE_UNARY_DERIVATIVE_IMPL(log2, T(1.0) / (dpx.p * T(50.69314718055994530941723212145818))) // Square root -SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, (dpx.p < T(1e-7) ? T(0.0) : T(0.5) / sqrt(dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, T(0.5) / sqrt(max(ReturnType(T(1e-7)), dpx.p))) // Reciprocal -SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, (dpx.p < T(1e-7) ? T(0.0) : T(-1.0) / (dpx.p * dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, T(-1.0) / max(ReturnType(T(1e-7)), dpx.p * dpx.p)) // rsqrt SIMPLE_UNARY_DERIVATIVE_IMPL(rsqrt, T(-0.5) / (dpx.p * sqrt(dpx.p))) // Arc-sin |
