From 499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 30 Jan 2023 19:24:09 -0800 Subject: Make ArrayExpressionType a DeclRefType and define its autodiff extension in stdlib. (#2615) * Allow array parameters in forward diff. * Use type canonicalization instead of coersion. * Reimplement array type. * Fix. * Update test case. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c732d1a5e..adbf8ae48 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -127,6 +127,41 @@ void updatePair(inout DifferentialPair p, T newPrimal, T p = DifferentialPair(newPrimal, newDiff); } +__generic +__intrinsic_op($(kIROp_MakeArrayFromElement)) +Array makeArrayFromElement(T element); + + +__generic +extension Array : IDifferentiable +{ + typedef Array Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return makeArrayFromElement(T.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + Array result; + for (int i = 0; i < N; i++) + result[i] = T.dadd(a[i], b[i]); + return result; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + Array result; + for (int i = 0; i < N; i++) + result[i] = T.dmul(a[i], b[i]); + return result; + } +} + // vector-matrix __generic [ForceInline] -- cgit v1.2.3