summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang35
1 files changed, 35 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c732d1a5e..adbf8ae48 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -127,6 +127,41 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
p = DifferentialPair<T>(newPrimal, newDiff);
}
+__generic<T, let N:int>
+__intrinsic_op($(kIROp_MakeArrayFromElement))
+Array<T,N> makeArrayFromElement(T element);
+
+
+__generic<T:IDifferentiable, let N:int>
+extension Array<T, N> : IDifferentiable
+{
+ typedef Array<T.Differential, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return makeArrayFromElement<T.Differential, N>(T.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dadd(a[i], b[i]);
+ return result;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dmul(a[i], b[i]);
+ return result;
+ }
+}
+
// vector-matrix
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]