diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c732d1a5e..adbf8ae48 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -127,6 +127,41 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T p = DifferentialPair<T>(newPrimal, newDiff); } +__generic<T, let N:int> +__intrinsic_op($(kIROp_MakeArrayFromElement)) +Array<T,N> makeArrayFromElement(T element); + + +__generic<T:IDifferentiable, let N:int> +extension Array<T, N> : IDifferentiable +{ + typedef Array<T.Differential, N> Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return makeArrayFromElement<T.Differential, N>(T.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + Array<T.Differential, N> result; + for (int i = 0; i < N; i++) + result[i] = T.dadd(a[i], b[i]); + return result; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + Array<T.Differential, N> result; + for (int i = 0; i < N; i++) + result[i] = T.dmul(a[i], b[i]); + return result; + } +} + // vector-matrix __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] |
