summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-10 13:43:18 -0700
committerGitHub <noreply@github.com>2023-04-10 13:43:18 -0700
commitd82992e30d5985001870e00afdf27091f59464f2 (patch)
tree6ee2c4ce4c98686e5721bda89f274343e366ab4e /source/slang/diff.meta.slang
parentea15647ba6bccb5ac48de5f4b80b8c2769d69b8f (diff)
Cleaner impl of unary stdlib derivative functions. (#2785)
* Cleaner impl of unary stdlib derivative functions. * fixup * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang68
1 files changed, 26 insertions, 42 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 2bdaccee3..cb87156f5 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -740,6 +740,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
[ForwardDerivativeOf(NAME)] \
DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \
{ \
+ typealias ReturnType = T; \
return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
@@ -747,40 +748,29 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \
{ \
- vector<T, N> result; \
- vector<T, N>.Differential d_result; \
- [ForceUnroll] for (int i = 0; i < N; ++i) \
- { \
- DifferentialPair<T> dp_elem = __d_##NAME( \
- DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); \
- result[i] = dp_elem.p; \
- d_result[i] = __slang_noop_cast<T>(dp_elem.d); \
- } \
- return DifferentialPair<vector<T, N>>(result, d_result); \
+ typealias ReturnType = vector<T, N>; \
+ return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
[BackwardDifferentiable] \
[ForwardDerivativeOf(NAME)] \
- DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpx) \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \
{ \
- matrix<T, M, N> result; \
- matrix<T, M, N>.Differential d_result; \
- [ForceUnroll] for (int i = 0; i < M; ++i) \
- [ForceUnroll] for (int j = 0; j < N; ++j) \
+ typealias ReturnType = vector<T,N>; \
+ matrix<T,M,N>.Differential diff; \
+ [ForceUnroll] for (int i = 0; i < M; i++) \
{ \
- DifferentialPair<T> dp_elem = __d_##NAME( \
- DifferentialPair<T>(dpx.p[i][j], \
- __slang_noop_cast<T.Differential>(dpx.d[i][j]))); \
- result[i][j] = dp_elem.p; \
- d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ var dpx = diffPair(dpm.p[i], dpm.d[i]); \
+ diff[i] = FWD_DIFF_FUNC; \
} \
- return DifferentialPair<matrix<T, M, N>>(result, d_result); \
+ return diffPair(NAME(dpm.p), diff); \
} \
__generic<T : __BuiltinFloatingPointType> \
[BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \
{ \
+ typealias ReturnType = T; \
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
@@ -789,32 +779,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
void __d_##NAME##_vector( \
inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
{ \
- vector<T, N>.Differential d_result; \
- [ForceUnroll] for (int i = 0; i < N; ++i) \
- { \
- DifferentialPair<T> dp_elem = diffPair(dpx.p[i], T.dzero()); \
- __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i])); \
- d_result[i] = __slang_noop_cast<T>(dp_elem.d); \
- } \
- dpx = diffPair(dpx.p, d_result); \
+ typealias ReturnType = vector<T, N>; \
+ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
[BackwardDifferentiable] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
- inout DifferentialPair<matrix<T, M, N>> dpx, matrix<T, M, N>.Differential dOut) \
+ inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
{ \
- matrix<T, M, N>.Differential d_result; \
- [ForceUnroll] for (int i = 0; i < M; ++i) \
- [ForceUnroll] for (int j = 0; j < N; ++j) \
+ typealias ReturnType = vector<T, N>; \
+ matrix<T, M, N>.Differential diff; \
+ [ForceUnroll] for (int i = 0; i < M; i++) \
{ \
- DifferentialPair<T> dp_elem = diffPair(dpx.p[i][j], T.dzero()); \
- __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i][j])); \
- d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ var dpx = diffPair(m.p[i], m.d[i]); \
+ var dOut = mdOut[i]; \
+ diff[i] = BWD_DIFF_FUNC; \
} \
- dpx = diffPair(dpx.p, d_result); \
+ m = diffPair(m.p, diff); \
}
-#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, T.dmul(DIFF_FUNC, dpx.d), T.dmul(DIFF_FUNC, dOut))
+#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut))
// Detach and set derivatives to zero
__generic<T : IDifferentiable>
@@ -824,9 +808,9 @@ T detach(T x);
#define SLANG_SQR(x) ((x)*(x))
// Absolute value
-UNARY_DERIVATIVE_IMPL(abs, (dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d)), (T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut)))
+UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut)))
// Saturate
-UNARY_DERIVATIVE_IMPL(saturate, (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dpx.d), (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dOut))
+UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut))
// frac
UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut)
// raidans, degrees
@@ -849,9 +833,9 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(log, T(1.0) / dpx.p)
SIMPLE_UNARY_DERIVATIVE_IMPL(log10, T(1.0) / (dpx.p * T(52.3025850929940456840179914546844)))
SIMPLE_UNARY_DERIVATIVE_IMPL(log2, T(1.0) / (dpx.p * T(50.69314718055994530941723212145818)))
// Square root
-SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, (dpx.p < T(1e-7) ? T(0.0) : T(0.5) / sqrt(dpx.p)))
+SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, T(0.5) / sqrt(max(ReturnType(T(1e-7)), dpx.p)))
// Reciprocal
-SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, (dpx.p < T(1e-7) ? T(0.0) : T(-1.0) / (dpx.p * dpx.p)))
+SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, T(-1.0) / max(ReturnType(T(1e-7)), dpx.p * dpx.p))
// rsqrt
SIMPLE_UNARY_DERIVATIVE_IMPL(rsqrt, T(-0.5) / (dpx.p * sqrt(dpx.p)))
// Arc-sin