diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-07 11:22:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-07 11:22:32 -0800 |
| commit | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch) | |
| tree | 87e444746f353d69a365380904f3f8caf15fbfec /source/slang/diff.meta.slang | |
| parent | 6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff) | |
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs.
* Add diff implementation matrix versions of binary and ternary intrinsics.
* Add diff impl for legacy intrinsics.
* Fix diagnostics of using non-differentiable function in a diff operator.
* Add diff implementation for `determinant`.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 195 |
1 files changed, 185 insertions, 10 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c303b39d9..54f927816 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ return DifferentialPair<vector<T, N>>(result, d_result); \ } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ + DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \ + { \ + matrix<T, M, N> result; \ + matrix<T, M, N>.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> dp_elem = __d_##NAME( \ + DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \ + DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + } \ + return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ dpx = diffPair(dpx.p, left_d_result); \ dpy = diffPair(dpy.p, right_d_result); \ + } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair<matrix<T, M, N>> dpx, \ + inout DifferentialPair<matrix<T, M, N>> dpy, \ + matrix<T, M, N>.Differential dOut) \ + { \ + matrix<T, M, N>.Differential left_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair<T> right_dp = diffPair(dpy.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, right_dp, __slang_noop_cast<T.Differential>(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \ + right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, right_d_result); \ } #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ @@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve DifferentialPair<vector<T, N>> dpx, \ DifferentialPair<vector<T, N>> dpy, \ DifferentialPair<vector<T, N>> dpz) \ -{ \ + { \ vector<T, N> result; \ vector<T, N>.Differential d_result; \ [ForceUnroll] for (int i = 0; i < N; ++i) \ @@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve } \ return DifferentialPair<vector<T, N>>(result, d_result); \ } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \ + DifferentialPair<matrix<T, M, N>> dpx, \ + DifferentialPair<matrix<T, M, N>> dpy, \ + DifferentialPair<matrix<T, M, N>> dpz) \ + { \ + matrix<T, M, N> result; \ + matrix<T, M, N>.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> dp_elem = __d_##NAME( \ + DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \ + DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j])), \ + DifferentialPair<T>(dpz.p[i][j], __slang_noop_cast<T.Differential>(dpz.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + } \ + return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, left_d_result); \ dpy = diffPair(dpy.p, middle_d_result); \ dpz = diffPair(dpz.p, right_d_result); \ + } \ + __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair<matrix<T, M, N>> dpx, \ + inout DifferentialPair<matrix<T, M, N>> dpy, \ + inout DifferentialPair<matrix<T, M, N>> dpz, \ + matrix<T, M, N>.Differential dOut) \ + { \ + matrix<T, M, N>.Differential left_d_result, middle_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair<T> middle_dp = diffPair(dpy.p[i][j], T.dzero()); \ + DifferentialPair<T> right_dp = diffPair(dpz.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, middle_dp, right_dp, \ + __slang_noop_cast<T.Differential>(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \ + middle_d_result[i][j] = __slang_noop_cast<T>(middle_dp.d); \ + right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, middle_d_result); \ + dpz = diffPair(dpz.p, right_d_result); \ } #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ @@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix __fwd_diff(__sincos_impl)(x, s, c); } -#if 0 -// TODO: this is not working right now since our type system can't resolve -// the overload to `sincos` in `[BackwardDerivativeOf]` attribute. We need to implement -// a proper overload resolver for custom backward derivatives. - __generic<T: __BuiltinFloatingPointType> [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic<T: __BuiltinFloatingPointType, let N : int> [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> @@ -1024,7 +1103,103 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); +} + +// dst (obsolete) +__generic<T : __BuiltinFloatingPointType> +[BackwardDifferentiable] +vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) +{ + vector<T, 4> dest; + dest.x = T(1.0); + dest.y = src0.y * src1.y; + dest.z = src0.z; + dest.w = src1.w; ; + return dest; +} +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(dst)] +[ForceInline] +DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1) +{ + return __fwd_diff(__dst_impl)(src0, src1); +} +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(dst)] +[ForceInline] +void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut) +{ + __bwd_diff(__dst_impl)(src0, src1, dOut); +} + +// Legacy lighting function (obsolete) +__target_intrinsic(hlsl) +[__readNone] +[BackwardDifferentiable] +float4 __lit_impl(float n_dot_l, float n_dot_h, float m) +{ + let ambient = 1.0f; + let diffuse = max(n_dot_l, 0.0f); + let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); + return float4(ambient, diffuse, specular, 1.0f); +} +[ForwardDerivativeOf(lit)] +[ForceInline] +DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m) +{ + return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); +} +[BackwardDerivativeOf(lit)] +[ForceInline] +void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut) +{ + __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); } -#endif
\ No newline at end of file +// Matrix determinant +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDifferentiable] +[__readNone] +T __determinant_impl(matrix<T,N,N> m) +{ + if (N == 1) + return m[0][0]; + else if (N == 2) + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; + else if (N == 3) + { + return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + } + else if (N == 4) + { + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + + return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) + - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + } + return T(0.0); +} +__generic<T : __BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(determinant)] +[ForceInline] +DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m) +{ + return __fwd_diff(__determinant_impl)(m); +} +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(determinant)] +[ForceInline] +void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut) +{ + __bwd_diff(__determinant_impl)(m, dOut); +} |
