summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-07 11:22:32 -0800
committerGitHub <noreply@github.com>2023-03-07 11:22:32 -0800
commit257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch)
tree87e444746f353d69a365380904f3f8caf15fbfec /source/slang/diff.meta.slang
parent6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff)
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang195
1 files changed, 185 insertions, 10 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c303b39d9..54f927816 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [ForwardDerivativeOf(NAME)] \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
+ DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \
+ { \
+ matrix<T, M, N> result; \
+ matrix<T, M, N>.Differential d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> dp_elem = __d_##NAME( \
+ DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \
+ DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j]))); \
+ result[i][j] = dp_elem.p; \
+ d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ } \
+ return DifferentialPair<matrix<T, M, N>>(result, d_result); \
+ } \
__generic<T : __BuiltinFloatingPointType, let N : int> \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
@@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
dpx = diffPair(dpx.p, left_d_result); \
dpy = diffPair(dpy.p, right_d_result); \
+ } \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDerivativeOf(NAME)] \
+ void __d_##NAME##_matrix( \
+ inout DifferentialPair<matrix<T, M, N>> dpx, \
+ inout DifferentialPair<matrix<T, M, N>> dpy, \
+ matrix<T, M, N>.Differential dOut) \
+ { \
+ matrix<T, M, N>.Differential left_d_result, right_d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \
+ DifferentialPair<T> right_dp = diffPair(dpy.p[i][j], T.dzero()); \
+ __d_##NAME(left_dp, right_dp, __slang_noop_cast<T.Differential>(dOut[i][j])); \
+ left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \
+ right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \
+ } \
+ dpx = diffPair(dpx.p, left_d_result); \
+ dpy = diffPair(dpy.p, right_d_result); \
}
#define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \
@@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
DifferentialPair<vector<T, N>> dpx, \
DifferentialPair<vector<T, N>> dpy, \
DifferentialPair<vector<T, N>> dpz) \
-{ \
+ { \
vector<T, N> result; \
vector<T, N>.Differential d_result; \
[ForceUnroll] for (int i = 0; i < N; ++i) \
@@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [ForwardDerivativeOf(NAME)] \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
+ DifferentialPair<matrix<T, M, N>> dpx, \
+ DifferentialPair<matrix<T, M, N>> dpy, \
+ DifferentialPair<matrix<T, M, N>> dpz) \
+ { \
+ matrix<T, M, N> result; \
+ matrix<T, M, N>.Differential d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> dp_elem = __d_##NAME( \
+ DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \
+ DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j])), \
+ DifferentialPair<T>(dpz.p[i][j], __slang_noop_cast<T.Differential>(dpz.d[i][j]))); \
+ result[i][j] = dp_elem.p; \
+ d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ } \
+ return DifferentialPair<matrix<T, M, N>>(result, d_result); \
+ } \
__generic<T : __BuiltinFloatingPointType, let N : int> \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
@@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, left_d_result); \
dpy = diffPair(dpy.p, middle_d_result); \
dpz = diffPair(dpz.p, right_d_result); \
+ } \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDerivativeOf(NAME)] \
+ void __d_##NAME##_matrix( \
+ inout DifferentialPair<matrix<T, M, N>> dpx, \
+ inout DifferentialPair<matrix<T, M, N>> dpy, \
+ inout DifferentialPair<matrix<T, M, N>> dpz, \
+ matrix<T, M, N>.Differential dOut) \
+ { \
+ matrix<T, M, N>.Differential left_d_result, middle_d_result, right_d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \
+ DifferentialPair<T> middle_dp = diffPair(dpy.p[i][j], T.dzero()); \
+ DifferentialPair<T> right_dp = diffPair(dpz.p[i][j], T.dzero()); \
+ __d_##NAME(left_dp, middle_dp, right_dp, \
+ __slang_noop_cast<T.Differential>(dOut[i][j])); \
+ left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \
+ middle_d_result[i][j] = __slang_noop_cast<T>(middle_dp.d); \
+ right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \
+ } \
+ dpx = diffPair(dpx.p, left_d_result); \
+ dpy = diffPair(dpy.p, middle_d_result); \
+ dpz = diffPair(dpz.p, right_d_result); \
}
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
@@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix
__fwd_diff(__sincos_impl)(x, s, c);
}
-#if 0
-// TODO: this is not working right now since our type system can't resolve
-// the overload to `sincos` in `[BackwardDerivativeOf]` attribute. We need to implement
-// a proper overload resolver for custom backward derivatives.
-
__generic<T: __BuiltinFloatingPointType>
[BackwardDerivativeOf(sincos)]
[ForceInline]
void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
}
__generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDerivativeOf(sincos)]
[ForceInline]
void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
}
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
@@ -1024,7 +1103,103 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
+}
+
+// dst (obsolete)
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
+vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
+{
+ vector<T, 4> dest;
+ dest.x = T(1.0);
+ dest.y = src0.y * src1.y;
+ dest.z = src0.z;
+ dest.w = src1.w; ;
+ return dest;
+}
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(dst)]
+[ForceInline]
+DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1)
+{
+ return __fwd_diff(__dst_impl)(src0, src1);
+}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(dst)]
+[ForceInline]
+void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut)
+{
+ __bwd_diff(__dst_impl)(src0, src1, dOut);
+}
+
+// Legacy lighting function (obsolete)
+__target_intrinsic(hlsl)
+[__readNone]
+[BackwardDifferentiable]
+float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
+{
+ let ambient = 1.0f;
+ let diffuse = max(n_dot_l, 0.0f);
+ let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m));
+ return float4(ambient, diffuse, specular, 1.0f);
+}
+[ForwardDerivativeOf(lit)]
+[ForceInline]
+DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m)
+{
+ return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m);
+}
+[BackwardDerivativeOf(lit)]
+[ForceInline]
+void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut)
+{
+ __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut);
}
-#endif \ No newline at end of file
+// Matrix determinant
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
+[__readNone]
+T __determinant_impl(matrix<T,N,N> m)
+{
+ if (N == 1)
+ return m[0][0];
+ else if (N == 2)
+ return m[0][0] * m[1][1] - m[0][1] * m[1][0];
+ else if (N == 3)
+ {
+ return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
+ - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2])
+ + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
+ }
+ else if (N == 4)
+ {
+ T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
+ T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
+ T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
+ T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
+ T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
+ T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
+
+ return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
+ - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04)
+ + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05)
+ - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05);
+ }
+ return T(0.0);
+}
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[ForwardDerivativeOf(determinant)]
+[ForceInline]
+DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m)
+{
+ return __fwd_diff(__determinant_impl)(m);
+}
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(determinant)]
+[ForceInline]
+void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut)
+{
+ __bwd_diff(__determinant_impl)(m, dOut);
+}