From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/diff.meta.slang | 133 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 106 insertions(+), 27 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index e604140ae..26fec224c 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -8,18 +8,118 @@ syntax __differentiate_jvp : JVPDerivativeModifier; __attributeTarget(FuncDecl) attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; -//@ public: - - /// Interface to denote types as differentiable. - /// Allows for user-specified differential types as - /// well as automatic generation, for when the associated type - /// hasn't been declared explicitly. +/// Interface to denote types as differentiable. +/// Allows for user-specified differential types as +/// well as automatic generation, for when the associated type +/// hasn't been declared explicitly. +/// Note that the requirements must currently be defined in this exact order +/// since the auto-diff pass relies on the order to grab the struct keys. +/// __magic_type(DifferentiableType) interface IDifferentiable { associatedtype Differential; + + static Differential zero(); + + static Differential dadd(Differential, Differential); + + static Differential dmul(This, Differential); }; +// Add extensions for the standard types +extension float : IDifferentiable +{ + typedef float Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return 0.f; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector : IDifferentiable +{ + typedef vector Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector : IDifferentiable +{ + typedef vector Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector : IDifferentiable +{ + typedef vector Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic @@ -47,24 +147,3 @@ struct __DifferentialPair return p(); } }; - -// Add extensions for the standard types -extension float : IDifferentiable -{ - typedef float Differential; -} - -extension vector : IDifferentiable -{ - typedef vector Differential; -} - -extension vector : IDifferentiable -{ - typedef vector Differential; -} - -extension vector : IDifferentiable -{ - typedef vector Differential; -} -- cgit v1.2.3