From 257733f328f38a763c8b0c8830ff4c0d34ec9491 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 7 Mar 2023 11:22:32 -0800 Subject: Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688) * Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 195 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 185 insertions(+), 10 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c303b39d9..54f927816 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair>(result, d_result); \ } \ + __generic \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair> __d_##NAME##_matrix( \ + DifferentialPair> dpx, DifferentialPair> dpy) \ + { \ + matrix result; \ + matrix.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair dp_elem = __d_##NAME( \ + DifferentialPair(dpx.p[i][j], __slang_noop_cast(dpx.d[i][j])), \ + DifferentialPair(dpy.p[i][j], __slang_noop_cast(dpy.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result); \ + } \ __generic \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair> dpx, \ + inout DifferentialPair> dpy, \ + matrix.Differential dOut) \ + { \ + matrix.Differential left_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair right_dp = diffPair(dpy.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, right_dp, __slang_noop_cast(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast(left_dp.d); \ + right_d_result[i][j] = __slang_noop_cast(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, right_d_result); \ } #define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \ @@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair> dpx, \ DifferentialPair> dpy, \ DifferentialPair> dpz) \ -{ \ + { \ vector result; \ vector.Differential d_result; \ [ForceUnroll] for (int i = 0; i < N; ++i) \ @@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair>(result, d_result); \ } \ + __generic \ + [ForwardDerivativeOf(NAME)] \ + DifferentialPair> __d_##NAME##_matrix( \ + DifferentialPair> dpx, \ + DifferentialPair> dpy, \ + DifferentialPair> dpz) \ + { \ + matrix result; \ + matrix.Differential d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair dp_elem = __d_##NAME( \ + DifferentialPair(dpx.p[i][j], __slang_noop_cast(dpx.d[i][j])), \ + DifferentialPair(dpy.p[i][j], __slang_noop_cast(dpy.d[i][j])), \ + DifferentialPair(dpz.p[i][j], __slang_noop_cast(dpz.d[i][j]))); \ + result[i][j] = dp_elem.p; \ + d_result[i][j] = __slang_noop_cast(dp_elem.d); \ + } \ + return DifferentialPair>(result, d_result); \ + } \ __generic \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ @@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ + [BackwardDerivativeOf(NAME)] \ + void __d_##NAME##_matrix( \ + inout DifferentialPair> dpx, \ + inout DifferentialPair> dpy, \ + inout DifferentialPair> dpz, \ + matrix.Differential dOut) \ + { \ + matrix.Differential left_d_result, middle_d_result, right_d_result; \ + [ForceUnroll] for (int i = 0; i < M; ++i) \ + [ForceUnroll] for (int j = 0; j < N; ++j) \ + { \ + DifferentialPair left_dp = diffPair(dpx.p[i][j], T.dzero()); \ + DifferentialPair middle_dp = diffPair(dpy.p[i][j], T.dzero()); \ + DifferentialPair right_dp = diffPair(dpz.p[i][j], T.dzero()); \ + __d_##NAME(left_dp, middle_dp, right_dp, \ + __slang_noop_cast(dOut[i][j])); \ + left_d_result[i][j] = __slang_noop_cast(left_dp.d); \ + middle_d_result[i][j] = __slang_noop_cast(middle_dp.d); \ + right_d_result[i][j] = __slang_noop_cast(right_dp.d); \ + } \ + dpx = diffPair(dpx.p, left_d_result); \ + dpy = diffPair(dpy.p, middle_d_result); \ + dpz = diffPair(dpz.p, right_d_result); \ } #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ @@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair> x, out DifferentialPair [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair x, T.Differential dS, T.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic [BackwardDerivativeOf(sincos)] [ForceInline] void __d_sincos(inout DifferentialPair> x, vector.Differential dS, vector.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); } __generic @@ -1024,7 +1103,103 @@ __generic [ForceInline] void __d_sincos(inout DifferentialPair> x, matrix.Differential dS, matrix.Differential dC) { - __bwd_diff(__sincos_impl)(x, s, c); + __bwd_diff(__sincos_impl)(x, dS, dC); +} + +// dst (obsolete) +__generic +[BackwardDifferentiable] +vector __dst_impl(vector src0, vector src1) +{ + vector dest; + dest.x = T(1.0); + dest.y = src0.y * src1.y; + dest.z = src0.z; + dest.w = src1.w; ; + return dest; +} +__generic +[ForwardDerivativeOf(dst)] +[ForceInline] +DifferentialPair> __d_dst(DifferentialPair> src0, DifferentialPair> src1) +{ + return __fwd_diff(__dst_impl)(src0, src1); +} +__generic +[BackwardDerivativeOf(dst)] +[ForceInline] +void __d_dst(inout DifferentialPair> src0, inout DifferentialPair> src1, vector.Differential dOut) +{ + __bwd_diff(__dst_impl)(src0, src1, dOut); +} + +// Legacy lighting function (obsolete) +__target_intrinsic(hlsl) +[__readNone] +[BackwardDifferentiable] +float4 __lit_impl(float n_dot_l, float n_dot_h, float m) +{ + let ambient = 1.0f; + let diffuse = max(n_dot_l, 0.0f); + let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); + return float4(ambient, diffuse, specular, 1.0f); +} +[ForwardDerivativeOf(lit)] +[ForceInline] +DifferentialPair __d_lit(DifferentialPair n_dot_l, DifferentialPair n_dot_h, DifferentialPair m) +{ + return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); +} +[BackwardDerivativeOf(lit)] +[ForceInline] +void __d_lit(inout DifferentialPair n_dot_l, inout DifferentialPair n_dot_h, inout DifferentialPair m, float4 dOut) +{ + __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); } -#endif \ No newline at end of file +// Matrix determinant +__generic +[BackwardDifferentiable] +[__readNone] +T __determinant_impl(matrix m) +{ + if (N == 1) + return m[0][0]; + else if (N == 2) + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; + else if (N == 3) + { + return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); + } + else if (N == 4) + { + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; + T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; + T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; + T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; + T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; + + return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) + - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + } + return T(0.0); +} +__generic +[ForwardDerivativeOf(determinant)] +[ForceInline] +DifferentialPair __determinant_impl(DifferentialPair> m) +{ + return __fwd_diff(__determinant_impl)(m); +} +__generic +[BackwardDerivativeOf(determinant)] +[ForceInline] +void __d_determinant(inout DifferentialPair> m, T.Differential dOut) +{ + __bwd_diff(__determinant_impl)(m, dOut); +} -- cgit v1.2.3