summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 19:24:09 -0800
committerGitHub <noreply@github.com>2023-01-30 19:24:09 -0800
commit499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (patch)
tree4c570a36d305c8909d633183694e0d1225f044c2 /source/slang/diff.meta.slang
parent134dd7eb26fc7988ae13559d276cbf337b4b9d27 (diff)
Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615)
* Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang35
1 files changed, 35 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c732d1a5e..adbf8ae48 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -127,6 +127,41 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
p = DifferentialPair<T>(newPrimal, newDiff);
}
+__generic<T, let N:int>
+__intrinsic_op($(kIROp_MakeArrayFromElement))
+Array<T,N> makeArrayFromElement(T element);
+
+
+__generic<T:IDifferentiable, let N:int>
+extension Array<T, N> : IDifferentiable
+{
+ typedef Array<T.Differential, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return makeArrayFromElement<T.Differential, N>(T.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dadd(a[i], b[i]);
+ return result;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ Array<T.Differential, N> result;
+ for (int i = 0; i < N; i++)
+ result[i] = T.dmul(a[i], b[i]);
+ return result;
+ }
+}
+
// vector-matrix
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]