summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang133
-rw-r--r--source/slang/slang-ast-builder.cpp17
-rw-r--r--source/slang/slang-ast-builder.h4
-rw-r--r--source/slang/slang-ast-decl.h29
-rw-r--r--source/slang/slang-ast-expr.h24
-rw-r--r--source/slang/slang-ast-modifier.h1
-rw-r--r--source/slang/slang-ast-val.cpp85
-rw-r--r--source/slang/slang-check-conformance.cpp139
-rw-r--r--source/slang/slang-check-constraint.cpp43
-rw-r--r--source/slang/slang-check-conversion.cpp4
-rw-r--r--source/slang/slang-check-decl.cpp205
-rw-r--r--source/slang/slang-check-expr.cpp278
-rw-r--r--source/slang/slang-check-impl.h149
-rw-r--r--source/slang/slang-check-overload.cpp142
-rw-r--r--source/slang/slang-check-type.cpp13
-rw-r--r--source/slang/slang-emit.cpp25
-rw-r--r--source/slang/slang-ir-dce.cpp7
-rw-r--r--source/slang/slang-ir-diff-call.cpp43
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp1478
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h47
-rw-r--r--source/slang/slang-ir-link.cpp35
-rw-r--r--source/slang/slang-ir-ssa.cpp9
-rw-r--r--source/slang/slang-ir.cpp150
-rw-r--r--source/slang/slang-ir.h6
-rw-r--r--source/slang/slang-lower-to-ir.cpp111
26 files changed, 2725 insertions, 460 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index e604140ae..26fec224c 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -8,18 +8,118 @@ syntax __differentiate_jvp : JVPDerivativeModifier;
__attributeTarget(FuncDecl)
attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
-//@ public:
-
- /// Interface to denote types as differentiable.
- /// Allows for user-specified differential types as
- /// well as automatic generation, for when the associated type
- /// hasn't been declared explicitly.
+/// Interface to denote types as differentiable.
+/// Allows for user-specified differential types as
+/// well as automatic generation, for when the associated type
+/// hasn't been declared explicitly.
+/// Note that the requirements must currently be defined in this exact order
+/// since the auto-diff pass relies on the order to grab the struct keys.
+///
__magic_type(DifferentiableType)
interface IDifferentiable
{
associatedtype Differential;
+
+ static Differential zero();
+
+ static Differential dadd(Differential, Differential);
+
+ static Differential dmul(This, Differential);
};
+// Add extensions for the standard types
+extension float : IDifferentiable
+{
+ typedef float Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return 0.f;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 3> : IDifferentiable
+{
+ typedef vector<float, 3> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 3>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 2> : IDifferentiable
+{
+ typedef vector<float, 2> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 2>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+extension vector<float, 4> : IDifferentiable
+{
+ typedef vector<float, 4> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ return vector<float, 4>(0.f);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
@@ -47,24 +147,3 @@ struct __DifferentialPair
return p();
}
};
-
-// Add extensions for the standard types
-extension float : IDifferentiable
-{
- typedef float Differential;
-}
-
-extension vector<float, 3> : IDifferentiable
-{
- typedef vector<float, 3> Differential;
-}
-
-extension vector<float, 2> : IDifferentiable
-{
- typedef vector<float, 2> Differential;
-}
-
-extension vector<float, 4> : IDifferentiable
-{
- typedef vector<float, 4> Differential;
-}
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index f8c208ac1..f6c550d69 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -179,6 +179,18 @@ Decl* SharedASTBuilder::findMagicDecl(const String& name)
return m_magicDecls[name].GetValue();
}
+Decl* SharedASTBuilder::tryFindMagicDecl(const String& name)
+{
+ if (m_magicDecls.ContainsKey(name))
+ {
+ return m_magicDecls[name].GetValue();
+ }
+ else
+ {
+ return nullptr;
+ }
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
@@ -308,6 +320,11 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface()
return declRef;
}
+bool ASTBuilder::isDifferentiableInterfaceAvailable()
+{
+ return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);
+}
+
DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
DeclRef<Decl> declRef;
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 91fe63c88..e4ea872a0 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -45,6 +45,8 @@ public:
// Look up a magic declaration by its name
Decl* findMagicDecl(String const& name);
+ Decl* tryFindMagicDecl(String const& name);
+
/// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session.
NamePool* getNamePool() { return m_namePool; }
@@ -328,6 +330,8 @@ public:
DeclRef<InterfaceDecl> getDifferentiableInterface();
+ bool isDifferentiableInterfaceAvailable();
+
DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg);
Type* getAndType(Type* left, Type* right);
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 147bc7d22..07cfe6a0c 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -494,6 +494,35 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
+// A declaration to hold differentiable type conformances generated during
+// the semantic checking phase.
+//
+class DifferentiableTypeDictionary : public ContainerDecl
+{
+ SLANG_AST_CLASS(DifferentiableTypeDictionary);
+};
+
+// A declaration to hold differentiable type conformances generated during
+// the semantic checking phase.
+//
+class DifferentiableTypeDictionaryItem : public Decl
+{
+ SLANG_AST_CLASS(DifferentiableTypeDictionaryItem);
+
+ DeclRefType* baseType;
+ SubtypeWitness* confWitness;
+};
+
+// A declaration that references another dictionary (generally from another module)
+// Used to tell the IR lowering pass to process the referenced dictionary.
+//
+class DifferentiableTypeDictionaryImportItem : public Decl
+{
+ SLANG_AST_CLASS(DifferentiableTypeDictionaryImportItem);
+
+ DeclRef<DifferentiableTypeDictionary> dictionaryRef;
+};
+
bool isInterfaceRequirement(Decl* decl);
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 13d687da0..e0a55cc29 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -38,6 +38,18 @@ class VarExpr : public DeclRefExpr
SLANG_AST_CLASS(VarExpr)
};
+class DifferentiableDeclRefExpr : public Expr
+{
+ SLANG_AST_CLASS(DifferentiableDeclRefExpr)
+
+ // Inner decl ref expr that references a differentiable expression.
+ Expr* inner = nullptr;
+
+ // Information on getters and setters if available.
+ Expr* setterExpr = nullptr;
+ Expr* getterExpr = nullptr;
+};
+
// An expression that references an overloaded set of declarations
// having the same name.
class OverloadedExpr : public Expr
@@ -428,13 +440,21 @@ class OpenRefExpr : public Expr
Expr* innerExpr = nullptr;
};
+ /// Base class for higher-order function application
+ /// Eg: foo(fn) where fn is a function expression.
+ ///
+class HigherOrderInvokeExpr : public Expr
+{
+ SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr)
+ Expr* baseFunction;
+};
+
/// An expression of the form `__jvp(fn)` to access the
/// forward-mode derivative version of the function `fn`
///
-class JVPDifferentiateExpr: public Expr
+class JVPDifferentiateExpr: public HigherOrderInvokeExpr
{
SLANG_AST_CLASS(JVPDifferentiateExpr)
- Expr* baseFunction;
};
/// A type expression of the form `__TaggedUnion(A, ...)`.
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 8868b7a1d..8230f481e 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -549,6 +549,7 @@ class AttributeTargetModifier : public Modifier
SyntaxClass<NodeBase> syntaxClass;
};
+
// Base class for checked and unchecked `[name(arg0, ...)]` style attribute.
class AttributeBase : public Modifier
{
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 377dee350..a8ceaa716 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -283,7 +283,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
{
if (constraintParam == declRef.getDecl())
{
- found = true;
+ found = true;
break;
}
index++;
@@ -443,6 +443,66 @@ HashCode TransitiveSubtypeWitness::_getHashCodeOverride()
return hash;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
+{
+ int diff = 0;
+
+ Type* substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
+ Type* substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff));
+
+ // If nothing changed, then we can bail out early.
+ if (!diff)
+ return this;
+
+ // Something changes, so let the caller know.
+ (*ioDiff)++;
+
+ // If the substituted witness is a conjunction, break it apart, but it's important to replace the
+ // sub and super types with the current ones since the conjunction witness will have an
+ //
+ if (auto substConjunctionWitness = as<ConjunctionSubtypeWitness>(substWitness))
+ {
+ if (indexInConjunction == 0)
+ {
+ auto witness = as<SubtypeWitness>(substConjunctionWitness->leftWitness);
+ SLANG_ASSERT(witness);
+
+ witness->sub = substSub;
+ witness->sup = substSup;
+
+ return witness;
+ }
+ else if (indexInConjunction == 1)
+ {
+ auto witness = as<SubtypeWitness>(substConjunctionWitness->rightWitness);
+ SLANG_ASSERT(witness);
+
+ witness->sub = substSub;
+ witness->sup = substSup;
+
+ return witness;
+ }
+ else
+ {
+ SLANG_UNIMPLEMENTED_X("conjunction index must be 0 or 1");
+ }
+ }
+ else
+ {
+ // In the simple case, we just construct a new conjunction subtype
+ // witness.
+ ExtractFromConjunctionSubtypeWitness* result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
+ result->sub = substSub;
+ result->sup = substSup;
+ result->conjunctionWitness = substWitness;
+ result->indexInConjunction = indexInConjunction;
+ return result;
+ }
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val)
@@ -637,29 +697,6 @@ HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride()
return combineHash(indexInConjunction, conjunctionWitness ? conjunctionWitness->getHashCode() : 0);
}
-Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
- Val* newConjunctionWitness = nullptr;
-
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
-
- if (this->conjunctionWitness)
- newConjunctionWitness = conjunctionWitness->substituteImpl(astBuilder, subst, &diff);
- *ioDiff += diff;
-
- if (diff)
- {
- auto result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
- result->conjunctionWitness = newConjunctionWitness;
- result->sub = substSub;
- result->sup = substSup;
- return result;
- }
- return this;
-}
-
// ModifierVal
bool ModifierVal::_equalsValOverride(Val* val)
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index e0c1f3702..cf362dcdd 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -18,6 +18,62 @@ namespace Slang
return witness;
}
+
+ Val* simplifyWitness(ASTBuilder* builder, Val* witness)
+ {
+ if (auto extractFromConjunction = as<ExtractFromConjunctionSubtypeWitness>(witness))
+ {
+ auto simplWitness = simplifyWitness(builder, extractFromConjunction->conjunctionWitness);
+ if (auto conjunction = as<ConjunctionSubtypeWitness>(simplWitness))
+ {
+ auto index = extractFromConjunction->indexInConjunction;
+ SLANG_ASSERT(index == 0 || index == 1);
+ if (index == 0)
+ return conjunction->leftWitness;
+ else
+ return conjunction->rightWitness;
+ }
+
+ ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create<ExtractFromConjunctionSubtypeWitness>();
+ simplExtractFromConjunction->sub = extractFromConjunction->sub;
+ simplExtractFromConjunction->sup = extractFromConjunction->sup;
+ simplExtractFromConjunction->indexInConjunction = extractFromConjunction->indexInConjunction;
+ simplExtractFromConjunction->conjunctionWitness = as<SubtypeWitness>(simplWitness);
+
+ return simplExtractFromConjunction;
+ }
+ else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness))
+ {
+ auto simplConjunctionWitness = builder->create<ConjunctionSubtypeWitness>();
+ simplConjunctionWitness->leftWitness = as<SubtypeWitness>(simplifyWitness(builder, conjunctionWitness->leftWitness));
+ simplConjunctionWitness->rightWitness = as<SubtypeWitness>(simplifyWitness(builder, conjunctionWitness->rightWitness));
+ simplConjunctionWitness->sub = conjunctionWitness->sub;
+ simplConjunctionWitness->sup = conjunctionWitness->sup;
+
+ return simplConjunctionWitness;
+ }
+ else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness))
+ {
+ TransitiveSubtypeWitness* simplTransitiveWitness = builder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
+ transitiveWitness->sub,
+ transitiveWitness->sup,
+ transitiveWitness->midToSup);
+
+ simplTransitiveWitness->sub = transitiveWitness->sub;
+ simplTransitiveWitness->sup = transitiveWitness->sup;
+ simplTransitiveWitness->midToSup = as<SubtypeWitness>(simplifyWitness(builder, transitiveWitness->midToSup));
+ simplTransitiveWitness->subToMid = as<SubtypeWitness>(simplifyWitness(builder, transitiveWitness->subToMid));
+
+ return simplTransitiveWitness;
+ }
+ else
+ {
+ // TODO: Add other cases.
+ return witness;
+ }
+ }
+
+
Val* SemanticsVisitor::createTypeWitness(
Type* subType,
DeclRef<AggTypeDecl> superTypeDeclRef,
@@ -70,7 +126,7 @@ namespace Slang
// As long as there is more than one breadcrumb, we
// need to be creating transitive witnesses.
- while(bb->prev)
+ while (bb->prev)
{
// On the first iteration when processing the list
// above, the breadcrumb would be for `{ C : D }`,
@@ -83,19 +139,42 @@ namespace Slang
// where `[...]` represents the "hole" we leave
// open to fill in next.
//
- DeclaredSubtypeWitness* declaredWitness =
- m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
- bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions);
+ if (bb->flavor == TypeWitnessBreadcrumb::Flavor::DeclFlavor)
+ {
+ DeclaredSubtypeWitness* declaredWitness =
+ m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions);
+
+ TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness);
+ transitiveWitness->sub = subType;
+ transitiveWitness->sup = bb->sup;
+ transitiveWitness->midToSup = declaredWitness;
+
+ // Fill in the current hole, and then set the
+ // hole to point into the node we just created.
+ *link = transitiveWitness;
+ link = &transitiveWitness->subToMid;
+ }
+ else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor)
+ {
+ ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
+ extractWitness->sub = subType;
+ extractWitness->sup = bb->sup;
+ extractWitness->indexInConjunction = 0;
- TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness);
- transitiveWitness->sub = subType;
- transitiveWitness->sup = bb->sup;
- transitiveWitness->midToSup = declaredWitness;
+ *link = extractWitness;
+ link = (SubtypeWitness**) &extractWitness->conjunctionWitness;
+ }
+ else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor)
+ {
+ ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
+ extractWitness->sub = subType;
+ extractWitness->sup = bb->sup;
+ extractWitness->indexInConjunction = 1;
- // Fill in the current hole, and then set the
- // hole to point into the node we just created.
- *link = transitiveWitness;
- link = &transitiveWitness->subToMid;
+ *link = extractWitness;
+ link = (SubtypeWitness**) &extractWitness->conjunctionWitness;
+ }
// Move on with the list.
bb = bb->prev;
@@ -108,9 +187,14 @@ namespace Slang
DeclaredSubtypeWitness* declaredWitness = createSimpleSubtypeWitness(bb);
*link = declaredWitness;
+ // Simplify witnesses of the form ExtractFromConjunction(ConjunctionWitness(...))
+ // TODO: At some point, we need a more robust way of checking that two witnesses are in-fact 'equal'.
+ // In the meantime, this step should suffice.
+
+
// We now know that our original `witness` variable has been
// filled in, and there are no other holes.
- return witness;
+ return simplifyWitness(m_astBuilder, witness);
}
bool SemanticsVisitor::isInterfaceSafeForTaggedUnion(
@@ -379,6 +463,35 @@ namespace Slang
}
return true;
}
+ else if (auto andType = as<AndType>(subType))
+ {
+ // (L & R) is a subtype of T if either L or R is a subtype of T.
+ // Note that in this method T is explicitly a DeclRef and so cannot be a conjunction itself.
+ //
+ TypeWitnessBreadcrumb leftBreadcrumb;
+ leftBreadcrumb.prev = inBreadcrumbs;
+ leftBreadcrumb.sub = andType;
+ leftBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef);
+ leftBreadcrumb.declRef = makeDeclRef((Decl*)nullptr);
+ leftBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor;
+
+ if(_isDeclaredSubtype(originalSubType, andType->left, superTypeDeclRef, outWitness, &leftBreadcrumb))
+ {
+ return true;
+ }
+
+ TypeWitnessBreadcrumb rightBreadcrumb;
+ rightBreadcrumb.prev = inBreadcrumbs;
+ rightBreadcrumb.sub = andType;
+ rightBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef);
+ rightBreadcrumb.declRef = makeDeclRef((Decl*)nullptr);
+ rightBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor;
+
+ if(_isDeclaredSubtype(originalSubType, andType->right, superTypeDeclRef, outWitness, &rightBreadcrumb))
+ {
+ return true;
+ }
+ }
// default is failure
return false;
}
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 24cedd7d5..f96b5a484 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -564,6 +564,19 @@ namespace Slang
}
}
+ // Two subtype witnesses can be unified if they exist (non-null) and
+ // prove that some pair of types are subtypes of types that can be unified.
+ //
+ if (auto fstWit = as<SubtypeWitness>(fst))
+ {
+ if (auto sndWit = as<SubtypeWitness>(snd))
+ {
+ return TryUnifyTypes(constraints,
+ fstWit->sup,
+ sndWit->sup);
+ }
+ }
+
SLANG_UNIMPLEMENTED_X("value unification case");
// default: fail
@@ -725,17 +738,29 @@ namespace Slang
bool SemanticsVisitor::TryUnifyConjunctionType(
ConstraintSystem& constraints,
- AndType* fst,
+ Type* fst,
Type* snd)
{
- // Unifying a type `T` with `A & B` amounts to unifying
- // `T` with `A` and also `T` with `B`.
+ // Unifying a type `A & B` with `T` amounts to unifying
+ // `A` with `T` and also `B` with `T` while
+ // unifying a type `T` with `A & B` amounts to either
+ // unifying `T` with `A` or `T` with `B`
//
// If either unification is impossible, then the full
// case is also impossible.
//
- return TryUnifyTypes(constraints, fst->left, snd)
- && TryUnifyTypes(constraints, fst->right, snd);
+ if (auto fstAndType = as<AndType>(fst))
+ {
+ return TryUnifyTypes(constraints, fstAndType->left, snd)
+ && TryUnifyTypes(constraints, fstAndType->right, snd);
+ }
+ else if (auto sndAndType = as<AndType>(snd))
+ {
+ return TryUnifyTypes(constraints, fst, sndAndType->left)
+ || TryUnifyTypes(constraints, fst, sndAndType->right);
+ }
+ else
+ return false;
}
bool SemanticsVisitor::TryUnifyTypes(
@@ -762,13 +787,9 @@ namespace Slang
// a conjunction directly, and will instead find all of the
// "leaf" types we need to constrain it to.
//
- if( auto fstAndType = as<AndType>(fst) )
- {
- return TryUnifyConjunctionType(constraints, fstAndType, snd);
- }
- if( auto sndAndType = as<AndType>(snd) )
+ if (as<AndType>(fst) || as<AndType>(snd))
{
- return TryUnifyConjunctionType(constraints, sndAndType, fst);
+ return TryUnifyConjunctionType(constraints, fst, snd);
}
// A generic parameter type can unify with anything.
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 2f5447ffb..e56d63f91 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1168,6 +1168,10 @@ namespace Slang
m_astBuilder->getErrorType(),
fromExpr);
}
+
+ // If we coerced to a differentiable type, log it.
+ maybeRegisterDifferentiableType(m_astBuilder, expr->type);
+
return expr;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b18e1c4da..2d6e20622 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -835,7 +835,15 @@ namespace Slang
// If `decl` is a container, then we want to ensure its children.
if(auto containerDecl = as<ContainerDecl>(decl))
- {
+ {
+ bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr);
+ if (trackDiffTypes)
+ {
+ // Add a context to track differentiable types.
+ DifferentiableTypeSemanticContext subDiffTypeContext;
+ visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext);
+ }
+
// NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here,
// because the visitor may add to `members` whilst iteration takes place, invalidating the iterator
// and likely a crash.
@@ -857,6 +865,21 @@ namespace Slang
_ensureAllDeclsRec(visitor, childDecl, state);
}
+
+ if (trackDiffTypes)
+ {
+ auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext();
+
+ // If there were any differentiable types used in differentiable
+ // methods, generate a dictionary with the required info.
+ //
+ if (subDiffTypeContext->isDictionaryRequired())
+ {
+ auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder());
+ diffTypeDict->parentDecl = containerDecl;
+ containerDecl->members.add(diffTypeDict);
+ }
+ }
}
// Note: the "inner" declaration of a `GenericDecl` is currently
@@ -1234,6 +1257,49 @@ namespace Slang
}
}
+ void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
+ {
+ // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any
+ // request to check differentiable types.
+ //
+ if (!m_astBuilder->isDifferentiableInterfaceAvailable())
+ return;
+
+ auto diffInterface = m_astBuilder->getDifferentiableInterface();
+
+ DeclRefType* type = nullptr;
+
+ if (auto extensionDecl = as<ExtensionDecl>(decl))
+ {
+ // If this is an extension, use the provided target type.
+ type = as<DeclRefType>(extensionDecl->targetType.type);
+ }
+ else
+ {
+ // If this is a type declaration, create a decl ref without
+ // any substitutions.
+ //
+ auto declRef = makeDeclRef(decl);
+
+ // TODO: Strip substitutions from the declreftype
+ type = DeclRefType::create(m_astBuilder, declRef);
+ }
+
+ // Skip if the declaration is the interface itself.
+ if (type->declRef == diffInterface)
+ return;
+
+ // If the DeclRefType conforms to IDifferentiable, register it with the top-level
+ // context.
+ //
+ if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface)))
+ {
+ // TODO: Temporarily disabled to move to new system. Fix later.
+ // context->registerDifferentiableType(type, witness);
+ }
+
+ }
+
void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
{
// TODO: are there any other validations we can do at this point?
@@ -1287,6 +1353,23 @@ namespace Slang
ensureDecl(constraint, DeclCheckState::ReadyForReference);
}
}
+
+ // TODO(sai): Is this the right checking stage to be doing this?
+ DifferentiableTypeSemanticContext diffTypeContext;
+
+ for (Index i = 0; i < members.getCount(); ++i)
+ {
+ Decl* m = members[i];
+
+ if (auto typeParam = as<GenericTypeParamDecl>(m))
+ {
+ tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext);
+ }
+ }
+
+ auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder);
+ diffTypeDictionaryNode->parentDecl = genericDecl;
+ genericDecl->members.add(diffTypeDictionaryNode);
}
void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
@@ -1322,6 +1405,7 @@ namespace Slang
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl)
{
checkAggTypeConformance(aggTypeDecl);
+ tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext());
}
// Conformances can also come via `extension` declarations, and
@@ -1330,6 +1414,7 @@ namespace Slang
void visitExtensionDecl(ExtensionDecl* extensionDecl)
{
checkExtensionConformance(extensionDecl);
+ tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext());
}
};
@@ -1486,6 +1571,32 @@ namespace Slang
// Furthermore, because a fully checked function will have checked
// its body, this also means that all function bodies and the
// declarations they contain should be fully checked.
+
+ // Generate a dictionary node to hold information about all
+ // available differentiable types in scope (including imports and stdlib)
+ //
+ if (getShared()->getDiffTypeContext()->isDictionaryRequired())
+ finishDifferentiableTypeDictionary(moduleDecl);
+ }
+
+ void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl)
+ {
+ // Grab the differentiable type information from imported modules.
+ for(auto importedModule : getShared()->importedModulesList)
+ {
+ this->getShared()->getDiffTypeContext()->addImportedModule(importedModule);
+ }
+
+ // Grad the differentiable type information from the standard library modules.
+ for (auto stdLibModule : this->getSession()->stdlibModules)
+ {
+ this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl());
+ }
+
+ auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder);
+ diffTypeDictNode->parentDecl = moduleDecl;
+
+ moduleDecl->members.add(diffTypeDictNode);
}
bool SemanticsVisitor::doesSignatureMatchRequirement(
@@ -4292,7 +4403,23 @@ namespace Slang
nullptr);
args.add(val);
}
- // TODO: need to handle constraints here?
+ }
+
+ // Add defaults for constraint parameters.
+ for (auto dd : genericDecl->members)
+ {
+ if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd))
+ {
+ // Convert the constraint to an appropriate witness.
+ auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup);
+
+ // Must be non-null since we know there's a constraint. If null, something is
+ // very wrong.
+ //
+ SLANG_ASSERT(witness);
+
+ args.add(witness);
+ }
}
GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr);
return subst;
@@ -4725,6 +4852,11 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
+ if (decl->findModifier<JVPDerivativeModifier>())
+ {
+ this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
+ }
+
for(auto paramDecl : decl->getParameters())
{
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
@@ -5594,6 +5726,75 @@ namespace Slang
m_candidateExtensionListsBuilt = false;
m_mapTypeDeclToCandidateExtensions.Clear();
}
+
+ void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
+ {
+ // Need to generate a type dictionary since we have a declaration that works with
+ // a differentiable type.
+ //
+ this->requireDifferentiableTypeDictionary();
+
+ m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness);
+ }
+
+ List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList()
+ {
+ List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances;
+ for (auto entry : m_mapTypeToIDifferentiableWitness)
+ {
+ diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value));
+ }
+
+ return diffConformances;
+ }
+
+ DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode(
+ ASTBuilder* builder)
+ {
+ auto dictionary = builder->create<DifferentiableTypeDictionary>();
+
+ for (auto item : m_mapTypeToIDifferentiableWitness)
+ {
+ auto entry = builder->create<DifferentiableTypeDictionaryItem>();
+ entry->baseType = item.Key.type;
+ entry->confWitness = item.Value;
+ entry->parentDecl = dictionary;
+
+ dictionary->members.add(entry);
+ }
+
+ for (auto item : m_importedDictionaries)
+ {
+ auto entry = builder->create<DifferentiableTypeDictionaryImportItem>();
+ entry->dictionaryRef = item;
+ entry->parentDecl = dictionary;
+
+ dictionary->members.add(entry);
+ }
+
+ return dictionary;
+ }
+
+ void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl)
+ {
+ // TODO: This is a terribly slow way to find the diff type dictionary.
+ // Switch to lookUp() when possible (this might involve naming the dictionary something)
+ //
+ for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>())
+ {
+ m_importedDictionaries.add(makeDeclRef(diffTypeDict));
+ }
+ }
+
+ void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary()
+ {
+ this->m_isTypeDictionaryRequired = true;
+ }
+
+ bool DifferentiableTypeSemanticContext::isDictionaryRequired()
+ {
+ return this->m_isTypeDictionaryRequired;
+ }
void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl)
{
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f1ccddf15..745532c27 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -719,8 +719,219 @@ namespace Slang
return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink());
}
+ Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type)
+ {
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ return builder->getPtrType(
+ _getDifferential(builder, ptrType->getValueType()),
+ ptrType->getClassInfo().m_name);
+ }
+ else if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ return builder->getArrayType(
+ _getDifferential(builder, arrayType->baseType),
+ arrayType->arrayLength);
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface())))
+ {
+ auto diffTypeLookupResult = lookUpMember(
+ getASTBuilder(),
+ this,
+ getName("Differential"),
+ type,
+ Slang::LookupMask::type,
+ Slang::LookupOptions::None);
+
+ diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult);
+
+ if (!diffTypeLookupResult.isValid())
+ {
+ // Diagnose no 'Differential' member.
+ getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential"));
+ }
+ else if (diffTypeLookupResult.isOverloaded())
+ {
+ SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported");
+ }
+ else
+ {
+ SharedTypeExpr* baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ baseTypeExpr->type.type = m_astBuilder->getTypeType(type);
+
+ auto diffTypeExpr = ConstructLookupResultExpr(
+ diffTypeLookupResult.item,
+ baseTypeExpr,
+ declRefType->declRef.getLoc(),
+ baseTypeExpr);
+
+ return ExtractTypeFromTypeRepr(diffTypeExpr);
+ }
+ }
+ }
+
+ return nullptr;
+ }
+
+ void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
+ {
+ if (!builder->isDifferentiableInterfaceAvailable())
+ {
+ return;
+ }
+
+ // Check for special cases such as PtrTypeBase<T> or Array<T>
+ // This could potentially be handled later by simply defining extensions
+ // for Ptr<T:IDifferentiable> etc..
+ //
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ maybeRegisterDifferentiableType(builder, ptrType->getValueType());
+ return;
+ }
+
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ maybeRegisterDifferentiableType(builder, arrayType->baseType);
+ return;
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
+ {
+ auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
+ diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness);
+ }
+
+ return;
+ }
+ }
+
+ Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm)
+ {
+ // Check that member lookups on differentiable types have appropriate differential
+ // getters and setters.
+ if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
+ {
+
+ // Check if we have a parent container. If yes, then checkedTerm is
+ // referencing a member of this parent.
+ //
+ auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent());
+
+ // Check if we have an aggregate (i.e. struct-like) type.
+ // Ignore interfaces and the case when the term refers to a function
+ //
+ if (parentType->declRef.as<AggTypeDeclBase>() &&
+ !parentType->declRef.as<InterfaceDecl>() &&
+ !declRefExpr->declRef.as<CallableDecl>())
+ {
+ // Check if the parent container type is differentiable.
+ if (auto parentDiffWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(
+ parentType, getASTBuilder()->getDifferentiableInterface())))
+ {
+ // If yes, the member in checkedTerm should have a differential getter and setter.
+ // Otherwise, <ERROR>
+ //
+ auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>();
+ diffExpr->type = checkedTerm->type;
+ diffExpr->inner = checkedTerm;
+
+ {
+ auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text);
+ auto getterResult = lookUpMember(
+ getASTBuilder(),
+ this,
+ getterName,
+ parentType,
+ Slang::LookupMask::Function,
+ Slang::LookupOptions::None);
+
+ if (!getterResult.isValid())
+ {
+ // Do nothing.. we assume that this field cannot be differentiated.
+ // Could this be confusing from a user perspective?
+ }
+ else if (getterResult.isOverloaded())
+ {
+ // Diagnose ambiguous getter.
+ SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported");
+ }
+ else
+ {
+ auto getterRefExpr = ConstructLookupResultExpr(
+ getterResult.item,
+ declRefExpr,
+ getterResult.item.declRef.getLoc(),
+ nullptr);
+
+ // Check that the type is what we expect.
+ // We're going to do this in a very crude way for now.
+ // Ideally, we want to use the overload resolution and type
+ // coercion logic in ResolveInvoke()
+ //
+
+ auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type);
+ auto diffParentType = _getDifferential(m_astBuilder, parentType);
+
+ auto ptrDiffType = m_astBuilder->getPtrType(diffType);
+ auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType);
+
+ auto funcType = as<FuncType>(getterRefExpr->type);
+
+ if (!ptrDiffType->equals(funcType->getResultType()))
+ {
+ getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
+ ptrDiffType, funcType->getResultType());
+ }
+
+ if (!inoutContainerDiffType->equals(funcType->getParamType(0)))
+ {
+ getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
+ inoutContainerDiffType, funcType->getParamType(0));
+ }
+
+ diffExpr->getterExpr = getterRefExpr;
+ }
+ }
+
+ return diffExpr;
+ }
+ }
+ }
+
+ return checkedTerm;
+ }
+
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
+ auto checkedTerm = _CheckTerm(term);
+
+ // Differentiable type checking.
+ // TODO: This can be super slow.
+ if (this->m_parentFunc &&
+ this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ {
+ maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+
+ if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
+ {
+ checkedTerm = maybeMakeDifferentialExpr(checkedTerm);
+ }
+ }
+
+ return checkedTerm;
+ }
+
+ Expr* SemanticsVisitor::_CheckTerm(Expr* term)
+ {
if (!term) return nullptr;
// The process of checking a term/expression can end up introducing
@@ -1677,6 +1888,13 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
+ {
+ auto checkedInnerTerm = CheckTerm(expr->inner);
+ expr->type = checkedInnerTerm->type;
+ return expr;
+ }
+
Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
@@ -1715,48 +1933,38 @@ namespace Slang
return primalType;
}
- Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType)
{
- // Check/Resolve inner function declaration.
- expr->baseFunction = CheckTerm(expr->baseFunction);
+ // Resolve JVP type here.
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
- auto astBuilder = this->getASTBuilder();
+ FuncType* jvpType = builder->create<FuncType>();
- if(auto primalType = as<FuncType>(expr->baseFunction->type))
- {
- // Resolve JVP type here.
- // Note that this type checking needs to be in sync with
- // the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* jvpType = astBuilder->create<FuncType>();
-
- // The JVP return type is float if primal return type is float
- // void otherwise.
- //
- jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType());
-
- // No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
- jvpType->errorType = primalType->errorType;
-
- for (UInt i = 0; i < primalType->getParamCount(); i++)
- {
- if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i)))
- jvpType->paramTypes.add(jvpParamType);
- }
+ // The JVP return type is float if primal return type is float
+ // void otherwise.
+ //
+ jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType());
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType()));
+ jvpType->errorType = originalType->errorType;
- expr->type = jvpType;
- }
- else
+ for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- // Error
- expr->type = astBuilder->getErrorType();
- if (!as<ErrorType>(expr->baseFunction->type))
- {
- getSink()->diagnose(expr->baseFunction->loc, Diagnostics::expectedFunction, expr->baseFunction->type);
- }
+ if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i)))
+ jvpType->paramTypes.add(jvpParamType);
}
+ return jvpType;
+ }
+
+ Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ {
+ this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
+
+ // Check/Resolve inner function declaration.
+ expr->baseFunction = CheckTerm(expr->baseFunction);
return expr;
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index c15428877..5c1c20e3a 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -237,11 +237,79 @@ namespace Slang
Dictionary<LookupRequestKey, LookupResult> lookupCache;
};
+ struct DifferentiableTypeSemanticContext
+ {
+
+ public:
+ /// Registers a type as conforming to IDifferentiable, along with a witness
+ /// describing the relationship.
+ ///
+ void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
+
+ /// Returns the list of registered differentiable types.
+ List<KeyValuePair<DeclRefType*, SubtypeWitness*>> getDifferentiableTypeConformanceList();
+
+ /// Creates a DifferentiableTypeDictionary AST container node with an entry for
+ /// every registered type. This can be inserted into the appropriate context for the
+ /// auto-diff pass.
+ ///
+ DifferentiableTypeDictionary* makeDifferentiableTypeDictionaryNode(ASTBuilder* builder);
+
+ /// Creates a DifferentiableTypeDictionary AST container node with an entry for
+ /// every registered type. This can be inserted into the appropriate context for the
+ /// auto-diff pass.
+ ///
+ void addImportedModule(ModuleDecl* importedModuleDecl);
+
+ /// Set flag to indicate that the type dictionary is requried.
+ void requireDifferentiableTypeDictionary();
+
+ /// Returns flag indicating whether the type dictionary is requried.
+ bool isDictionaryRequired();
+
+ private:
+ // Nested struct to override the '==' operator for DeclRefTypes
+ struct DeclRefTypeKey
+ {
+ DeclRefType* type;
+
+ DeclRefTypeKey(DeclRefType* type) : type(type)
+ {};
+
+ DeclRefTypeKey(DeclRefTypeKey& typeKey) : type(typeKey.type)
+ {};
+
+ DeclRefTypeKey() : type(nullptr)
+ {};
+
+ bool operator==(const DeclRefTypeKey& other) const
+ {
+ return (other.type->declRef == this->type->declRef);
+ }
+
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher.hashObject(&type->declRef);
+ return hasher.getResult();
+ }
+ };
+
+ /// Mapping from types to subtype witnesses for conformance to IDifferentiable.
+ Dictionary<DeclRefTypeKey, SubtypeWitness*> m_mapTypeToIDifferentiableWitness;
+
+ /// List of external dictionaries (from imported modules)
+ List<DeclRef<DifferentiableTypeDictionary>> m_importedDictionaries;
+
+ /// Flag to indicate if a differentiable type dictionary is required.
+ bool m_isTypeDictionaryRequired = false;
+ };
/// Give a cache and a name, will remove all entries associated with a name
/// Might be useful/necessary if a new name is introduced
void removeLookupForName(TypeCheckingCache* cache, Name* name);
+
/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
{
@@ -269,6 +337,10 @@ namespace Slang
//
List<ModuleDecl*> importedModulesList;
HashSet<ModuleDecl*> importedModulesSet;
+
+ DifferentiableTypeSemanticContext diffTypeContext;
+
+ List<DifferentiableTypeSemanticContext*> diffTypeContextStack;
public:
SharedSemanticsContext(
@@ -303,6 +375,29 @@ namespace Slang
return m_linkage->isInLanguageServer();
return false;
}
+
+ DifferentiableTypeSemanticContext* getDiffTypeContext()
+ {
+ return &diffTypeContext;
+ }
+
+ DifferentiableTypeSemanticContext* innermostDiffTypeContext()
+ {
+ return (diffTypeContextStack.getCount() > 0) ? diffTypeContextStack.getLast() : &diffTypeContext;
+ }
+
+ void pushDiffTypeContext(DifferentiableTypeSemanticContext* context)
+ {
+ diffTypeContextStack.add(context);
+ }
+
+ DifferentiableTypeSemanticContext* popDiffTypeContext()
+ {
+ auto context = diffTypeContextStack.getLast();
+ diffTypeContextStack.removeLast();
+ return context;
+ }
+
/// Get the list of extension declarations that appear to apply to `decl` in this context
List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl);
@@ -687,6 +782,8 @@ namespace Slang
Expr* CheckTerm(Expr* term);
+ Expr* _CheckTerm(Expr* term);
+
Expr* CreateErrorExpr(Expr* expr);
bool IsErrorExpr(Expr* expr);
@@ -716,6 +813,20 @@ namespace Slang
//
Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType);
+ // Convert a function's original type to it's JVP type.
+ Type* processJVPFuncType(ASTBuilder* builder, FuncType* originalType);
+
+ // Check and register a type if it is differentiable.
+ void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
+
+ // Check if a term is referencing a member, and add a decoration to it's
+ // differential getter function, if one exists.
+ //
+ Expr* maybeMakeDifferentialExpr(Expr* checkedTerm);
+
+ // Construct the differential for 'type', if it exists.
+ Type* _getDifferential(ASTBuilder* builder, Type* type);
+
public:
bool ValuesAreEqual(
@@ -1004,6 +1115,16 @@ namespace Slang
DeclRef<Decl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+ /// Registers a type as differentiable in the currrent semantic context, if the declaration represents
+ /// a subtype of IDifferentable. Does nothing otherwise.
+ void tryAddDifferentiableConformanceToContext(
+ Decl* decl,
+ DifferentiableTypeSemanticContext* context);
+
+ /// Generates a dictionary node for the module with all registered differentiable types,
+ /// as well as information about differentiable types in imported modules.
+ void finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl);
+
// Find the appropriate member of a declared type to
// satisfy a requirement of an interface the type
// claims to conform to.
@@ -1259,6 +1380,23 @@ namespace Slang
Type* sub = nullptr;
Type* sup = nullptr;
DeclRef<Decl> declRef;
+
+ enum Flavor
+ {
+ // Describes a sub-type super-type relationship through a
+ // reference to an inhertiance declaration.
+ DeclFlavor,
+
+ // Describes a sub-type super-type relationship through
+ // conjunction. This doesn't necessarily have a corresponding declaration
+ // since AndTypes cannot actually be used as types.
+ // i.e. if (A & B) subtype C because A subtype C, then we use AndTypeLeft to represent
+ // that relationship.
+ AndTypeLeftFlavor,
+ AndTypeRightFlavor
+ };
+
+ Flavor flavor = DeclFlavor;
};
// Create a subtype witness based on the declared relationship
@@ -1554,6 +1692,10 @@ namespace Slang
void AddOverloadCandidate(
OverloadResolveContext& context,
OverloadCandidate& candidate);
+
+ void AddHigherOrderOverloadCandidates(
+ Expr* funcExpr,
+ OverloadResolveContext& context);
void AddFuncOverloadCandidate(
LookupResultItem item,
@@ -1621,7 +1763,7 @@ namespace Slang
bool TryUnifyConjunctionType(
ConstraintSystem& constraints,
- AndType* fst,
+ Type* fst,
Type* snd);
// Is the candidate extension declaration actually applicable to the given type
@@ -1638,7 +1780,8 @@ namespace Slang
DeclRef<Decl> inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs);
+ GenericSubstitution* substWithKnownGenericArgs,
+ List<Type*> *innerParameterTypes = nullptr);
void AddTypeOverloadCandidates(
Type* type,
@@ -1781,6 +1924,8 @@ namespace Slang
Expr* visitVarExpr(VarExpr *expr);
+ Expr* visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr *expr);
+
Expr* visitTypeCastExpr(TypeCastExpr * expr);
Expr* visitTryExpr(TryExpr* expr);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 7dba3986a..eadf2f63d 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -715,6 +715,21 @@ namespace Slang
callExpr->originalFunctionExpr = callExpr->functionExpr;
callExpr->type = QualType(candidate.resultType);
+ // If the callee is the result of a higher-order function invocation,
+ // set it's base function to the declaration corresponding to the
+ // resolved overload.
+ //
+ if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr))
+ {
+ higherOrderInvoke->baseFunction = ConstructLookupResultExpr(
+ candidate.item,
+ baseExpr,
+ higherOrderInvoke->loc,
+ callExpr->functionExpr);
+
+ higherOrderInvoke->type = candidate.funcType;
+ }
+
return callExpr;
}
@@ -1174,7 +1189,8 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs)
+ GenericSubstitution* substWithKnownGenericArgs,
+ List<Type*> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
// `genericDeclRef`, in a context where it is being applied
@@ -1279,7 +1295,7 @@ namespace Slang
TryUnifyTypes(
constraints,
context.getArgTypeForInference(aa, this),
- getType(m_astBuilder, params[aa]));
+ (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]);
}
}
else
@@ -1495,6 +1511,11 @@ namespace Slang
AddOverloadCandidates(item, context);
}
}
+ else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
+ {
+ // The expression is the result of a higher order function application.
+ AddHigherOrderOverloadCandidates(higherOrderExpr, context);
+ }
else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr))
{
// A partially-applied generic is allowed as an overload candidate,
@@ -1520,6 +1541,121 @@ namespace Slang
}
}
+ void SemanticsVisitor::AddHigherOrderOverloadCandidates(
+ Expr* funcExpr,
+ OverloadResolveContext& context)
+ {
+ // Lookup the higher order function and process types accordingly. In the future,
+ // if there are enough varieties, we can have dispatch logic instead of an
+ // if-else ladder.
+ if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr))
+ {
+ if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
+ {
+ // Case: __jvp(name-resolved-to-decl-ref)
+
+ auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
+ SLANG_ASSERT(baseFuncDeclRef);
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType));
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(baseFuncDeclRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
+ {
+ // Case: __jvp(name-resolved-to-multiple-decl-ref)
+
+ if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
+ {
+ for (auto item : overloadExpr->lookupResult2.items)
+ {
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ this->getASTBuilder(),
+ as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc))));
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(item.declRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ }
+ else
+ {
+ // Unhandled overload expr.
+ funcExpr->type = this->getASTBuilder()->getErrorType();
+ getSink()->diagnose(funcExpr->loc,
+ Diagnostics::unimplemented,
+ funcExpr->type);
+ }
+ }
+ else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
+ {
+ // Case: __jvp(name-resolved-to-generic-decl)
+
+ // Get inner function
+ DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
+ getInner(baseFuncGenericDeclRef),
+ baseFuncGenericDeclRef.substitutions);
+
+ // Pull parameter list of inner function.
+ auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
+
+ // Process func type to generate JVP func type.
+ auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType));
+
+ // Extract parameter list from processed type.
+ List<Type*> paramTypes;
+
+ for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++)
+ paramTypes.add(jvpFuncType->getParamType(ii));
+
+ // Try to infer generic arguments, based on the updated context.
+ DeclRef<Decl> innerRef = inferGenericArguments(
+ baseFuncGenericDeclRef,
+ context,
+ nullptr,
+ &paramTypes);
+
+ if (innerRef)
+ {
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+
+ // Note that we call processJVPFuncType() again here
+ // in order to process the specialized version of the original func type.
+ // This could potentially be a declRef.substitute(jvpFuncType)
+ //
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ this->getASTBuilder(),
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(innerRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Could not resolve generic candidate");
+ }
+
+ }
+ else
+ {
+ // Unhandled case for the inner expr.
+ funcExpr->type = this->getASTBuilder()->getErrorType();
+ getSink()->diagnose(funcExpr->loc,
+ Diagnostics::expectedFunction,
+ funcExpr->type);
+ }
+ }
+ }
+
String SemanticsVisitor::getCallSignatureString(
OverloadResolveContext& context)
{
@@ -1627,8 +1763,8 @@ namespace Slang
// without needing dummy initializer/constructor declarations.
//
// Handling that special casing here (rather than in, say,
- // `visitTypeCastExpr`) would allow us to continue to ensure
// that `(T) expr` and `T(expr)` continue to be semantically
+ // `visitTypeCastExpr`) would allow us to continue to ensure
// equivalent in (almost) all cases.
if (!context.bestCandidate)
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index d402dde03..6a8f802f7 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -320,6 +320,19 @@ namespace Slang
getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource);
}
}
+
+ // Differentiable type checking.
+ // TODO: This can be super slow. Switch to caching the result asap.
+ if (this->m_parentFunc &&
+ this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ {
+ auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(result, getASTBuilder()->getDifferentiableInterface())))
+ {
+ diffTypeContext->registerDifferentiableType((DeclRefType*)result, subtypeWitness);
+ }
+ }
*outProperType = result;
return true;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 4666e80d8..1ea54475e 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -10,6 +10,8 @@
#include "slang-ir-collect-global-uniforms.h"
#include "slang-ir-cleanup-void.h"
#include "slang-ir-dce.h"
+#include "slang-ir-diff-call.h"
+#include "slang-ir-diff-jvp.h"
#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
@@ -365,6 +367,29 @@ Result linkAndOptimizeIR(
lowerReinterpret(targetRequest, irModule, sink);
validateIRModuleIfEnabled(codeGenContext, irModule);
+
+ // Inline calls to any functions marked with [__unsafeInlineEarly] again,
+ // since we may be missing out cases prevented by the functions that we just specialzied.
+ performMandatoryEarlyInlining(irModule);
+
+ dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
+
+ // Process higher-order calles to auto-diff passes.
+ // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass)
+ processJVPDerivativeMarkers(irModule, sink);
+
+ // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass)
+ // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+
+ // 3. Fill in higher-order invocations with the generated functions.
+ processDerivativeCalls(irModule);
+
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
+
+ validateIRModuleIfEnabled(codeGenContext, irModule);
+
+ applySparseConditionalConstantPropagation(irModule);
+ eliminateDeadCode(irModule);
// For targets that supports dynamic dispatch, we need to lower the
// generics / interface types to ordinary functions and types using
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index d58e307da..7d677b488 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -361,6 +361,13 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
case kIROp_WitnessTableEntry:
return true;
+ // Special dictionaries used for differentiable type tracking
+ // should be kept alive. These are removed by the auto-diff pass,
+ // once they are used.
+ case kIROp_DifferentiableTypeDictionaryItem:
+ case kIROp_DifferentiableTypeDictionary:
+ return true;
+
default:
break;
}
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 92044be3c..ee78246fe 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -52,25 +52,50 @@ struct DerivativeCallProcessContext
// the intstructions.
void processDifferentiate(IRJVPDifferentiate* derivOfInst)
{
- IRFunc* jvpFunc = nullptr;
+ IRInst* jvpCallable = nullptr;
+
+ // First get base function
+ auto origCallable = derivOfInst->getBaseFn();
+
+ IRSpecialize* specialization = nullptr;
+
+ // If the base is a specialize inst, get the inner fn.
+ if (auto origSpecialize = as<IRSpecialize>(origCallable))
+ {
+ specialization = origSpecialize;
+ origCallable = origSpecialize->getBase();
+ }
+
+ // We should have either a generic or a function reference on our hands.
+ SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable));
// Resolve the derivative function.
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
- if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
- jvpFunc = jvpRefDecorator->getJVPFunc();
+ jvpCallable = jvpRefDecorator->getJVPFunc();
+ }
+
+ SLANG_ASSERT(jvpCallable);
+
+ if (specialization)
+ {
+ // Replace the specialization target with the JVP func.
+ specialization->setOperand(0, jvpCallable);
+
+ // Then replace the JVPDifferentiate inst with the specialization.
+ derivOfInst->replaceUsesWith(specialization);
}
-
- // Substitute all uses of the 'derivativeOf' operation
- // with the resolved derivative function.
- while (auto use = derivOfInst->firstUse)
+ else
{
- use->set(jvpFunc);
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ derivOfInst->replaceUsesWith(jvpCallable);
}
- // Remove the 'derivativeOf'
+ // Remove the 'derivativeOf' inst.
derivOfInst->removeAndDeallocate();
}
};
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 5eee13d5e..843428c01 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -7,6 +7,10 @@
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
+// origX, primalX, diffX
+// origX -> primalX (cloneEnv)
+// origX -> diffX (instMapD)
+
namespace Slang
{
@@ -24,7 +28,7 @@ typedef Pair<IRInst*, IRInst*> InstPair;
struct DifferentiableTypeConformanceContext
{
- Dictionary<IRInst*, IRInst*> witnessTableMap;
+ Dictionary<IRInst*, IRInst*> witnessTableMap;
IRInst* inst = nullptr;
@@ -39,6 +43,18 @@ struct DifferentiableTypeConformanceContext
// type in the conformance table associated with the concrete type.
//
IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // The struct key for the 'zero()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of zero() for a given type.
+ //
+ IRStructKey* zeroMethodStructKey = nullptr;
+
+ // The struct key for the 'add()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of add() for a given type.
+ //
+ IRStructKey* addMethodStructKey = nullptr;
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
@@ -56,6 +72,9 @@ struct DifferentiableTypeConformanceContext
{
differentiableInterfaceType = parent->differentiableInterfaceType;
differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
+ zeroMethodStructKey = parent->zeroMethodStructKey;
+ addMethodStructKey = parent->addMethodStructKey;
+
isInterfaceAvailable = parent->isInterfaceAvailable;
}
else
@@ -64,17 +83,13 @@ struct DifferentiableTypeConformanceContext
if (differentiableInterfaceType)
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
if (differentialAssocTypeStructKey)
isInterfaceAvailable = true;
}
}
-
- if (isInterfaceAvailable)
- {
- // Load all witness tables corresponding to the IDifferentiable interface.
- loadWitnessTablesForInterface(differentiableInterfaceType);
- }
}
DifferentiableTypeConformanceContext(IRInst* inst) :
@@ -84,35 +99,30 @@ struct DifferentiableTypeConformanceContext
// Lookup a witness table for the concreteType. One should exist if concreteType
// inherits (successfully) from IDifferentiable.
//
- IRInst* lookUpConformanceForType(IRInst* type)
+ IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type)
{
SLANG_ASSERT(isInterfaceAvailable);
+ // TODO: Cache the returned value to avoid repeatedly scanning through
+ // blocks looking for the type entries.
+ //
+ if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent()))
+ {
+ return irWitness;
+ }
- if (witnessTableMap.ContainsKey(type))
- return witnessTableMap[type];
- else if (parent)
- return parent->lookUpConformanceForType(type);
- else
- return nullptr;
+ return nullptr;
}
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- SLANG_ASSERT(isInterfaceAvailable);
- if (auto conformance = lookUpConformanceForType(origType))
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+ {
+ if (auto conformance = lookUpConformanceForType(builder, origType))
{
if (auto witnessTable = as<IRWitnessTable>(conformance))
{
for (auto entry : witnessTable->getEntries())
{
- if (entry->getRequirementKey() == differentialAssocTypeStructKey)
- return as<IRType>(entry->getSatisfyingVal());
+ if (entry->getRequirementKey() == key)
+ return entry->getSatisfyingVal();
}
}
else if (auto witnessTableParam = as<IRParam>(conformance))
@@ -120,12 +130,32 @@ struct DifferentiableTypeConformanceContext
return builder->emitLookupInterfaceMethodInst(
builder->getTypeKind(),
witnessTableParam,
- differentialAssocTypeStructKey);
+ key);
}
}
return nullptr;
}
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
+ }
+
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, addMethodStructKey);
+ }
private:
@@ -150,11 +180,26 @@ struct DifferentiableTypeConformanceContext
IRStructKey* findDifferentialTypeStructKey()
{
+ return getIDifferentiableStructKeyAtIndex(0);
+ }
+
+ IRStructKey* findZeroMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(1);
+ }
+
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(2);
+ }
+
+ IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
+ {
if (as<IRModuleInst>(inst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0)))
+ // Assume for now that IDifferentiable has exactly three fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
else
{
@@ -200,12 +245,18 @@ struct DifferentiableTypeConformanceContext
genericParam = genericParam->getNextParam();
}
- UCount tableIndex = 0;
+ Count tableIndex = 0;
while (genericParam)
{
SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
+
+ if (tableIndex >= typeParams.getCount())
+ break;
+
if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
{
+ // TODO(sai): Heavily flawed way to find the right witness table.
+ // Rewrite this part
if (witnessTableType->getConformanceType() == differentiableInterfaceType)
witnessTableMap.Add(typeParams[tableIndex], genericParam);
}
@@ -222,6 +273,40 @@ struct DifferentiableTypeConformanceContext
};
+
+IRInst* findGlobal(IRInst* inst)
+{
+ if (inst->getParent() != inst->getModule()->getModuleInst())
+ {
+ return findGlobal(inst->getParent());
+ }
+
+ return inst;
+}
+
+void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst)
+{
+ HashSet<IRInst*> globalsOfUses;
+ for (auto use = globalInst->firstUse; use; use = use->nextUse)
+ {
+ globalsOfUses.Add(findGlobal(use->getUser()));
+ }
+
+ IRInst* earliestUse = nullptr;
+ for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst())
+ {
+ if (globalsOfUses.Contains(cursor))
+ {
+ earliestUse = cursor;
+ }
+ }
+
+ if (earliestUse)
+ {
+ globalInst->insertBefore(earliestUse);
+ }
+}
+
struct DifferentialPairTypeBuilder
{
@@ -229,95 +314,246 @@ struct DifferentialPairTypeBuilder
diffConformanceContext(diffConformanceContext)
{}
- IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ IRStructField* findField(IRInst* type, IRStructKey* key)
{
- if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ if (auto irStructType = as<IRStructType>(type))
{
- auto primalField = as<IRStructField>(basePairStructType->getFirstChild());
- SLANG_ASSERT(primalField);
-
- return as<IRFieldExtract>(builder->emitFieldExtract(
- primalField->getFieldType(),
- baseInst,
- primalField->getKey()
- ));
+ for (auto field : irStructType->getFields())
+ {
+ if (field->getKey() == key)
+ {
+ return field;
+ }
+ }
}
- else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ else if (auto irSpecialize = as<IRSpecialize>(type))
{
- if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase()))
{
- auto primalField = as<IRStructField>(pairStructType->getFirstChild());
- SLANG_ASSERT(primalField);
-
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType(primalField->getFieldType()),
- baseInst,
- primalField->getKey()
- ));
+ if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric)))
+ {
+ return findField(irGenericStructType, key);
+ }
}
}
- else
+
+ return nullptr;
+ }
+
+ IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
+ {
+ // Get base generic that's being specialized.
+ auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase());
+ SLANG_ASSERT(genericType);
+
+ // Find the index of genericParam in the base generic.
+ int paramIndex = -1;
+ int currentIndex = 0;
+ for (auto param : genericType->getParams())
{
- SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ if (param == genericParam)
+ paramIndex = currentIndex;
+ currentIndex ++;
}
- return nullptr;
+
+ SLANG_ASSERT(paramIndex >= 0);
+
+ // Return the corresponding operand in the specialization inst.
+ return specializeInst->getOperand(1 + paramIndex);
}
- IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
{
if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
{
- auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst());
- SLANG_ASSERT(diffField);
-
return as<IRFieldExtract>(builder->emitFieldExtract(
- diffField->getFieldType(),
+ findField(basePairStructType, key)->getFieldType(),
baseInst,
- diffField->getKey()
+ key
));
}
else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
{
- if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
- auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst());
- SLANG_ASSERT(diffField);
-
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType(diffField->getFieldType()),
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ ptrInnerSpecializedType,
+ findField(ptrInnerSpecializedType, key)->getFieldType())),
baseInst,
- diffField->getKey()
+ key
));
+ }
+ }
+ else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findField(ptrBaseStructType, key)->getFieldType()),
+ baseInst,
+ key));
+ }
+ }
+ else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType()))
+ {
+ // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
+ // type, emit the specialization type.
+ //
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ (IRType*)findSpecializationForParam(
+ specializedType,
+ findField(genericBasePairStructType, key)->getFieldType()),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
+ {
+ if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ specializedType,
+ findField(genericPairStructType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
}
}
else
{
- SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
}
return nullptr;
}
+
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ }
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ }
+
+ void relocateNewTypes(IRBuilder* builder)
+ {
+ for (auto typeInst : generatedTypeList)
+ {
+ moveGlobalToBeforeUses(builder, typeInst);
+ }
+ }
+
+ void _createGenericDiffPairType(IRBuilder* builder)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ // Make a generic version of the pair struct.
+ auto irGeneric = builder->emitGeneric();
+ irGeneric->setFullType(builder->getTypeKind());
+ builder->setInsertInto(irGeneric);
+
+ generatedTypeList.add(irGeneric);
+
+ auto irBlock = builder->emitBlock();
+ builder->setInsertInto(irBlock);
+
+ auto pTypeParam = builder->emitParam(builder->getTypeType());
+ builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT"));
+
+ auto dTypeParam = builder->emitParam(builder->getTypeType());
+ builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT"));
+
+ auto irStructType = builder->createStructType();
+ builder->emitReturn(irStructType);
+
+ auto primalKey = _getOrCreatePrimalStructKey(builder);
+ builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal"));
+ builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam);
+
+ auto diffKey = _getOrCreateDiffStructKey(builder);
+ builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
+ builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam);
+
+ // Reset cursor when done.
+ builder->setInsertLoc(insertLoc);
+
+ this->genericDiffPairType = irGeneric;
+ }
+
+ IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
+ {
+ if (!this->globalDiffKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ this->globalDiffKey = builder->createStructKey();
+ builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
+
+ builder->setInsertLoc(insertLoc);
+ }
+
+ return this->globalDiffKey;
+ }
+
+ IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder)
+ {
+ if (!this->globalPrimalKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ this->globalPrimalKey = builder->createStructKey();
+ builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
+
+ builder->setInsertLoc(insertLoc);
+ }
+
+ return this->globalPrimalKey;
+ }
+
+ IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder)
+ {
+ if (!this->genericDiffPairType)
+ {
+ _createGenericDiffPairType(builder);
+ }
+
+ SLANG_ASSERT(this->genericDiffPairType);
+ return this->genericDiffPairType;
+ }
- IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
{
- auto diffPairType = builder->createStructType();
-
- // Create a keys for the primal and differential fields.
- IRStructKey* origKey = builder->createStructKey();
- builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal"));
- builder->createStructField(diffPairType, origKey, origBaseType);
+ SLANG_ASSERT(!as<IRParam>(origBaseType));
- IRStructKey* diffKey = builder->createStructKey();
- builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
- builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType));
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType);
- return diffPairType;
+ return pairStructType;
}
return nullptr;
}
- IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (pairTypeCache.ContainsKey(origBaseType))
return pairTypeCache[origBaseType];
@@ -328,10 +564,17 @@ struct DifferentialPairTypeBuilder
return pairType;
}
- Dictionary<IRType*, IRStructType*> pairTypeCache;
+ Dictionary<IRInst*, IRInst*> pairTypeCache;
DifferentiableTypeConformanceContext* diffConformanceContext;
+
+ IRStructKey* globalPrimalKey = nullptr;
+
+ IRStructKey* globalDiffKey = nullptr;
+ IRInst* genericDiffPairType = nullptr;
+
+ List<IRInst*> generatedTypeList;
};
struct JVPTranscriber
@@ -341,6 +584,9 @@ struct JVPTranscriber
// their differential values.
Dictionary<IRInst*, IRInst*> instMapD;
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
// Cloning environment to hold mapping from old to new copies for the primal
// instructions.
IRCloneEnv cloneEnv;
@@ -362,7 +608,17 @@ struct JVPTranscriber
void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
{
- instMapD.Add(origInst, diffInst);
+ if (hasDifferentialInst(origInst))
+ {
+ if (lookupDiffInst(origInst) != diffInst)
+ {
+ SLANG_UNEXPECTED("Inconsistent differential mappings");
+ }
+ }
+ else
+ {
+ instMapD.Add(origInst, diffInst);
+ }
}
void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
@@ -439,6 +695,7 @@ struct JVPTranscriber
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
auto origType = funcType->getParamType(i);
+ origType = (IRType*) lookupPrimalInst(origType, origType);
if (auto diffPairType = tryGetDiffPairType(builder, origType))
newParameterTypes.add(diffPairType);
else
@@ -448,7 +705,8 @@ struct JVPTranscriber
// Transcribe return type to a pair.
// This will be void if the primal return type is non-differentiable.
//
- if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType()))
+ auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
+ if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
diffReturnType = returnPairType;
else
diffReturnType = builder->getVoidType();
@@ -458,41 +716,101 @@ struct JVPTranscriber
IRType* differentiateType(IRBuilder* builder, IRType* origType)
{
- switch (origType->getOp())
- {
- case kIROp_HalfType:
- case kIROp_FloatType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType));
- case kIROp_OutType:
- return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType()));
- case kIROp_InOutType:
- return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType()));
- default:
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = lookupPrimalInst(origType, origType);
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(diffConformanceContext->getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType));
}
}
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType)
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
- if (auto origPtrType = as<IRPtrTypeBase>(origType))
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(origType->getOp(), diffPairValueType);
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
else
return nullptr;
}
- return pairBuilder->getOrCreateDiffPairType(builder, origType);
+ return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType);
}
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
{
- if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType()))
+ auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
{
IRParam* diffPairParam = builder->emitParam(diffPairType);
@@ -507,6 +825,7 @@ struct JVPTranscriber
pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
}
+
return InstPair(
cloneInst(&cloneEnv, builder, origParam),
nullptr);
@@ -570,15 +889,13 @@ struct JVPTranscriber
auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
auto diffRight = findOrTranscribeDiffInst(builder, origRight);
- auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0);
- auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0);
if (diffLeft || diffRight)
{
- diffLeft = diffLeft ? diffLeft : leftZero;
- diffRight = diffRight ? diffRight : rightZero;
+ diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
+ diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
- auto resultType = origArith->getDataType();
+ auto resultType = primalArith->getDataType();
switch(origArith->getOp())
{
case kIROp_Add:
@@ -608,17 +925,36 @@ struct JVPTranscriber
return InstPair(primalArith, nullptr);
}
+
+ InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+ {
+ SLANG_ASSERT(origLogic->getOperandCount() == 2);
+
+ // TODO: Check other boolean cases.
+ if (as<IRBoolType>(origLogic->getDataType()))
+ {
+ // Boolean operations are not differentiable. For the linearization
+ // pass, we do not need to do anything but copy them over to the ne
+ // function.
+ auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
+ return InstPair(primalLogic, nullptr);
+ }
+
+ SLANG_UNEXPECTED("Logical operation with non-boolean result");
+ }
+
InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
+
if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
- IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
- SLANG_ASSERT(diffLoad);
-
+ // Default case, we're loading from a known differential inst.
+ diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
return InstPair(primalLoad, diffLoad);
}
return InstPair(primalLoad, nullptr);
@@ -634,15 +970,17 @@ struct JVPTranscriber
auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+ IRInst* diffStore = nullptr;
+
// If the stored value has a differential version,
// emit a store instruction for the differential parameter.
// Otherwise, emit nothing since there's nothing to load.
//
if (diffStoreLocation && diffStoreVal)
{
- IRStore* diffStore = as<IRStore>(
- builder->emitStore(diffStoreLocation, diffStoreVal));
- SLANG_ASSERT(diffStore);
+ // Default case, storing the entire type (and not a member)
+ diffStore = as<IRStore>(
+ builder->emitStore(diffStoreLocation, diffStoreVal));
return InstPair(primalStore, diffStore);
}
@@ -653,14 +991,31 @@ struct JVPTranscriber
InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
{
IRInst* origReturnVal = origReturn->getVal();
-
- if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType()))
+
+ auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
+ if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
+ {
+ // If the return value is itself a function, generic or a struct then this
+ // is likely to be a generic scope. In this case, we lookup the differential
+ // and return that.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+
+ // Neither of these should be nullptr.
+ SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
+ IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
+
+ return InstPair(diffReturn, diffReturn);
+ }
+ else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
{
IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
-
IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
if(!diffReturnVal)
- diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType());
+ diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
+
+ // If the pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffReturnVal);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
@@ -668,10 +1023,12 @@ struct JVPTranscriber
}
else
{
- // If the differential return value is not available, emit a
- // void return.
- IRInst* voidReturn = builder->emitReturn();
- return InstPair(voidReturn, voidReturn);
+ // If the return type is not differentiable, emit the primal value only.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+
+ IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ return InstPair(primalReturn, nullptr);
+
}
}
@@ -682,15 +1039,43 @@ struct JVPTranscriber
InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{
IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+
+ // Check if the output type can be differentiated. If it cannot be
+ // differentiated, don't differentiate the inst
+ //
+ auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
+ if (auto diffConstructType = differentiateType(builder, primalConstructType))
+ {
+ UCount operandCount = origConstruct->getOperandCount();
- if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1)
- return InstPair(primalConstruct, nullptr);
+ List<IRInst*> diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with
+ // the differential. Otherwise, use a zero.
+ //
+ if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ {
+ auto operandDataType = origConstruct->getOperand(ii)->getDataType();
+ operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
+ }
+
+ return InstPair(
+ primalConstruct,
+ builder->emitIntrinsicInst(
+ diffConstructType,
+ origConstruct->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+ }
else
- getSink()->diagnose(origConstruct->sourceLoc,
- Diagnostics::unimplemented,
- "this construct instruction cannot be differentiated");
-
- return InstPair(primalConstruct, nullptr);
+ {
+ return InstPair(primalConstruct, nullptr);
+ }
}
// Differentiating a call instruction here is primarily about generating
@@ -699,13 +1084,21 @@ struct JVPTranscriber
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
{
- if (auto origCallee = as<IRFunc>(origCall->getCallee()))
+
+ if (as<IRFunc>(origCall->getCallee()))
{
-
+ auto origCallee = origCall->getCallee();
+
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
+
+ // TODO: If inner is not differentiable, treat as non-differentiable call.
// Build the differential callee
IRInst* diffCall = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())),
- origCallee);
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
@@ -715,17 +1108,17 @@ struct JVPTranscriber
auto primalArg = findOrTranscribePrimalInst(builder, origArg);
SLANG_ASSERT(primalArg);
- auto origType = origArg->getDataType();
- if (auto pairType = tryGetDiffPairType(builder, origType))
+ auto primalType = primalArg->getDataType();
+ if (auto pairType = tryGetDiffPairType(builder, primalType))
{
-
auto diffArg = findOrTranscribeDiffInst(builder, origArg);
- // TODO(sai): This part is flawed. Replace with a call to the
- // 'zero()' interface method.
if (!diffArg)
- diffArg = getZeroOfType(builder, origType);
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
+
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
args.add(diffPair);
@@ -737,8 +1130,11 @@ struct JVPTranscriber
}
}
+ auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ SLANG_ASSERT(diffReturnType);
+
auto callInst = builder->emitCallInst(
- tryGetDiffPairType(builder, origCall->getFullType()),
+ diffReturnType,
diffCall,
args);
@@ -746,6 +1142,13 @@ struct JVPTranscriber
pairBuilder->emitPrimalFieldAccess(builder, callInst),
pairBuilder->emitDiffFieldAccess(builder, callInst));
}
+ else if(as<IRSpecialize>(origCall->getCallee()) ||
+ as<IRLookupWitnessMethod>(origCall->getCallee()))
+ {
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unspecialized callee or an interface method");
+ }
else
{
// Note that this can only happen if the callee is a result
@@ -774,7 +1177,7 @@ struct JVPTranscriber
return InstPair(
primalSwizzle,
builder->emitSwizzle(
- differentiateType(builder, origSwizzle->getDataType()),
+ differentiateType(builder, primalSwizzle->getDataType()),
diffBase,
origSwizzle->getElementCount(),
swizzleIndices.getBuffer()));
@@ -806,7 +1209,7 @@ struct JVPTranscriber
return InstPair(
primalInst,
builder->emitIntrinsicInst(
- differentiateType(builder, origInst->getDataType()),
+ differentiateType(builder, primalInst->getDataType()),
origInst->getOp(),
operandCount,
diffOperands.getBuffer()));
@@ -819,17 +1222,44 @@ struct JVPTranscriber
case kIROp_unconditionalBranch:
auto origBranch = as<IRUnconditionalBranch>(origInst);
- // Branches with extra operands not handled currently.
- if (origBranch->getOperandCount() > 1)
- break;
+ // Grab the differentials for any phi nodes.
+ List<IRInst*> pairArgs;
+ for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
+ {
+ auto origArg = origBranch->getArg(ii);
- IRInst* diffBranch = nullptr;
+ IRInst* pairArg = nullptr;
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType()))
+ {
+ auto diffArg = lookupDiffInst(origArg, nullptr);
+ if (!diffArg)
+ {
+ diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType());
+ }
+
+ pairArg = builder->emitMakeDifferentialPair(
+ diffPairType,
+ lookupPrimalInst(origArg),
+ diffArg);
+ }
+ else
+ {
+ pairArg = lookupPrimalInst(origArg);
+ }
+ pairArgs.add(pairArg);
+ }
- if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr))
- diffBranch = builder->emitBranch(as<IRBlock>(diffBlock));
+ IRInst* diffBranch = nullptr;
+ if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
+ {
+ diffBranch = builder->emitBranch(
+ as<IRBlock>(diffBlock),
+ pairArgs.getCount(),
+ pairArgs.getBuffer());
+ }
// For now, every block in the original fn must have a corresponding
- // block to compute both primals and derivatives.
+ // block to compute *both* primals and derivatives (i.e linearized block)
SLANG_ASSERT(diffBranch);
return InstPair(diffBranch, diffBranch);
@@ -843,12 +1273,13 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
-
InstPair transcribeConst(IRBuilder*, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
+ case kIROp_VoidLit:
+ case kIROp_IntLit:
return InstPair(origInst, nullptr);
}
@@ -860,49 +1291,439 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
+ InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ {
+ // This is slightly counter-intuitive, but we don't perform any differentiation
+ // logic here. We simple clone the original specialize which points to the original function,
+ // or the cloned version in case we're inside a generic scope.
+ // The differentiation logic is inserted later when this is used in an IRCall.
+ // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn))
+ // rather than have Specialize(JVPDifferentiate(Fn))
+ //
+ auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize);
+ return InstPair(diffSpecialize, diffSpecialize);
+ }
+
+ InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
+ {
+ // This is slightly counter-intuitive, but we don't perform any differentiation
+ // logic here. We simple clone the original lookup which points to the original function,
+ // or the cloned version in case we're inside a generic scope.
+ // The differentiation logic is inserted later when this is used in an IRCall.
+ // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table))
+ // rather than have Lookup(JVPDifferentiate(Table))
+ //
+ auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
+ return InstPair(diffLookup, diffLookup);
+ }
+
// In differential computation, the 'default' differential value is always zero.
// This is a consequence of differential computing being inherently linear. As a
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
//
- IRInst* getZeroOfType(IRBuilder* builder, IRType* type)
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ // We special case a few non-differentiable types that sometimes appear in places
+ // where we're forced to provide a differential zero value. For instance,
+ // float3(float, float, int) is accepted by the compiler, but is tricky in the context
+ // of differentiation since int is non-differentiable, and should be cast to float first.
+ // In the absence of such casts, this piece of code generates appropriate zero values.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_IntType:
+ return builder->getIntValue(primalType, 0);
+ default:
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ IRInst* diffBlock = builder->emitBlock();
+
+ // Note: for blocks, we setup the mapping _before_
+ // processing the children since we could encounter
+ // a lookup while processing the children.
+ //
+ mapPrimalInst(origBlock, diffBlock);
+ mapDifferentialInst(origBlock, diffBlock);
+
+ builder->setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->transcribe(builder, param);
+
+ // Look for the differentiable type dictionary and clone it (and anything else we might need).
+ // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement)
+ // that are operands.
+ // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks.
+ if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock))
+ {
+ auto clonedDict = cloneInst(&cloneEnv, builder, origDict);
+ mapPrimalInst(origDict, clonedDict);
+ mapDifferentialInst(origDict, clonedDict);
+ }
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->transcribe(builder, child);
+
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract)
{
- switch (type->getOp())
+ IRInst* origBase = origExtract->getBase();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+
+ auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType());
+
+ IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField());
+ IRInst* diffExtract = nullptr;
+
+ if (auto diffExtractType = differentiateType(builder, primalExtractType))
{
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- return builder->getFloatValue(type, 0.0);
- case kIROp_IntType:
- return builder->getIntValue(type, 0);
- case kIROp_VectorType:
+ // Check if we have a getter.
+ if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>())
{
- IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())};
- return builder->emitIntrinsicInst(
- type,
- kIROp_constructVectorFromScalar,
- 1,
+
+ IRInst* getterFunc = getterDecoration->getGetterFunc();
+
+ // Must be a method with a single parameter.
+ SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1);
+
+ // Our getter func accepts a _pointer_ to the target type
+ // So we have to create a variable and store our type into memory
+ // here. This will eventually get optimized out in later passes.
+ //
+ auto diffTempVar = builder->emitVar(
+ diffBase->getDataType());
+
+ builder->emitStore(diffTempVar, diffBase);
+
+ List<IRInst*> args;
+ args.add(diffTempVar);
+
+ // Emit a call to the getter. The getter will return a reference type.
+ // We need to load from this to go to a non-ptr 'solid' type.
+ //
+ auto diffGetterCall = builder->emitCallInst(
+ as<IRFuncType>(getterFunc->getDataType())->getResultType(),
+ getterFunc,
args);
+
+ diffExtract = builder->emitLoad(diffGetterCall);
}
- default:
- getSink()->diagnose(type->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
}
+
+ return InstPair(primalExtract, diffExtract);
+ }
+
+ InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress)
+ {
+ IRInst* origBase = origAddress->getBase();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+
+ auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType());
+
+ IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField());
+ IRInst* diffAddress = nullptr;
+
+ if (auto diffAddressType = differentiateType(builder, primalAddressType))
+ {
+ // If we have a getter associated with this field, we want to use that.
+ if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>())
+ {
+ auto getterFunc = getterDecoration->getGetterFunc();
+
+ // Add the base differential inst as the argument.
+ List<IRInst*> args;
+ args.add(diffBase);
+
+ diffAddress = builder->emitCallInst(
+ as<IRFuncType>(getterFunc->getDataType())->getResultType(),
+ getterFunc,
+ args);
+ }
+
+ }
+
+ return InstPair(primalAddress, diffAddress);
+ }
+
+
+ InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
+ {
+ SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
+
+ IRInst* origBase = origGetElementPtr->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
+
+ auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
+
+ IRInst* primalOperands[] = {primalBase, primalIndex};
+ IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
+ primalType,
+ origGetElementPtr->getOp(),
+ 2,
+ primalOperands);
+
+ IRInst* diffGetElementPtr = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ IRInst* diffOperands[] = {diffBase, primalIndex};
+ diffGetElementPtr = builder->emitIntrinsicInst(
+ diffType,
+ origGetElementPtr->getOp(),
+ 2,
+ diffOperands);
+ }
+ }
+
+ return InstPair(primalGetElementPtr, diffGetElementPtr);
+ }
+
+
+ InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
+ {
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body)
+ auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
+
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+
+
+ List<IRInst*> diffLoopOperands;
+ diffLoopOperands.add(diffTargetBlock);
+ diffLoopOperands.add(diffBreakBlock);
+ diffLoopOperands.add(diffContinueBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
+ diffLoopOperands.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ diffLoopOperands.getCount(),
+ diffLoopOperands.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+ {
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body).
+ // Note that for the condition we use the primal inst (condition values should not have a
+ // differential)
+ auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
+ SLANG_ASSERT(primalConditionBlock);
+
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
+ SLANG_ASSERT(diffTrueBlock);
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
+ SLANG_ASSERT(diffFalseBlock);
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
+ SLANG_ASSERT(diffAfterBlock);
+
+
+ List<IRInst*> diffIfElseArgs;
+ diffIfElseArgs.add(primalConditionBlock);
+ diffIfElseArgs.add(diffTrueBlock);
+ diffIfElseArgs.add(diffFalseBlock);
+ diffIfElseArgs.add(diffAfterBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
+ diffIfElseArgs.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_ifElse,
+ diffIfElseArgs.getCount(),
+ diffIfElseArgs.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc)
+ {
+ IRFunc* primalFunc = nullptr;
+
+ auto oldLoc = builder->getInsertLoc();
+
+ // If this is a top-level function, there is no need to clone it
+ // since it is visible in all the scopes.
+ // Otherwise, we need to clone it in case of generic scopes.
+ //
+ // TODO(sai): Is this the correct thing to do? Can a function cloned inside a
+ // generic scope but is not the return value of that generic, be used within
+ // that scope? Or do we have to call out to the original generic specialized with
+ // the current generic params?
+ //
+ bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr);
+ if (isTopLevelFunc)
+ {
+ builder->setInsertBefore(origFunc);
+ primalFunc = origFunc;
+ }
+ else
+ {
+ // TODO(sai): this might never be called, and it might never make sense
+ // to call it either. Potentially remove this.
+ primalFunc = as<IRFunc>(
+ cloneInst(&cloneEnv, builder, origFunc));
+ }
+
+ auto diffFunc = builder->createFunc();
+
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ // TODO(sai): Replace naming scheme
+ // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
+ // builder->addNameHintDecoration(diffFunc, jvpName);
+
+ // Transcribe children from origFunc into diffFunc
+ builder->setInsertInto(diffFunc);
+ for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(builder, block);
+
+ // Reset builder position
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ // Transcribe a generic definition
+ InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric)
+ {
+ // For now, we assume there's only one generic layer. So this inst must be top level
+ bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr);
+ SLANG_RELEASE_ASSERT(isTopLevel);
+
+ IRGeneric* primalGeneric = origGeneric;
+
+ auto oldLoc = builder->getInsertLoc();
+ builder->setInsertBefore(origGeneric);
+
+ auto diffGeneric = builder->emitGeneric();
+
+ // Process type of generic. If the generic is a function, then it's type will also be a
+ // generic and this logic will transcribe that generic first before continuing with the
+ // function itself.
+ //
+ auto primalType = primalGeneric->getFullType();
+
+ IRType* diffType = nullptr;
+ if (primalType)
+ {
+ diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType);
+ }
+
+ diffGeneric->setFullType(diffType);
+
+ // TODO(sai): Replace naming scheme
+ // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
+ // builder->addNameHintDecoration(diffFunc, jvpName);
+
+ // Transcribe children from origFunc into diffFunc.
+ builder->setInsertInto(diffGeneric);
+ for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(builder, block);
+
+ // Reset builder position.
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(primalGeneric, diffGeneric);
}
IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
{
+ // If a differential intstruction is already mapped for
+ // this original inst, return that.
+ //
+ if (auto diffInst = lookupDiffInst(origInst, nullptr))
+ {
+ SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ return diffInst;
+ }
+
+ // Otherwise, dispatch to the appropriate method
+ // depending on the op-code.
+ //
+ instsInProgress.Add(origInst);
InstPair pair = transcribeInst(builder, origInst);
if (auto primalInst = pair.primal)
{
mapPrimalInst(origInst, pair.primal);
-
mapDifferentialInst(origInst, pair.differential);
return pair.differential;
}
+ instsInProgress.Remove(origInst);
+
getSink()->diagnose(origInst->sourceLoc,
Diagnostics::internalCompilerError,
"failed to transcibe instruction");
@@ -911,7 +1732,7 @@ struct JVPTranscriber
InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
{
- // Handle common operations
+ // Handle common SSA-style operations
switch (origInst->getOp())
{
case kIROp_Param:
@@ -934,6 +1755,14 @@ struct JVPTranscriber
case kIROp_Sub:
case kIROp_Div:
return transcribeBinaryArith(builder, origInst);
+
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ return transcribeBinaryLogic(builder, origInst);
case kIROp_Construct:
return transcribeConstruct(builder, origInst);
@@ -945,24 +1774,91 @@ struct JVPTranscriber
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
case kIROp_constructVectorFromScalar:
+ case kIROp_MakeTuple:
return transcribeByPassthrough(builder, origInst);
case kIROp_unconditionalBranch:
- case kIROp_conditionalBranch:
return transcribeControlFlow(builder, origInst);
case kIROp_FloatLit:
+ case kIROp_IntLit:
+ case kIROp_VoidLit:
return transcribeConst(builder, origInst);
+ case kIROp_Specialize:
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
+
+ case kIROp_lookup_interface_method:
+ return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
+
+ case kIROp_FieldExtract:
+ return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst));
+
+ case kIROp_FieldAddress:
+ return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst));
+
+ case kIROp_getElement:
+ case kIROp_getElementPtr:
+ return transcribeGetElement(builder, origInst);
+
+ case kIROp_loop:
+ return transcribeLoop(builder, as<IRLoop>(origInst));
+
+ case kIROp_ifElse:
+ return transcribeIfElse(builder, as<IRIfElse>(origInst));
+
+ case kIROp_DifferentiableTypeDictionary:
+ // Ignore dictionary insts.
+ return InstPair(nullptr, nullptr);
+
}
// If none of the cases have been hit, check if the instruction is a
- // type.
- // For now we don't have logic to differentiate types that appear in blocks.
- // So, we clone and avoid differentiating them.
- //
+ // type. Only need to explicitly differentiate types if they appear inside a block.
+ //
if (auto origType = as<IRType>(origInst))
- return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr);
+ {
+ // If this is a generic type, transcibe the parent
+ // generic and derive the type from the transcribed generic's
+ // return value.
+ //
+ if (as<IRGeneric>(origType->getParent()->getParent()) &&
+ findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
+ !instsInProgress.Contains(origType->getParent()->getParent()))
+ {
+ auto origGenericType = origType->getParent()->getParent();
+ auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
+ auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
+ return InstPair(
+ origGenericType,
+ innerDiffGenericType
+ );
+ }
+ else if (as<IRBlock>(origType->getParent()))
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ differentiateType(builder, origType));
+ else
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ nullptr);
+ }
+
+ // Handle instructions with children
+ switch (origInst->getOp())
+ {
+ case kIROp_Func:
+ return transcribeFunc(builder, as<IRFunc>(origInst));
+
+ case kIROp_Block:
+ return transcribeBlock(builder, as<IRBlock>(origInst));
+
+ case kIROp_Generic:
+ return transcribeGeneric(builder, as<IRGeneric>(origInst));
+ }
+
// If we reach this statement, the instruction type is likely unhandled.
getSink()->diagnose(origInst->sourceLoc,
@@ -1042,6 +1938,14 @@ struct JVPDerivativeContext
// IRMakeDifferentialPair with an IRMakeStruct.
//
modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));
+
+ // Temporary fix: Move generated types, if any, to before their use locations.
+ (&pairBuilderStorage)->relocateNewTypes(builder);
+
+ // Remove all kIROp_DifferentiableTypeDictionary instructions and
+ // kIROp_DifferentialGetterDecoration decorations
+ //
+ modified |= stripDiffTypeInformation(builder, module->getModuleInst());
return modified;
}
@@ -1079,19 +1983,45 @@ struct JVPDerivativeContext
if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
{
- auto baseFunction = jvpDiffInst->getBaseFn();
+ auto baseInst = jvpDiffInst->getBaseFn();
+
+ IRGlobalValueWithCode* baseFunction = nullptr;
+
+ if (auto specializeInst = as<IRSpecialize>(baseInst))
+ {
+ baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase());
+ }
+ else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
+ {
+ baseFunction = globalValWithCode;
+ }
+
+ SLANG_ASSERT(baseFunction);
+
// If the JVP Reference already exists, no need to
// differentiate again.
//
- if(lookupJVPReference(baseFunction)) continue;
+ if (lookupJVPReference(baseFunction)) continue;
- if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction)))
+ if (isMarkedForJVP(baseFunction))
{
- IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction));
- builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction);
- workQueue->push(jvpFunction);
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
+ SLANG_ASSERT(diffFunc);
+ builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc);
+ workQueue->push(diffFunc);
+ }
+ else
+ {
+ // TODO(Sai): This would probably be better with a more specific
+ // error code.
+ getSink()->diagnose(jvpDiffInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
}
- else
+ else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
@@ -1106,55 +2036,33 @@ struct JVPDerivativeContext
return true;
}
- // Run through all the global-level instructions,
- // looking for callables.
- // Note: We're only processing global callables (IRGlobalValueWithCode)
- // for now.
- //
- bool processMarkedGlobalFunctions(IRBuilder* builder)
+ IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*)
{
- for (auto inst : module->getGlobalInsts())
+
+ if (auto pairType = as<IRDifferentialPairType>(type))
{
- // If the instr is a callable, get all the basic blocks
- if (auto callable = as<IRGlobalValueWithCode>(inst))
- {
- if (isFunctionMarkedForJVP(callable))
- {
- SLANG_ASSERT(as<IRFunc>(callable));
+ builder->setInsertBefore(pairType);
- IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable));
- builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+ auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
+ builder,
+ pairType->getValueType());
- unmarkForJVP(callable);
- }
- }
- }
- return true;
- }
+ pairType->replaceUsesWith(diffPairStructType);
+ pairType->removeAndDeallocate();
- IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext)
- {
- if (diffContext->isInterfaceAvailable)
+ return diffPairStructType;
+ }
+ else if (auto loweredStructType = as<IRStructType>(type))
{
- if (auto pairType = as<IRDifferentialPairType>(type))
- {
- builder->setInsertBefore(pairType);
-
- auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
- builder,
- pairType->getValueType());
-
- pairType->replaceUsesWith(diffPairStructType);
- pairType->removeAndDeallocate();
-
- return diffPairStructType;
- }
- else if (auto loweredStructType = as<IRStructType>(type))
- {
- // Already lowered to struct.
- return loweredStructType;
- }
+ // Already lowered to struct.
+ return loweredStructType;
}
+ else if (auto specializedStructType = as<IRSpecialize>(type))
+ {
+ // Already lowered to specialized struct.
+ return specializedStructType;
+ }
+
return nullptr;
}
@@ -1171,7 +2079,7 @@ struct JVPDerivativeContext
operands.add(makePairInst->getPrimalValue());
operands.add(makePairInst->getDifferentialValue());
- auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands);
+ auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
makePairInst->replaceUsesWith(makeStructInst);
makePairInst->removeAndDeallocate();
@@ -1258,10 +2166,43 @@ struct JVPDerivativeContext
return modified;
}
+ bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent)
+ {
+ bool modified = false;
+
+ auto child = parent->getFirstChild();
+ while (child)
+ {
+ auto nextChild = child->getNextInst();
+
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ {
+ child->removeAndDeallocate();
+ child = nextChild;
+ modified = true;
+ continue;
+ }
+
+ if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>())
+ {
+ getterDecoration->removeAndDeallocate();
+ }
+
+ if (child->getFirstChild() != nullptr)
+ {
+ modified |= stripDiffTypeInformation(builder, child);
+ }
+
+ child = nextChild;
+ }
+
+ return modified;
+ }
+
// Checks decorators to see if the function should
// be differentiated (kIROp_JVPDerivativeMarkerDecoration)
//
- bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
+ bool isMarkedForJVP(IRGlobalValueWithCode* callable)
{
for(auto decoration = callable->getFirstDecoration();
decoration;
@@ -1292,63 +2233,8 @@ struct JVPDerivativeContext
}
}
- List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
- {
- List<IRParam*> params;
- for(UIndex i = 0; i < dataType->getParamCount(); i++)
- {
- params.add(
- builder->emitParam(dataType->getParamType(i)));
- }
- return params;
- }
-
- // Perform forward-mode automatic differentiation on
- // the intstructions.
- //
- IRFunc* emitJVPFunction(IRBuilder* builder,
- IRFunc* primalFn)
- {
- eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn);
-
- builder->setInsertBefore(primalFn->getNextInst());
-
- auto jvpFn = builder->createFunc();
-
- SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType()));
- IRType* jvpFuncType = transcriberStorage.differentiateFunctionType(
- builder,
- as<IRFuncType>(primalFn->getFullType()));
- jvpFn->setFullType(jvpFuncType);
-
- if (auto jvpName = getJVPFuncName(builder, primalFn))
- builder->addNameHintDecoration(jvpFn, jvpName);
-
- builder->setInsertInto(jvpFn);
-
- // Emit a block instruction for every block in the function, and map it as the
- // corresponding differential.
- //
- for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
- {
- auto jvpBlock = builder->emitBlock();
- transcriberStorage.mapDifferentialInst(block, jvpBlock);
- transcriberStorage.mapPrimalInst(block, jvpBlock);
- }
-
- // Go back over the blocks, and process the children of each block.
- for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
- {
- auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block));
- SLANG_ASSERT(jvpBlock);
- emitJVPBlock(builder, block, jvpBlock);
- }
-
- return jvpFn;
- }
-
IRStringLit* getJVPFuncName(IRBuilder* builder,
- IRFunc* func)
+ IRInst* func)
{
auto oldLoc = builder->getInsertLoc();
builder->setInsertBefore(func);
@@ -1368,36 +2254,6 @@ struct JVPDerivativeContext
return name;
}
- IRBlock* emitJVPBlock(IRBuilder* builder,
- IRBlock* origBlock,
- IRBlock* jvpBlock = nullptr)
- {
- JVPTranscriber* transcriber = &(transcriberStorage);
-
- // Create if not already created, and then insert into new block.
- if (!jvpBlock)
- jvpBlock = builder->emitBlock();
- else
- builder->setInsertInto(jvpBlock);
-
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- {
- transcriber->transcribe(builder, param);
- }
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- {
- transcriber->transcribe(builder, child);
- }
-
- return jvpBlock;
- }
-
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
module(module), sink(sink),
diffConformanceContextStorage(module->getModuleInst()),
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 8f8261af5..f91fc9cda 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -706,6 +706,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
+ /// Used by the auto-diff pass to hold a reference to a
+ /// differential getter associated with this expression.
+ INST(DifferentialGetterDecoration, diffGetter, 1, 0)
+
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0)
@@ -805,6 +809,10 @@ INST(GenericSpecializationDictionary, GenericSpecializationDictionary, 0, PARENT
INST(ExistentialFuncSpecializationDictionary, ExistentialFuncSpecializationDictionary, 0, PARENT)
INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDictionary, 0, PARENT)
+/* Differentiable Type Dictionary */
+INST(DifferentiableTypeDictionary, DifferentiableTypeDictionary, 0, PARENT)
+INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
+
#undef PARENT
#undef USE_OTHER
#undef INST_RANGE
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 98bc6a0a2..33a2fbfb0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -554,9 +554,19 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration
};
IR_LEAF_ISA(JVPDerivativeReferenceDecoration)
- IRFunc* getJVPFunc() { return as<IRFunc>(getOperand(0)); }
+ IRInst* getJVPFunc() { return getOperand(0); }
};
+struct IRDifferentialGetterDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_DifferentialGetterDecoration
+ };
+ IR_LEAF_ISA(DifferentialGetterDecoration)
+
+ IRInst* getGetterFunc() { return getOperand(0); }
+};
// An instruction that replaces the function symbol
// with it's derivative function.
@@ -573,6 +583,15 @@ struct IRJVPDifferentiate : IRInst
IR_LEAF_ISA(JVPDifferentiate)
};
+// Dictionary item mapping a type with a corresponding
+// IDifferentiable witness table
+//
+struct IRDifferentiableTypeDictionaryItem : IRInst
+{
+ IR_LEAF_ISA(DifferentiableTypeDictionaryItem)
+};
+
+
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
// (instructions representing types, witness tables, etc.)
@@ -2462,6 +2481,27 @@ public:
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
+ // Emit and return a dictionary instruction to the global or generic scope.
+ IRInst* emitDifferentiableTypeDictionary();
+
+ // Emit and return a dictionary instruction to the global or generic scope,
+ // if one is not already present.
+ //
+ IRInst* findOrEmitDifferentiableTypeDictionary();
+
+ // Returns the IRDifferentiableTypeDictionary in the scope of inst.
+ IRInst* findDifferentiableTypeDictionary(IRInst* inst);
+
+ // Add a differentiable type entry to the appropriate dictionary.
+ IRInst* addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness);
+
+ // Lookup a differentiable type entry in the appropriate dictionary.
+ // This recursively looks up in upper contexts.
+ //
+ IRInst* findDifferentiableTypeEntry(IRInst* irType);
+
+ IRInst* findDifferentiableTypeEntry(IRInst* irType, IRInst* scope);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
@@ -3162,6 +3202,11 @@ public:
addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
}
+ void addDifferentialGetterDecoration(IRInst* value, IRInst* getterFn)
+ {
+ addDecoration(value, kIROp_DifferentialGetterDecoration, getterFn);
+ }
+
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index a5130e8b6..56688abae 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -238,6 +238,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
case kIROp_WitnessTable:
case kIROp_InterfaceType:
case kIROp_TaggedUnionType:
+ case kIROp_DifferentiableTypeDictionary:
return cloneGlobalValue(this, originalValue);
case kIROp_BoolLit:
@@ -592,6 +593,24 @@ IRWitnessTable* cloneWitnessTableImpl(
return clonedTable;
}
+IRInst* cloneDifferentiableTypeDictionary(
+ IRSpecContextBase* context,
+ IRBuilder* builder,
+ IRInst* originalDict,
+ IROriginalValuesForClone const& originalValues,
+ IRInst* dstDict = nullptr,
+ bool registerValue = true)
+{
+ IRInst* clonedDict = dstDict;
+ if (!clonedDict)
+ {
+ clonedDict = builder->emitDifferentiableTypeDictionary();
+ }
+ cloneSimpleGlobalValueImpl(context, originalDict, originalValues, clonedDict, registerValue);
+ return clonedDict;
+}
+
+
IRWitnessTable* cloneWitnessTableWithoutRegistering(
IRSpecContextBase* context,
IRBuilder* builder,
@@ -1118,6 +1137,9 @@ IRInst* cloneInst(
case kIROp_GlobalGenericParam:
return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues);
+
+ case kIROp_DifferentiableTypeDictionary:
+ return cloneDifferentiableTypeDictionary(context, builder, originalInst, originalValues);
default:
break;
@@ -1504,11 +1526,14 @@ LinkedIR linkIR(
{
for (auto inst : irModule->getGlobalInsts())
{
- auto bindInst = as<IRBindGlobalGenericParam>(inst);
- if (!bindInst)
- continue;
-
- cloneValue(context, bindInst);
+ if (auto bindInst = as<IRBindGlobalGenericParam>(inst))
+ {
+ cloneValue(context, bindInst);
+ }
+ else if (inst->getOp() == kIROp_DifferentiableTypeDictionary)
+ {
+ cloneValue(context, inst);
+ }
}
}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index a496db3a8..05be164d4 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -318,10 +318,17 @@ IRInst* applyAccessChain(
auto fieldKey = accessChain->getOperand(1);
auto type = cast<IRPtrTypeBase>(accessChain->getDataType())->getValueType();
auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue);
- return builder->emitFieldExtract(
+ auto extractInst = builder->emitFieldExtract(
type,
baseValue,
fieldKey);
+
+ for (auto decoration : accessChain->getDecorations())
+ {
+ cloneDecoration(decoration, extractInst);
+ }
+
+ return extractInst;
}
case kIROp_getElementPtr:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 46d6d445d..2aaeb4ac3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3547,6 +3547,125 @@ namespace Slang
}
}
+
+ IRInst* IRBuilder::emitDifferentiableTypeDictionary()
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_DifferentiableTypeDictionary,
+ nullptr);
+
+ addGlobalValue(this, inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary()
+ {
+ auto currentLoc = this->getInsertLoc();
+ auto currentInst = currentLoc.getInst();
+
+ if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst))
+ return diffTypeDictionary;
+
+ return emitDifferentiableTypeDictionary();
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent)
+ {
+ //auto parent = inst->getParent();
+ while (parent)
+ {
+ // Inserting into the top level of a module?
+ // That is fine, and we can stop searching.
+ if (as<IRModuleInst>(parent))
+ break;
+
+ // Inserting into a basic block inside of
+ // a generic? That is okay too.
+ if (auto block = as<IRBlock>(parent))
+ {
+ if (as<IRGeneric>(block->parent))
+ break;
+ }
+
+ // Otherwise, move up the chain.
+ parent = parent->parent;
+ }
+
+ for (auto child = parent->getFirstChild(); child; child = child->getNextInst())
+ {
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ return child;
+ }
+
+ return nullptr;
+ }
+
+ IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness)
+ {
+ auto oldLoc = this->getInsertLoc();
+
+ IRDifferentiableTypeDictionaryItem* item = nullptr;
+
+ if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary())
+ {
+ this->setInsertInto(diffTypeDictionary);
+
+ IRInst* args[2] = {irType, conformanceWitness};
+ item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>(
+ this,
+ kIROp_DifferentiableTypeDictionaryItem,
+ nullptr,
+ 2,
+ args);
+
+ addInst(item);
+ }
+
+ this->setInsertLoc(oldLoc);
+
+ return item;
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope)
+ {
+ for (auto child = scope->getFirstChild(); child; child = child->getNextInst())
+ {
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ {
+ for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst())
+ {
+ IRInst* entryType = entry->getOperand(0);
+ IRInst* entryConformanceWitness = entry->getOperand(1);
+
+ if (irType == entryType)
+ {
+ return entryConformanceWitness;
+ }
+ }
+ }
+ }
+
+ return nullptr;
+ }
+
+ IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType)
+ {
+ auto instScope = this->getInsertLoc().getInst();
+
+ while (instScope)
+ {
+ if (auto witness = findDifferentiableTypeEntry(irType, instScope))
+ {
+ return witness;
+ }
+ instScope = instScope->getParent();
+ }
+
+ return nullptr;
+ }
+
+
IRFunc* IRBuilder::createFunc()
{
IRFunc* rsFunc = createInst<IRFunc>(
@@ -6322,6 +6441,37 @@ namespace Slang
return inst;
}
+ IRInst* findOuterGeneric(IRInst* inst)
+ {
+ if (inst)
+ {
+ inst = inst->getParent();
+ }
+ else
+ {
+ return nullptr;
+ }
+
+ while(inst)
+ {
+ if (as<IRGeneric>(inst))
+ return inst;
+
+ inst = inst->getParent();
+ }
+ return nullptr;
+ }
+
+ IRInst* findOuterMostGeneric(IRInst* inst)
+ {
+ IRInst* currInst = inst;
+ while(auto outerGeneric = findOuterGeneric(currInst))
+ {
+ currInst = outerGeneric;
+ }
+ return currInst;
+ }
+
IRGeneric* findSpecializedGeneric(IRSpecialize* specialize)
{
return as<IRGeneric>(specialize->getBase());
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index c48f4b378..a2fb1be98 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1723,6 +1723,12 @@ IRInst* findGenericReturnVal(IRGeneric* generic);
// Recursively find the inner most generic return value.
IRInst* findInnerMostGenericReturnVal(IRGeneric* generic);
+// Find the generic container, if any, that this inst is contained in
+// Returns nullptr if there is no outer container.
+IRInst* findOuterGeneric(IRInst* inst);
+// Recursively find the outer most generic container.
+IRInst* findOuterMostGeneric(IRInst* inst);
+
struct IRSpecialize;
IRGeneric* findSpecializedGeneric(IRSpecialize* specialize);
IRInst* findSpecializeReturnVal(IRSpecialize* specialize);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b03f3ae62..dc6067868 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1146,10 +1146,6 @@ static void addLinkageDecoration(
{
builder->addExternCppDecoration(inst, mangledName);
}
- if (decl->findModifier<JVPDerivativeModifier>())
- {
- builder->addJVPDerivativeMarkerDecoration(inst);
- }
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>())
{
@@ -3042,6 +3038,38 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
+ LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
+ {
+ LoweredValInfo info = lowerSubExpr(expr->inner);
+
+ IRInst* irBaseVal = nullptr;
+ switch (info.flavor)
+ {
+ case LoweredValInfo::Flavor::Simple:
+ irBaseVal = getSimpleVal(context, info);
+ break;
+
+ case LoweredValInfo::Flavor::Ptr:
+ irBaseVal = info.val;
+ break;
+
+ default:
+ SLANG_UNEXPECTED("Unhandled lowered value cases");
+ }
+
+ // If the differentiable expr has an associated getter or setter, lower it
+ // and put it in a decoration.
+ //
+ if (expr->getterExpr != nullptr)
+ {
+ auto irGetter = lowerSubExpr(expr->getterExpr);
+ SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple);
+ getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val);
+ }
+
+ return info;
+ }
+
// Emit IR to denote the forward-mode derivative
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
@@ -5844,6 +5872,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl)
+ {
+ for (auto & member : decl->members)
+ {
+ if (auto entry = as<DifferentiableTypeDictionaryItem>(member))
+ {
+
+ // Lower type and witness.
+ IRType* irType = lowerType(context, entry->baseType);
+ IRInst* irWitness = lowerVal(context, entry->confWitness).val;
+
+ SLANG_ASSERT(irType);
+
+ // If the witness can be lowered, and the differentiable type entry exists,
+ // add an entry to the context.
+ //
+ if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType))
+ getBuilder()->addDifferentiableTypeEntry(irType, irWitness);
+ }
+ else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member))
+ {
+ ensureDecl(context, importEntry->dictionaryRef.getDecl());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ }
+
+ if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary())
+ {
+ // Place the dictionary at the end of modules and generic blocks.
+ diffTypeDict->moveToEnd();
+ }
+
+ return LoweredValInfo();
+ }
+
#define IGNORED_CASE(NAME) \
LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); }
@@ -5853,6 +5920,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IGNORED_CASE(SyntaxDecl)
IGNORED_CASE(AttributeDecl)
IGNORED_CASE(NamespaceDecl)
+ IGNORED_CASE(DifferentiableTypeDictionaryItem)
#undef IGNORED_CASE
@@ -6130,7 +6198,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
// Register the value now, rather than later, to avoid any possible infinite recursion.
- setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable));
+ setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
auto irSubType = lowerType(subContext, subType);
irWitnessTable->setOperand(0, irSubType);
@@ -7219,6 +7287,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ // We only need dictionaries to be lowered for decls with executable code (i.e. statements)
+ // Do not lower type dictionaries for inhertiance decls or decls
+ // that are declaring a type, since this can create a cyclic dependancy.
+ //
+ if (as<FunctionDeclBase>(leafDecl))
+ {
+ for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>())
+ {
+ // We directly use lowerDecl() instead of ensureDecl() to emit to
+ // the current generic block instead of the top-level module.
+ //
+ lowerDecl(subContext, diffTypeDict);
+ }
+ }
+
return irGeneric;
}
@@ -7372,6 +7455,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam);
}
+
+ // Add a differentiable type dictionary if necessary.
+ if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock()))
+ markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict);
}
if (valuesToClone.Count() == 0)
{
@@ -7723,6 +7810,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(context, irFunc, decl);
addLinkageDecoration(context, irFunc, decl);
+ if (decl->findModifier<JVPDerivativeModifier>())
+ {
+ getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,
@@ -8788,15 +8880,6 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// temporaries whenever possible.
constructSSA(module);
- // Process higher-order-function calls before any optimization passes
- // to allow the optimizations to affect the generated funcitons.
- // 1. Process JVP derivative functions.
- processJVPDerivativeMarkers(module, compileRequest->getSink());
- // 2. Process VJP derivative functions.
- // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
- // 3. Replace JVP & VJP calls.
- processDerivativeCalls(module);
-
// Do basic constant folding and dead code elimination
// using Sparse Conditional Constant Propagation (SCCP)
//