summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-08 14:56:20 -0800
committerGitHub <noreply@github.com>2022-12-08 14:56:20 -0800
commit41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (patch)
treec6cde57da4d3415d86d09213936a48d3d26e07e1 /source/slang/diff.meta.slang
parent468bb7ecf65c000c308adae511bf65a1ca4cc412 (diff)
Auto-diff for matrix operations. (#2559)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang32
1 files changed, 32 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a97ab9eaf..248112810 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -110,6 +110,38 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
p = DifferentialPair<T>(newPrimal, newDiff);
}
+// vector-matrix
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
+{
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
+ return DifferentialPair<vector<T,M>>(primal, diff);
+}
+
+// matrix-vector
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[ForwardDerivativeOf(mul)]
+DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right)
+{
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
+ return DifferentialPair<vector<T,N>>(primal, diff);
+}
+
+
+// matrix-matrix
+__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
+[ForwardDerivativeOf(mul)]
+DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left)
+{
+ let primal = mul(right.p, left.p);
+ let diff = mul(right.d, left.p) + mul(right.p, left.d);
+ return DifferentialPair<matrix<T,R,C>>(primal, diff);
+}
#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \
vector<TYPE, COUNT> result; \