summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/diff.meta.slang
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang133
1 files changed, 106 insertions, 27 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index e604140ae..26fec224c 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -8,18 +8,118 @@ syntax __differentiate_jvp : JVPDerivativeModifier;
__attributeTarget(FuncDecl)
attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
-//@ public:
-
- /// Interface to denote types as differentiable.
- /// Allows for user-specified differential types as
- /// well as automatic generation, for when the associated type
- /// hasn't been declared explicitly.
+/// Interface to denote types as differentiable.
+/// Allows for user-specified differential types as
+/// well as automatic generation, for when the associated type
+/// hasn't been declared explicitly.
+/// Note that the requirements must currently be defined in this exact order
+/// since the auto-diff pass relies on the order to grab the struct keys.
+///
__magic_type(DifferentiableType)
interface IDifferentiable
{
associatedtype Differential;
+
+ static Differential zero();
+
+ static Differential dadd(Differential, Differential);
+
+ static Differential dmul(This, Differential);
};
+// Add extensions for the standard types
+extension float : IDifferentiable
+{
+ typedef float Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return 0.f;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 3> : IDifferentiable
+{
+ typedef vector<float, 3> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 3>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 2> : IDifferentiable
+{
+ typedef vector<float, 2> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 2>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 4> : IDifferentiable
+{
+ typedef vector<float, 4> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 4>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
@@ -47,24 +147,3 @@ struct __DifferentialPair
return p();
}
};
-
-// Add extensions for the standard types
-extension float : IDifferentiable
-{
- typedef float Differential;
-}
-
-extension vector<float, 3> : IDifferentiable
-{
- typedef vector<float, 3> Differential;
-}
-
-extension vector<float, 2> : IDifferentiable
-{
- typedef vector<float, 2> Differential;
-}
-
-extension vector<float, 4> : IDifferentiable
-{
- typedef vector<float, 4> Differential;
-}