diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-28 09:23:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-28 09:23:08 -0700 |
| commit | 638e5fb000d4e242a91e8b653da4a72daec0efda (patch) | |
| tree | cfcd15c1fc6bdee624eb33abac3268241b086dec /source/slang/diff.meta.slang | |
| parent | 16595a8379e9dbfa1845fd72f3531ff3372da3ef (diff) | |
Make tuple types work in autodiff. (#4923)
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a4c468ef7..80aca230a 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1210,6 +1210,31 @@ extension Array<T, N> : IDifferentiable } } +__generic<each T : IDifferentiable> +extension Tuple<T> : IDifferentiable +{ + typealias Differential = Tuple<expand(each T).Differential>; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return makeTuple(expand (each T).dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return makeTuple(expand(each T).dadd(each a, each b)); + } + + __generic<U : __BuiltinRealType> + [__unsafeForceInlineEarly] + static Differential dmul(U a, Differential b) + { + return makeTuple(expand(each T).dmul(a, each b)); + } +} + // Matrix transpose __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] |
