diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-26 17:15:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-26 17:15:21 -0400 |
| commit | ba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch) | |
| tree | 2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang/diff.meta.slang | |
| parent | b8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff) | |
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)`
- Add AST synthesis support for generic containers
- Refactor relevant tests
* Merge dmul synthesis with dadd and dzero, and disambiguate using an enum
* Fix trailing spaces
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 97 |
1 files changed, 62 insertions, 35 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f2fd8e3b0..3e381e55d 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -374,12 +374,13 @@ extension Array<T, N> : IDifferentiable return result; } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { Array<T.Differential, N> result; for (int i = 0; i < N; i++) - result[i] = T.dmul(a[i], b[i]); + result[i] = T.dmul<U>(a, b[i]); return result; } } @@ -543,8 +544,8 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; - d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast<T.Differential>(dpy.d[i]))); - d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); + d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpx.p[i] * dpy.d[i])); + d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpy.p[i] * dpx.d[i])); } return DifferentialPair<T>(result, d_result); } @@ -797,7 +798,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ @@ -805,7 +806,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ @@ -813,21 +814,21 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \ { \ - typealias ReturnType = vector<T,N>; \ - matrix<T,M,N>.Differential diff; \ + typealias ReturnType = vector<T, N>; \ + matrix<T, M, N>.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(dpm.p[i], dpm.d[i]); \ - diff[i] = FWD_DIFF_FUNC; \ + diff[i] = __slang_noop_cast<vector<T, N>>(FWD_DIFF_FUNC); \ } \ return diffPair(NAME(dpm.p), diff); \ } \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ @@ -835,31 +836,57 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ - inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ + inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ { \ typealias ReturnType = vector<T, N>; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ + inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ { \ typealias ReturnType = vector<T, N>; \ matrix<T, M, N>.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(m.p[i], m.d[i]); \ - var dOut = mdOut[i]; \ + var dOut = __slang_noop_cast<vector<T, N>>(mdOut[i]); \ diff[i] = BWD_DIFF_FUNC; \ } \ m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, __mul_p_d(DIFF_FUNC, dpx.d), __mul_p_d(DIFF_FUNC, dOut)) + +/// Element-wise multiply for scalars and vectors for (T, T.Differential) +__generic<T : __BuiltinFloatingPointType> +[__unsafeForceInlineEarly] +[Differentiable] +T.Differential __mul_p_d(T a, T.Differential b) +{ + return __slang_noop_cast<T.Differential>(a * __slang_noop_cast<T>(b)); +} + +__generic<T : __BuiltinFloatingPointType> +[__unsafeForceInlineEarly] +[Differentiable] +T __mul_p_d(T a, T b) +{ + return (a * b); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[__unsafeForceInlineEarly] +[Differentiable] +vector<T, N> __mul_p_d(vector<T, N> a, vector<T, N> b) +{ + return a * b; +} + /// Detach and set derivatives to zero. __generic<T : IDifferentiable> @@ -871,14 +898,14 @@ T detach(T x); #define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) +UNARY_DERIVATIVE_IMPL(abs, (__mul_p_d(SLANG_SIGN(dpx.p), (dpx.d))), (__mul_p_d(SLANG_SIGN(dpx.p), (dOut)))) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees -SIMPLE_UNARY_DERIVATIVE_IMPL(radians, T(0.01745329251994329576923690768489)) -SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, T(57.295779513082320876798154814105)) +SIMPLE_UNARY_DERIVATIVE_IMPL(radians, ReturnType(T(0.01745329251994329576923690768489))) +SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, ReturnType(T(57.295779513082320876798154814105))) // Exponent SIMPLE_UNARY_DERIVATIVE_IMPL(exp, exp(dpx.p)) SIMPLE_UNARY_DERIVATIVE_IMPL(exp2, exp2(dpx.p)* T(50.69314718055994530941723212145818)) @@ -915,8 +942,8 @@ __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(atan2)] DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) { - T.Differential dx = T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); - T.Differential dy = T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); + T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); + T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); return DifferentialPair<T>( atan2(dpy.p, dpx.p), T.dadd(dx, dy)); @@ -928,8 +955,8 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); - dpy = diffPair(dpy.p, T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); + dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); + dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); } VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) @@ -968,8 +995,8 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) } T val = pow(dpx.p, dpy.p); - T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); - T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); + T.Differential d1 = __mul_p_d((val * log(dpx.p)), dpy.d); + T.Differential d2 = __mul_p_d((val * dpy.p / dpx.p), dpx.d); return DifferentialPair<T>( val, T.dadd(d1, d2) @@ -993,10 +1020,10 @@ void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Dif T val = pow(dpx.p, dpy.p); dpx = diffPair( dpx.p, - T.dmul(val * dpy.p / dpx.p, dOut)); + (__mul_p_d((val * dpy.p / dpx.p), dOut))); dpy = diffPair( dpy.p, - T.dmul(val * log(dpx.p), dOut)); + (__mul_p_d((val * log(dpx.p)), dOut))); } } @@ -1061,7 +1088,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D { return DifferentialPair<T>( lerp(dpx.p, dpy.p, dps.p), - T.dadd(T.dadd(T.dmul((T(1.0) - dps.p), dpx.d), T.dmul(dps.p, dpy.d)), T.dmul(dpy.p - dpx.p, dps.d)) + T.dadd(T.dadd(__mul_p_d((T(1.0) - dps.p), dpx.d), __mul_p_d(dps.p, dpy.d)), __mul_p_d((dpy.p - dpx.p), dps.d)) ); } __generic<T : __BuiltinFloatingPointType> @@ -1070,9 +1097,9 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(lerp)] void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(T(1.0) - dps.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dps.p, dOut)); - dps = diffPair(dpy.p, T.dmul((dpy.p - dpx.p), dOut)); + dpx = diffPair(dpx.p, __mul_p_d((T(1.0) - dps.p), dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dps.p, dOut)); + dps = diffPair(dpy.p, __mul_p_d((dpy.p - dpx.p), dOut)); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp) @@ -1175,7 +1202,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di { return DifferentialPair<T>( mad(dpx.p, dpy.p, dpz.p), - T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d)); + T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d)); } __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] @@ -1183,8 +1210,8 @@ __generic<T : __BuiltinFloatingPointType> [PreferRecompute] void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dpx.p, dOut)); + dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut)); dpz = diffPair(dpz.p, dOut); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad) |
