diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
40 files changed, 3500 insertions, 477 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index e604140ae..26fec224c 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -8,18 +8,118 @@ syntax __differentiate_jvp : JVPDerivativeModifier; __attributeTarget(FuncDecl) attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; -//@ public: - - /// Interface to denote types as differentiable. - /// Allows for user-specified differential types as - /// well as automatic generation, for when the associated type - /// hasn't been declared explicitly. +/// Interface to denote types as differentiable. +/// Allows for user-specified differential types as +/// well as automatic generation, for when the associated type +/// hasn't been declared explicitly. +/// Note that the requirements must currently be defined in this exact order +/// since the auto-diff pass relies on the order to grab the struct keys. +/// __magic_type(DifferentiableType) interface IDifferentiable { associatedtype Differential; + + static Differential zero(); + + static Differential dadd(Differential, Differential); + + static Differential dmul(This, Differential); }; +// Add extensions for the standard types +extension float : IDifferentiable +{ + typedef float Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return 0.f; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector<float, 3> : IDifferentiable +{ + typedef vector<float, 3> Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector<float, 3>(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector<float, 2> : IDifferentiable +{ + typedef vector<float, 2> Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector<float, 2>(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +extension vector<float, 4> : IDifferentiable +{ + typedef vector<float, 4> Differential; + + [__unsafeForceInlineEarly] + static Differential zero() + { + return vector<float, 4>(0.f); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic<T : IDifferentiable> @@ -47,24 +147,3 @@ struct __DifferentialPair return p(); } }; - -// Add extensions for the standard types -extension float : IDifferentiable -{ - typedef float Differential; -} - -extension vector<float, 3> : IDifferentiable -{ - typedef vector<float, 3> Differential; -} - -extension vector<float, 2> : IDifferentiable -{ - typedef vector<float, 2> Differential; -} - -extension vector<float, 4> : IDifferentiable -{ - typedef vector<float, 4> Differential; -} diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index f8c208ac1..f6c550d69 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -179,6 +179,18 @@ Decl* SharedASTBuilder::findMagicDecl(const String& name) return m_magicDecls[name].GetValue(); } +Decl* SharedASTBuilder::tryFindMagicDecl(const String& name) +{ + if (m_magicDecls.ContainsKey(name)) + { + return m_magicDecls[name].GetValue(); + } + else + { + return nullptr; + } +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): @@ -308,6 +320,11 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface() return declRef; } +bool ASTBuilder::isDifferentiableInterfaceAvailable() +{ + return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); +} + DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { DeclRef<Decl> declRef; diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 91fe63c88..e4ea872a0 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -45,6 +45,8 @@ public: // Look up a magic declaration by its name Decl* findMagicDecl(String const& name); + Decl* tryFindMagicDecl(String const& name); + /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session. NamePool* getNamePool() { return m_namePool; } @@ -328,6 +330,8 @@ public: DeclRef<InterfaceDecl> getDifferentiableInterface(); + bool isDifferentiableInterfaceAvailable(); + DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg); Type* getAndType(Type* left, Type* right); diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 147bc7d22..07cfe6a0c 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -494,6 +494,35 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; +// A declaration to hold differentiable type conformances generated during +// the semantic checking phase. +// +class DifferentiableTypeDictionary : public ContainerDecl +{ + SLANG_AST_CLASS(DifferentiableTypeDictionary); +}; + +// A declaration to hold differentiable type conformances generated during +// the semantic checking phase. +// +class DifferentiableTypeDictionaryItem : public Decl +{ + SLANG_AST_CLASS(DifferentiableTypeDictionaryItem); + + DeclRefType* baseType; + SubtypeWitness* confWitness; +}; + +// A declaration that references another dictionary (generally from another module) +// Used to tell the IR lowering pass to process the referenced dictionary. +// +class DifferentiableTypeDictionaryImportItem : public Decl +{ + SLANG_AST_CLASS(DifferentiableTypeDictionaryImportItem); + + DeclRef<DifferentiableTypeDictionary> dictionaryRef; +}; + bool isInterfaceRequirement(Decl* decl); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 13d687da0..e0a55cc29 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -38,6 +38,18 @@ class VarExpr : public DeclRefExpr SLANG_AST_CLASS(VarExpr) }; +class DifferentiableDeclRefExpr : public Expr +{ + SLANG_AST_CLASS(DifferentiableDeclRefExpr) + + // Inner decl ref expr that references a differentiable expression. + Expr* inner = nullptr; + + // Information on getters and setters if available. + Expr* setterExpr = nullptr; + Expr* getterExpr = nullptr; +}; + // An expression that references an overloaded set of declarations // having the same name. class OverloadedExpr : public Expr @@ -428,13 +440,21 @@ class OpenRefExpr : public Expr Expr* innerExpr = nullptr; }; + /// Base class for higher-order function application + /// Eg: foo(fn) where fn is a function expression. + /// +class HigherOrderInvokeExpr : public Expr +{ + SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) + Expr* baseFunction; +}; + /// An expression of the form `__jvp(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class JVPDifferentiateExpr: public Expr +class JVPDifferentiateExpr: public HigherOrderInvokeExpr { SLANG_AST_CLASS(JVPDifferentiateExpr) - Expr* baseFunction; }; /// A type expression of the form `__TaggedUnion(A, ...)`. diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 8868b7a1d..8230f481e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -549,6 +549,7 @@ class AttributeTargetModifier : public Modifier SyntaxClass<NodeBase> syntaxClass; }; + // Base class for checked and unchecked `[name(arg0, ...)]` style attribute. class AttributeBase : public Modifier { diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 377dee350..a8ceaa716 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -283,7 +283,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub { if (constraintParam == declRef.getDecl()) { - found = true; + found = true; break; } index++; @@ -443,6 +443,66 @@ HashCode TransitiveSubtypeWitness::_getHashCodeOverride() return hash; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + int diff = 0; + + Type* substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); + Type* substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + + // If nothing changed, then we can bail out early. + if (!diff) + return this; + + // Something changes, so let the caller know. + (*ioDiff)++; + + // If the substituted witness is a conjunction, break it apart, but it's important to replace the + // sub and super types with the current ones since the conjunction witness will have an + // + if (auto substConjunctionWitness = as<ConjunctionSubtypeWitness>(substWitness)) + { + if (indexInConjunction == 0) + { + auto witness = as<SubtypeWitness>(substConjunctionWitness->leftWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else if (indexInConjunction == 1) + { + auto witness = as<SubtypeWitness>(substConjunctionWitness->rightWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else + { + SLANG_UNIMPLEMENTED_X("conjunction index must be 0 or 1"); + } + } + else + { + // In the simple case, we just construct a new conjunction subtype + // witness. + ExtractFromConjunctionSubtypeWitness* result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); + result->sub = substSub; + result->sup = substSup; + result->conjunctionWitness = substWitness; + result->indexInConjunction = indexInConjunction; + return result; + } +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) @@ -637,29 +697,6 @@ HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() return combineHash(indexInConjunction, conjunctionWitness ? conjunctionWitness->getHashCode() : 0); } -Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - Val* newConjunctionWitness = nullptr; - - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - - if (this->conjunctionWitness) - newConjunctionWitness = conjunctionWitness->substituteImpl(astBuilder, subst, &diff); - *ioDiff += diff; - - if (diff) - { - auto result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); - result->conjunctionWitness = newConjunctionWitness; - result->sub = substSub; - result->sup = substSup; - return result; - } - return this; -} - // ModifierVal bool ModifierVal::_equalsValOverride(Val* val) diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index e0c1f3702..cf362dcdd 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -18,6 +18,62 @@ namespace Slang return witness; } + + Val* simplifyWitness(ASTBuilder* builder, Val* witness) + { + if (auto extractFromConjunction = as<ExtractFromConjunctionSubtypeWitness>(witness)) + { + auto simplWitness = simplifyWitness(builder, extractFromConjunction->conjunctionWitness); + if (auto conjunction = as<ConjunctionSubtypeWitness>(simplWitness)) + { + auto index = extractFromConjunction->indexInConjunction; + SLANG_ASSERT(index == 0 || index == 1); + if (index == 0) + return conjunction->leftWitness; + else + return conjunction->rightWitness; + } + + ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create<ExtractFromConjunctionSubtypeWitness>(); + simplExtractFromConjunction->sub = extractFromConjunction->sub; + simplExtractFromConjunction->sup = extractFromConjunction->sup; + simplExtractFromConjunction->indexInConjunction = extractFromConjunction->indexInConjunction; + simplExtractFromConjunction->conjunctionWitness = as<SubtypeWitness>(simplWitness); + + return simplExtractFromConjunction; + } + else if (auto conjunctionWitness = as<ConjunctionSubtypeWitness>(witness)) + { + auto simplConjunctionWitness = builder->create<ConjunctionSubtypeWitness>(); + simplConjunctionWitness->leftWitness = as<SubtypeWitness>(simplifyWitness(builder, conjunctionWitness->leftWitness)); + simplConjunctionWitness->rightWitness = as<SubtypeWitness>(simplifyWitness(builder, conjunctionWitness->rightWitness)); + simplConjunctionWitness->sub = conjunctionWitness->sub; + simplConjunctionWitness->sup = conjunctionWitness->sup; + + return simplConjunctionWitness; + } + else if (auto transitiveWitness = as<TransitiveSubtypeWitness>(witness)) + { + TransitiveSubtypeWitness* simplTransitiveWitness = builder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>( + transitiveWitness->sub, + transitiveWitness->sup, + transitiveWitness->midToSup); + + simplTransitiveWitness->sub = transitiveWitness->sub; + simplTransitiveWitness->sup = transitiveWitness->sup; + simplTransitiveWitness->midToSup = as<SubtypeWitness>(simplifyWitness(builder, transitiveWitness->midToSup)); + simplTransitiveWitness->subToMid = as<SubtypeWitness>(simplifyWitness(builder, transitiveWitness->subToMid)); + + return simplTransitiveWitness; + } + else + { + // TODO: Add other cases. + return witness; + } + } + + Val* SemanticsVisitor::createTypeWitness( Type* subType, DeclRef<AggTypeDecl> superTypeDeclRef, @@ -70,7 +126,7 @@ namespace Slang // As long as there is more than one breadcrumb, we // need to be creating transitive witnesses. - while(bb->prev) + while (bb->prev) { // On the first iteration when processing the list // above, the breadcrumb would be for `{ C : D }`, @@ -83,19 +139,42 @@ namespace Slang // where `[...]` represents the "hole" we leave // open to fill in next. // - DeclaredSubtypeWitness* declaredWitness = - m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( - bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); + if (bb->flavor == TypeWitnessBreadcrumb::Flavor::DeclFlavor) + { + DeclaredSubtypeWitness* declaredWitness = + m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( + bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); + + TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness); + transitiveWitness->sub = subType; + transitiveWitness->sup = bb->sup; + transitiveWitness->midToSup = declaredWitness; + + // Fill in the current hole, and then set the + // hole to point into the node we just created. + *link = transitiveWitness; + link = &transitiveWitness->subToMid; + } + else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor) + { + ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); + extractWitness->sub = subType; + extractWitness->sup = bb->sup; + extractWitness->indexInConjunction = 0; - TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness); - transitiveWitness->sub = subType; - transitiveWitness->sup = bb->sup; - transitiveWitness->midToSup = declaredWitness; + *link = extractWitness; + link = (SubtypeWitness**) &extractWitness->conjunctionWitness; + } + else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor) + { + ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); + extractWitness->sub = subType; + extractWitness->sup = bb->sup; + extractWitness->indexInConjunction = 1; - // Fill in the current hole, and then set the - // hole to point into the node we just created. - *link = transitiveWitness; - link = &transitiveWitness->subToMid; + *link = extractWitness; + link = (SubtypeWitness**) &extractWitness->conjunctionWitness; + } // Move on with the list. bb = bb->prev; @@ -108,9 +187,14 @@ namespace Slang DeclaredSubtypeWitness* declaredWitness = createSimpleSubtypeWitness(bb); *link = declaredWitness; + // Simplify witnesses of the form ExtractFromConjunction(ConjunctionWitness(...)) + // TODO: At some point, we need a more robust way of checking that two witnesses are in-fact 'equal'. + // In the meantime, this step should suffice. + + // We now know that our original `witness` variable has been // filled in, and there are no other holes. - return witness; + return simplifyWitness(m_astBuilder, witness); } bool SemanticsVisitor::isInterfaceSafeForTaggedUnion( @@ -379,6 +463,35 @@ namespace Slang } return true; } + else if (auto andType = as<AndType>(subType)) + { + // (L & R) is a subtype of T if either L or R is a subtype of T. + // Note that in this method T is explicitly a DeclRef and so cannot be a conjunction itself. + // + TypeWitnessBreadcrumb leftBreadcrumb; + leftBreadcrumb.prev = inBreadcrumbs; + leftBreadcrumb.sub = andType; + leftBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + leftBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + leftBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor; + + if(_isDeclaredSubtype(originalSubType, andType->left, superTypeDeclRef, outWitness, &leftBreadcrumb)) + { + return true; + } + + TypeWitnessBreadcrumb rightBreadcrumb; + rightBreadcrumb.prev = inBreadcrumbs; + rightBreadcrumb.sub = andType; + rightBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + rightBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + rightBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor; + + if(_isDeclaredSubtype(originalSubType, andType->right, superTypeDeclRef, outWitness, &rightBreadcrumb)) + { + return true; + } + } // default is failure return false; } diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 24cedd7d5..f96b5a484 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -564,6 +564,19 @@ namespace Slang } } + // Two subtype witnesses can be unified if they exist (non-null) and + // prove that some pair of types are subtypes of types that can be unified. + // + if (auto fstWit = as<SubtypeWitness>(fst)) + { + if (auto sndWit = as<SubtypeWitness>(snd)) + { + return TryUnifyTypes(constraints, + fstWit->sup, + sndWit->sup); + } + } + SLANG_UNIMPLEMENTED_X("value unification case"); // default: fail @@ -725,17 +738,29 @@ namespace Slang bool SemanticsVisitor::TryUnifyConjunctionType( ConstraintSystem& constraints, - AndType* fst, + Type* fst, Type* snd) { - // Unifying a type `T` with `A & B` amounts to unifying - // `T` with `A` and also `T` with `B`. + // Unifying a type `A & B` with `T` amounts to unifying + // `A` with `T` and also `B` with `T` while + // unifying a type `T` with `A & B` amounts to either + // unifying `T` with `A` or `T` with `B` // // If either unification is impossible, then the full // case is also impossible. // - return TryUnifyTypes(constraints, fst->left, snd) - && TryUnifyTypes(constraints, fst->right, snd); + if (auto fstAndType = as<AndType>(fst)) + { + return TryUnifyTypes(constraints, fstAndType->left, snd) + && TryUnifyTypes(constraints, fstAndType->right, snd); + } + else if (auto sndAndType = as<AndType>(snd)) + { + return TryUnifyTypes(constraints, fst, sndAndType->left) + || TryUnifyTypes(constraints, fst, sndAndType->right); + } + else + return false; } bool SemanticsVisitor::TryUnifyTypes( @@ -762,13 +787,9 @@ namespace Slang // a conjunction directly, and will instead find all of the // "leaf" types we need to constrain it to. // - if( auto fstAndType = as<AndType>(fst) ) - { - return TryUnifyConjunctionType(constraints, fstAndType, snd); - } - if( auto sndAndType = as<AndType>(snd) ) + if (as<AndType>(fst) || as<AndType>(snd)) { - return TryUnifyConjunctionType(constraints, sndAndType, fst); + return TryUnifyConjunctionType(constraints, fst, snd); } // A generic parameter type can unify with anything. diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 2f5447ffb..e56d63f91 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1168,6 +1168,10 @@ namespace Slang m_astBuilder->getErrorType(), fromExpr); } + + // If we coerced to a differentiable type, log it. + maybeRegisterDifferentiableType(m_astBuilder, expr->type); + return expr; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b18e1c4da..2d6e20622 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -835,7 +835,15 @@ namespace Slang // If `decl` is a container, then we want to ensure its children. if(auto containerDecl = as<ContainerDecl>(decl)) - { + { + bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr); + if (trackDiffTypes) + { + // Add a context to track differentiable types. + DifferentiableTypeSemanticContext subDiffTypeContext; + visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext); + } + // NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here, // because the visitor may add to `members` whilst iteration takes place, invalidating the iterator // and likely a crash. @@ -857,6 +865,21 @@ namespace Slang _ensureAllDeclsRec(visitor, childDecl, state); } + + if (trackDiffTypes) + { + auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext(); + + // If there were any differentiable types used in differentiable + // methods, generate a dictionary with the required info. + // + if (subDiffTypeContext->isDictionaryRequired()) + { + auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder()); + diffTypeDict->parentDecl = containerDecl; + containerDecl->members.add(diffTypeDict); + } + } } // Note: the "inner" declaration of a `GenericDecl` is currently @@ -1234,6 +1257,49 @@ namespace Slang } } + void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) + { + // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any + // request to check differentiable types. + // + if (!m_astBuilder->isDifferentiableInterfaceAvailable()) + return; + + auto diffInterface = m_astBuilder->getDifferentiableInterface(); + + DeclRefType* type = nullptr; + + if (auto extensionDecl = as<ExtensionDecl>(decl)) + { + // If this is an extension, use the provided target type. + type = as<DeclRefType>(extensionDecl->targetType.type); + } + else + { + // If this is a type declaration, create a decl ref without + // any substitutions. + // + auto declRef = makeDeclRef(decl); + + // TODO: Strip substitutions from the declreftype + type = DeclRefType::create(m_astBuilder, declRef); + } + + // Skip if the declaration is the interface itself. + if (type->declRef == diffInterface) + return; + + // If the DeclRefType conforms to IDifferentiable, register it with the top-level + // context. + // + if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface))) + { + // TODO: Temporarily disabled to move to new system. Fix later. + // context->registerDifferentiableType(type, witness); + } + + } + void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { // TODO: are there any other validations we can do at this point? @@ -1287,6 +1353,23 @@ namespace Slang ensureDecl(constraint, DeclCheckState::ReadyForReference); } } + + // TODO(sai): Is this the right checking stage to be doing this? + DifferentiableTypeSemanticContext diffTypeContext; + + for (Index i = 0; i < members.getCount(); ++i) + { + Decl* m = members[i]; + + if (auto typeParam = as<GenericTypeParamDecl>(m)) + { + tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext); + } + } + + auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder); + diffTypeDictionaryNode->parentDecl = genericDecl; + genericDecl->members.add(diffTypeDictionaryNode); } void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) @@ -1322,6 +1405,7 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl) { checkAggTypeConformance(aggTypeDecl); + tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext()); } // Conformances can also come via `extension` declarations, and @@ -1330,6 +1414,7 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* extensionDecl) { checkExtensionConformance(extensionDecl); + tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext()); } }; @@ -1486,6 +1571,32 @@ namespace Slang // Furthermore, because a fully checked function will have checked // its body, this also means that all function bodies and the // declarations they contain should be fully checked. + + // Generate a dictionary node to hold information about all + // available differentiable types in scope (including imports and stdlib) + // + if (getShared()->getDiffTypeContext()->isDictionaryRequired()) + finishDifferentiableTypeDictionary(moduleDecl); + } + + void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl) + { + // Grab the differentiable type information from imported modules. + for(auto importedModule : getShared()->importedModulesList) + { + this->getShared()->getDiffTypeContext()->addImportedModule(importedModule); + } + + // Grad the differentiable type information from the standard library modules. + for (auto stdLibModule : this->getSession()->stdlibModules) + { + this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl()); + } + + auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder); + diffTypeDictNode->parentDecl = moduleDecl; + + moduleDecl->members.add(diffTypeDictNode); } bool SemanticsVisitor::doesSignatureMatchRequirement( @@ -4292,7 +4403,23 @@ namespace Slang nullptr); args.add(val); } - // TODO: need to handle constraints here? + } + + // Add defaults for constraint parameters. + for (auto dd : genericDecl->members) + { + if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd)) + { + // Convert the constraint to an appropriate witness. + auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup); + + // Must be non-null since we know there's a constraint. If null, something is + // very wrong. + // + SLANG_ASSERT(witness); + + args.add(witness); + } } GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr); return subst; @@ -4725,6 +4852,11 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { + if (decl->findModifier<JVPDerivativeModifier>()) + { + this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); + } + for(auto paramDecl : decl->getParameters()) { ensureDecl(paramDecl, DeclCheckState::ReadyForReference); @@ -5594,6 +5726,75 @@ namespace Slang m_candidateExtensionListsBuilt = false; m_mapTypeDeclToCandidateExtensions.Clear(); } + + void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) + { + // Need to generate a type dictionary since we have a declaration that works with + // a differentiable type. + // + this->requireDifferentiableTypeDictionary(); + + m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness); + } + + List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList() + { + List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances; + for (auto entry : m_mapTypeToIDifferentiableWitness) + { + diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value)); + } + + return diffConformances; + } + + DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode( + ASTBuilder* builder) + { + auto dictionary = builder->create<DifferentiableTypeDictionary>(); + + for (auto item : m_mapTypeToIDifferentiableWitness) + { + auto entry = builder->create<DifferentiableTypeDictionaryItem>(); + entry->baseType = item.Key.type; + entry->confWitness = item.Value; + entry->parentDecl = dictionary; + + dictionary->members.add(entry); + } + + for (auto item : m_importedDictionaries) + { + auto entry = builder->create<DifferentiableTypeDictionaryImportItem>(); + entry->dictionaryRef = item; + entry->parentDecl = dictionary; + + dictionary->members.add(entry); + } + + return dictionary; + } + + void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl) + { + // TODO: This is a terribly slow way to find the diff type dictionary. + // Switch to lookUp() when possible (this might involve naming the dictionary something) + // + for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>()) + { + m_importedDictionaries.add(makeDeclRef(diffTypeDict)); + } + } + + void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary() + { + this->m_isTypeDictionaryRequired = true; + } + + bool DifferentiableTypeSemanticContext::isDictionaryRequired() + { + return this->m_isTypeDictionaryRequired; + } void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f1ccddf15..745532c27 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -719,8 +719,219 @@ namespace Slang return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); } + Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type) + { + if (auto ptrType = as<PtrTypeBase>(type)) + { + return builder->getPtrType( + _getDifferential(builder, ptrType->getValueType()), + ptrType->getClassInfo().m_name); + } + else if (auto arrayType = as<ArrayExpressionType>(type)) + { + return builder->getArrayType( + _getDifferential(builder, arrayType->baseType), + arrayType->arrayLength); + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface()))) + { + auto diffTypeLookupResult = lookUpMember( + getASTBuilder(), + this, + getName("Differential"), + type, + Slang::LookupMask::type, + Slang::LookupOptions::None); + + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); + + if (!diffTypeLookupResult.isValid()) + { + // Diagnose no 'Differential' member. + getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); + } + else if (diffTypeLookupResult.isOverloaded()) + { + SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + } + else + { + SharedTypeExpr* baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = type; + baseTypeExpr->type.type = m_astBuilder->getTypeType(type); + + auto diffTypeExpr = ConstructLookupResultExpr( + diffTypeLookupResult.item, + baseTypeExpr, + declRefType->declRef.getLoc(), + baseTypeExpr); + + return ExtractTypeFromTypeRepr(diffTypeExpr); + } + } + } + + return nullptr; + } + + void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) + { + if (!builder->isDifferentiableInterfaceAvailable()) + { + return; + } + + // Check for special cases such as PtrTypeBase<T> or Array<T> + // This could potentially be handled later by simply defining extensions + // for Ptr<T:IDifferentiable> etc.. + // + if (auto ptrType = as<PtrTypeBase>(type)) + { + maybeRegisterDifferentiableType(builder, ptrType->getValueType()); + return; + } + + if (auto arrayType = as<ArrayExpressionType>(type)) + { + maybeRegisterDifferentiableType(builder, arrayType->baseType); + return; + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) + { + auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); + diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness); + } + + return; + } + } + + Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) + { + // Check that member lookups on differentiable types have appropriate differential + // getters and setters. + if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) + { + + // Check if we have a parent container. If yes, then checkedTerm is + // referencing a member of this parent. + // + auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); + + // Check if we have an aggregate (i.e. struct-like) type. + // Ignore interfaces and the case when the term refers to a function + // + if (parentType->declRef.as<AggTypeDeclBase>() && + !parentType->declRef.as<InterfaceDecl>() && + !declRefExpr->declRef.as<CallableDecl>()) + { + // Check if the parent container type is differentiable. + if (auto parentDiffWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness( + parentType, getASTBuilder()->getDifferentiableInterface()))) + { + // If yes, the member in checkedTerm should have a differential getter and setter. + // Otherwise, <ERROR> + // + auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>(); + diffExpr->type = checkedTerm->type; + diffExpr->inner = checkedTerm; + + { + auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); + auto getterResult = lookUpMember( + getASTBuilder(), + this, + getterName, + parentType, + Slang::LookupMask::Function, + Slang::LookupOptions::None); + + if (!getterResult.isValid()) + { + // Do nothing.. we assume that this field cannot be differentiated. + // Could this be confusing from a user perspective? + } + else if (getterResult.isOverloaded()) + { + // Diagnose ambiguous getter. + SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); + } + else + { + auto getterRefExpr = ConstructLookupResultExpr( + getterResult.item, + declRefExpr, + getterResult.item.declRef.getLoc(), + nullptr); + + // Check that the type is what we expect. + // We're going to do this in a very crude way for now. + // Ideally, we want to use the overload resolution and type + // coercion logic in ResolveInvoke() + // + + auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); + auto diffParentType = _getDifferential(m_astBuilder, parentType); + + auto ptrDiffType = m_astBuilder->getPtrType(diffType); + auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); + + auto funcType = as<FuncType>(getterRefExpr->type); + + if (!ptrDiffType->equals(funcType->getResultType())) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + ptrDiffType, funcType->getResultType()); + } + + if (!inoutContainerDiffType->equals(funcType->getParamType(0))) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + inoutContainerDiffType, funcType->getParamType(0)); + } + + diffExpr->getterExpr = getterRefExpr; + } + } + + return diffExpr; + } + } + } + + return checkedTerm; + } + Expr* SemanticsVisitor::CheckTerm(Expr* term) { + auto checkedTerm = _CheckTerm(term); + + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) + { + checkedTerm = maybeMakeDifferentialExpr(checkedTerm); + } + } + + return checkedTerm; + } + + Expr* SemanticsVisitor::_CheckTerm(Expr* term) + { if (!term) return nullptr; // The process of checking a term/expression can end up introducing @@ -1677,6 +1888,13 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) + { + auto checkedInnerTerm = CheckTerm(expr->inner); + expr->type = checkedInnerTerm->type; + return expr; + } + Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { @@ -1715,48 +1933,38 @@ namespace Slang return primalType; } - Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) { - // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + // Resolve JVP type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp - auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = builder->create<FuncType>(); - if(auto primalType = as<FuncType>(expr->baseFunction->type)) - { - // Resolve JVP type here. - // Note that this type checking needs to be in sync with - // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* jvpType = astBuilder->create<FuncType>(); - - // The JVP return type is float if primal return type is float - // void otherwise. - // - jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType()); - - // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); - jvpType->errorType = primalType->errorType; - - for (UInt i = 0; i < primalType->getParamCount(); i++) - { - if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i))) - jvpType->paramTypes.add(jvpParamType); - } + // The JVP return type is float if primal return type is float + // void otherwise. + // + jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + jvpType->errorType = originalType->errorType; - expr->type = jvpType; - } - else + for (UInt i = 0; i < originalType->getParamCount(); i++) { - // Error - expr->type = astBuilder->getErrorType(); - if (!as<ErrorType>(expr->baseFunction->type)) - { - getSink()->diagnose(expr->baseFunction->loc, Diagnostics::expectedFunction, expr->baseFunction->type); - } + if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } + return jvpType; + } + + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + { + this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); + + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); return expr; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index c15428877..5c1c20e3a 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -237,11 +237,79 @@ namespace Slang Dictionary<LookupRequestKey, LookupResult> lookupCache; }; + struct DifferentiableTypeSemanticContext + { + + public: + /// Registers a type as conforming to IDifferentiable, along with a witness + /// describing the relationship. + /// + void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); + + /// Returns the list of registered differentiable types. + List<KeyValuePair<DeclRefType*, SubtypeWitness*>> getDifferentiableTypeConformanceList(); + + /// Creates a DifferentiableTypeDictionary AST container node with an entry for + /// every registered type. This can be inserted into the appropriate context for the + /// auto-diff pass. + /// + DifferentiableTypeDictionary* makeDifferentiableTypeDictionaryNode(ASTBuilder* builder); + + /// Creates a DifferentiableTypeDictionary AST container node with an entry for + /// every registered type. This can be inserted into the appropriate context for the + /// auto-diff pass. + /// + void addImportedModule(ModuleDecl* importedModuleDecl); + + /// Set flag to indicate that the type dictionary is requried. + void requireDifferentiableTypeDictionary(); + + /// Returns flag indicating whether the type dictionary is requried. + bool isDictionaryRequired(); + + private: + // Nested struct to override the '==' operator for DeclRefTypes + struct DeclRefTypeKey + { + DeclRefType* type; + + DeclRefTypeKey(DeclRefType* type) : type(type) + {}; + + DeclRefTypeKey(DeclRefTypeKey& typeKey) : type(typeKey.type) + {}; + + DeclRefTypeKey() : type(nullptr) + {}; + + bool operator==(const DeclRefTypeKey& other) const + { + return (other.type->declRef == this->type->declRef); + } + + HashCode getHashCode() const + { + Hasher hasher; + hasher.hashObject(&type->declRef); + return hasher.getResult(); + } + }; + + /// Mapping from types to subtype witnesses for conformance to IDifferentiable. + Dictionary<DeclRefTypeKey, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; + + /// List of external dictionaries (from imported modules) + List<DeclRef<DifferentiableTypeDictionary>> m_importedDictionaries; + + /// Flag to indicate if a differentiable type dictionary is required. + bool m_isTypeDictionaryRequired = false; + }; /// Give a cache and a name, will remove all entries associated with a name /// Might be useful/necessary if a new name is introduced void removeLookupForName(TypeCheckingCache* cache, Name* name); + /// Shared state for a semantics-checking session. struct SharedSemanticsContext { @@ -269,6 +337,10 @@ namespace Slang // List<ModuleDecl*> importedModulesList; HashSet<ModuleDecl*> importedModulesSet; + + DifferentiableTypeSemanticContext diffTypeContext; + + List<DifferentiableTypeSemanticContext*> diffTypeContextStack; public: SharedSemanticsContext( @@ -303,6 +375,29 @@ namespace Slang return m_linkage->isInLanguageServer(); return false; } + + DifferentiableTypeSemanticContext* getDiffTypeContext() + { + return &diffTypeContext; + } + + DifferentiableTypeSemanticContext* innermostDiffTypeContext() + { + return (diffTypeContextStack.getCount() > 0) ? diffTypeContextStack.getLast() : &diffTypeContext; + } + + void pushDiffTypeContext(DifferentiableTypeSemanticContext* context) + { + diffTypeContextStack.add(context); + } + + DifferentiableTypeSemanticContext* popDiffTypeContext() + { + auto context = diffTypeContextStack.getLast(); + diffTypeContextStack.removeLast(); + return context; + } + /// Get the list of extension declarations that appear to apply to `decl` in this context List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); @@ -687,6 +782,8 @@ namespace Slang Expr* CheckTerm(Expr* term); + Expr* _CheckTerm(Expr* term); + Expr* CreateErrorExpr(Expr* expr); bool IsErrorExpr(Expr* expr); @@ -716,6 +813,20 @@ namespace Slang // Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType); + // Convert a function's original type to it's JVP type. + Type* processJVPFuncType(ASTBuilder* builder, FuncType* originalType); + + // Check and register a type if it is differentiable. + void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); + + // Check if a term is referencing a member, and add a decoration to it's + // differential getter function, if one exists. + // + Expr* maybeMakeDifferentialExpr(Expr* checkedTerm); + + // Construct the differential for 'type', if it exists. + Type* _getDifferential(ASTBuilder* builder, Type* type); + public: bool ValuesAreEqual( @@ -1004,6 +1115,16 @@ namespace Slang DeclRef<Decl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + /// Registers a type as differentiable in the currrent semantic context, if the declaration represents + /// a subtype of IDifferentable. Does nothing otherwise. + void tryAddDifferentiableConformanceToContext( + Decl* decl, + DifferentiableTypeSemanticContext* context); + + /// Generates a dictionary node for the module with all registered differentiable types, + /// as well as information about differentiable types in imported modules. + void finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl); + // Find the appropriate member of a declared type to // satisfy a requirement of an interface the type // claims to conform to. @@ -1259,6 +1380,23 @@ namespace Slang Type* sub = nullptr; Type* sup = nullptr; DeclRef<Decl> declRef; + + enum Flavor + { + // Describes a sub-type super-type relationship through a + // reference to an inhertiance declaration. + DeclFlavor, + + // Describes a sub-type super-type relationship through + // conjunction. This doesn't necessarily have a corresponding declaration + // since AndTypes cannot actually be used as types. + // i.e. if (A & B) subtype C because A subtype C, then we use AndTypeLeft to represent + // that relationship. + AndTypeLeftFlavor, + AndTypeRightFlavor + }; + + Flavor flavor = DeclFlavor; }; // Create a subtype witness based on the declared relationship @@ -1554,6 +1692,10 @@ namespace Slang void AddOverloadCandidate( OverloadResolveContext& context, OverloadCandidate& candidate); + + void AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context); void AddFuncOverloadCandidate( LookupResultItem item, @@ -1621,7 +1763,7 @@ namespace Slang bool TryUnifyConjunctionType( ConstraintSystem& constraints, - AndType* fst, + Type* fst, Type* snd); // Is the candidate extension declaration actually applicable to the given type @@ -1638,7 +1780,8 @@ namespace Slang DeclRef<Decl> inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs); + GenericSubstitution* substWithKnownGenericArgs, + List<Type*> *innerParameterTypes = nullptr); void AddTypeOverloadCandidates( Type* type, @@ -1781,6 +1924,8 @@ namespace Slang Expr* visitVarExpr(VarExpr *expr); + Expr* visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr *expr); + Expr* visitTypeCastExpr(TypeCastExpr * expr); Expr* visitTryExpr(TryExpr* expr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 7dba3986a..eadf2f63d 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -715,6 +715,21 @@ namespace Slang callExpr->originalFunctionExpr = callExpr->functionExpr; callExpr->type = QualType(candidate.resultType); + // If the callee is the result of a higher-order function invocation, + // set it's base function to the declaration corresponding to the + // resolved overload. + // + if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr)) + { + higherOrderInvoke->baseFunction = ConstructLookupResultExpr( + candidate.item, + baseExpr, + higherOrderInvoke->loc, + callExpr->functionExpr); + + higherOrderInvoke->type = candidate.funcType; + } + return callExpr; } @@ -1174,7 +1189,8 @@ namespace Slang DeclRef<Decl> SemanticsVisitor::inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs) + GenericSubstitution* substWithKnownGenericArgs, + List<Type*> *innerParameterTypes) { // We have been asked to infer zero or more arguments to // `genericDeclRef`, in a context where it is being applied @@ -1279,7 +1295,7 @@ namespace Slang TryUnifyTypes( constraints, context.getArgTypeForInference(aa, this), - getType(m_astBuilder, params[aa])); + (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]); } } else @@ -1495,6 +1511,11 @@ namespace Slang AddOverloadCandidates(item, context); } } + else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context); + } else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr)) { // A partially-applied generic is allowed as an overload candidate, @@ -1520,6 +1541,121 @@ namespace Slang } } + void SemanticsVisitor::AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context) + { + // Lookup the higher order function and process types accordingly. In the future, + // if there are enough varieties, we can have dispatch logic instead of an + // if-else ladder. + if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr)) + { + if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-decl-ref) + + auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); + SLANG_ASSERT(baseFuncDeclRef); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType)); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(baseFuncDeclRef); + + AddOverloadCandidate(context, candidate); + } + else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-multiple-decl-ref) + + if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) + { + for (auto item : overloadExpr->lookupResult2.items) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as<FuncType>(processJVPFuncType( + this->getASTBuilder(), + as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(item.declRef); + + AddOverloadCandidate(context, candidate); + } + } + else + { + // Unhandled overload expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::unimplemented, + funcExpr->type); + } + } + else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) + { + // Case: __jvp(name-resolved-to-generic-decl) + + // Get inner function + DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( + getInner(baseFuncGenericDeclRef), + baseFuncGenericDeclRef.substitutions); + + // Pull parameter list of inner function. + auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); + + // Process func type to generate JVP func type. + auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType)); + + // Extract parameter list from processed type. + List<Type*> paramTypes; + + for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++) + paramTypes.add(jvpFuncType->getParamType(ii)); + + // Try to infer generic arguments, based on the updated context. + DeclRef<Decl> innerRef = inferGenericArguments( + baseFuncGenericDeclRef, + context, + nullptr, + ¶mTypes); + + if (innerRef) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + + // Note that we call processJVPFuncType() again here + // in order to process the specialized version of the original func type. + // This could potentially be a declRef.substitute(jvpFuncType) + // + candidate.funcType = as<FuncType>(processJVPFuncType( + this->getASTBuilder(), + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(innerRef); + + AddOverloadCandidate(context, candidate); + } + else + { + SLANG_UNEXPECTED("Could not resolve generic candidate"); + } + + } + else + { + // Unhandled case for the inner expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::expectedFunction, + funcExpr->type); + } + } + } + String SemanticsVisitor::getCallSignatureString( OverloadResolveContext& context) { @@ -1627,8 +1763,8 @@ namespace Slang // without needing dummy initializer/constructor declarations. // // Handling that special casing here (rather than in, say, - // `visitTypeCastExpr`) would allow us to continue to ensure // that `(T) expr` and `T(expr)` continue to be semantically + // `visitTypeCastExpr`) would allow us to continue to ensure // equivalent in (almost) all cases. if (!context.bestCandidate) diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d402dde03..6a8f802f7 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -320,6 +320,19 @@ namespace Slang getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource); } } + + // Differentiable type checking. + // TODO: This can be super slow. Switch to caching the result asap. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + { + auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(result, getASTBuilder()->getDifferentiableInterface()))) + { + diffTypeContext->registerDifferentiableType((DeclRefType*)result, subtypeWitness); + } + } *outProperType = result; return true; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 4666e80d8..1ea54475e 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -10,6 +10,8 @@ #include "slang-ir-collect-global-uniforms.h" #include "slang-ir-cleanup-void.h" #include "slang-ir-dce.h" +#include "slang-ir-diff-call.h" +#include "slang-ir-diff-jvp.h" #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" @@ -365,6 +367,29 @@ Result linkAndOptimizeIR( lowerReinterpret(targetRequest, irModule, sink); validateIRModuleIfEnabled(codeGenContext, irModule); + + // Inline calls to any functions marked with [__unsafeInlineEarly] again, + // since we may be missing out cases prevented by the functions that we just specialzied. + performMandatoryEarlyInlining(irModule); + + dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); + + // Process higher-order calles to auto-diff passes. + // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass) + processJVPDerivativeMarkers(irModule, sink); + + // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass) + // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. + + // 3. Fill in higher-order invocations with the generated functions. + processDerivativeCalls(irModule); + + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); + + validateIRModuleIfEnabled(codeGenContext, irModule); + + applySparseConditionalConstantPropagation(irModule); + eliminateDeadCode(irModule); // For targets that supports dynamic dispatch, we need to lower the // generics / interface types to ordinary functions and types using diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index d58e307da..7d677b488 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -361,6 +361,13 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o case kIROp_WitnessTableEntry: return true; + // Special dictionaries used for differentiable type tracking + // should be kept alive. These are removed by the auto-diff pass, + // once they are used. + case kIROp_DifferentiableTypeDictionaryItem: + case kIROp_DifferentiableTypeDictionary: + return true; + default: break; } diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 92044be3c..ee78246fe 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -52,25 +52,50 @@ struct DerivativeCallProcessContext // the intstructions. void processDifferentiate(IRJVPDifferentiate* derivOfInst) { - IRFunc* jvpFunc = nullptr; + IRInst* jvpCallable = nullptr; + + // First get base function + auto origCallable = derivOfInst->getBaseFn(); + + IRSpecialize* specialization = nullptr; + + // If the base is a specialize inst, get the inner fn. + if (auto origSpecialize = as<IRSpecialize>(origCallable)) + { + specialization = origSpecialize; + origCallable = origSpecialize->getBase(); + } + + // We should have either a generic or a function reference on our hands. + SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable)); // Resolve the derivative function. // // Check for the 'JVPDerivativeReference' decorator on the // base function. - if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>()) { - jvpFunc = jvpRefDecorator->getJVPFunc(); + jvpCallable = jvpRefDecorator->getJVPFunc(); + } + + SLANG_ASSERT(jvpCallable); + + if (specialization) + { + // Replace the specialization target with the JVP func. + specialization->setOperand(0, jvpCallable); + + // Then replace the JVPDifferentiate inst with the specialization. + derivOfInst->replaceUsesWith(specialization); } - - // Substitute all uses of the 'derivativeOf' operation - // with the resolved derivative function. - while (auto use = derivOfInst->firstUse) + else { - use->set(jvpFunc); + // Substitute all uses of the 'derivativeOf' operation + // with the resolved derivative function. + derivOfInst->replaceUsesWith(jvpCallable); } - // Remove the 'derivativeOf' + // Remove the 'derivativeOf' inst. derivOfInst->removeAndDeallocate(); } }; diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 5eee13d5e..843428c01 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -7,6 +7,10 @@ #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" +// origX, primalX, diffX +// origX -> primalX (cloneEnv) +// origX -> diffX (instMapD) + namespace Slang { @@ -24,7 +28,7 @@ typedef Pair<IRInst*, IRInst*> InstPair; struct DifferentiableTypeConformanceContext { - Dictionary<IRInst*, IRInst*> witnessTableMap; + Dictionary<IRInst*, IRInst*> witnessTableMap; IRInst* inst = nullptr; @@ -39,6 +43,18 @@ struct DifferentiableTypeConformanceContext // type in the conformance table associated with the concrete type. // IRStructKey* differentialAssocTypeStructKey = nullptr; + + // The struct key for the 'zero()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of zero() for a given type. + // + IRStructKey* zeroMethodStructKey = nullptr; + + // The struct key for the 'add()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of add() for a given type. + // + IRStructKey* addMethodStructKey = nullptr; // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. @@ -56,6 +72,9 @@ struct DifferentiableTypeConformanceContext { differentiableInterfaceType = parent->differentiableInterfaceType; differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey; + zeroMethodStructKey = parent->zeroMethodStructKey; + addMethodStructKey = parent->addMethodStructKey; + isInterfaceAvailable = parent->isInterfaceAvailable; } else @@ -64,17 +83,13 @@ struct DifferentiableTypeConformanceContext if (differentiableInterfaceType) { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); if (differentialAssocTypeStructKey) isInterfaceAvailable = true; } } - - if (isInterfaceAvailable) - { - // Load all witness tables corresponding to the IDifferentiable interface. - loadWitnessTablesForInterface(differentiableInterfaceType); - } } DifferentiableTypeConformanceContext(IRInst* inst) : @@ -84,35 +99,30 @@ struct DifferentiableTypeConformanceContext // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. // - IRInst* lookUpConformanceForType(IRInst* type) + IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type) { SLANG_ASSERT(isInterfaceAvailable); + // TODO: Cache the returned value to avoid repeatedly scanning through + // blocks looking for the type entries. + // + if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent())) + { + return irWitness; + } - if (witnessTableMap.ContainsKey(type)) - return witnessTableMap[type]; - else if (parent) - return parent->lookUpConformanceForType(type); - else - return nullptr; + return nullptr; } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - SLANG_ASSERT(isInterfaceAvailable); - if (auto conformance = lookUpConformanceForType(origType)) + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) + { + if (auto conformance = lookUpConformanceForType(builder, origType)) { if (auto witnessTable = as<IRWitnessTable>(conformance)) { for (auto entry : witnessTable->getEntries()) { - if (entry->getRequirementKey() == differentialAssocTypeStructKey) - return as<IRType>(entry->getSatisfyingVal()); + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); } } else if (auto witnessTableParam = as<IRParam>(conformance)) @@ -120,12 +130,32 @@ struct DifferentiableTypeConformanceContext return builder->emitLookupInterfaceMethodInst( builder->getTypeKind(), witnessTableParam, - differentialAssocTypeStructKey); + key); } } return nullptr; } + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); + } + + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, addMethodStructKey); + } private: @@ -150,11 +180,26 @@ struct DifferentiableTypeConformanceContext IRStructKey* findDifferentialTypeStructKey() { + return getIDifferentiableStructKeyAtIndex(0); + } + + IRStructKey* findZeroMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(1); + } + + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(2); + } + + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) + { if (as<IRModuleInst>(inst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0))) + // Assume for now that IDifferentiable has exactly three fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); else { @@ -200,12 +245,18 @@ struct DifferentiableTypeConformanceContext genericParam = genericParam->getNextParam(); } - UCount tableIndex = 0; + Count tableIndex = 0; while (genericParam) { SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType())); + + if (tableIndex >= typeParams.getCount()) + break; + if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType())) { + // TODO(sai): Heavily flawed way to find the right witness table. + // Rewrite this part if (witnessTableType->getConformanceType() == differentiableInterfaceType) witnessTableMap.Add(typeParams[tableIndex], genericParam); } @@ -222,6 +273,40 @@ struct DifferentiableTypeConformanceContext }; + +IRInst* findGlobal(IRInst* inst) +{ + if (inst->getParent() != inst->getModule()->getModuleInst()) + { + return findGlobal(inst->getParent()); + } + + return inst; +} + +void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst) +{ + HashSet<IRInst*> globalsOfUses; + for (auto use = globalInst->firstUse; use; use = use->nextUse) + { + globalsOfUses.Add(findGlobal(use->getUser())); + } + + IRInst* earliestUse = nullptr; + for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst()) + { + if (globalsOfUses.Contains(cursor)) + { + earliestUse = cursor; + } + } + + if (earliestUse) + { + globalInst->insertBefore(earliestUse); + } +} + struct DifferentialPairTypeBuilder { @@ -229,95 +314,246 @@ struct DifferentialPairTypeBuilder diffConformanceContext(diffConformanceContext) {} - IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) + IRStructField* findField(IRInst* type, IRStructKey* key) { - if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + if (auto irStructType = as<IRStructType>(type)) { - auto primalField = as<IRStructField>(basePairStructType->getFirstChild()); - SLANG_ASSERT(primalField); - - return as<IRFieldExtract>(builder->emitFieldExtract( - primalField->getFieldType(), - baseInst, - primalField->getKey() - )); + for (auto field : irStructType->getFields()) + { + if (field->getKey() == key) + { + return field; + } + } } - else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + else if (auto irSpecialize = as<IRSpecialize>(type)) { - if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase())) { - auto primalField = as<IRStructField>(pairStructType->getFirstChild()); - SLANG_ASSERT(primalField); - - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType(primalField->getFieldType()), - baseInst, - primalField->getKey() - )); + if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric))) + { + return findField(irGenericStructType, key); + } } } - else + + return nullptr; + } + + IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) + { + // Get base generic that's being specialized. + auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase()); + SLANG_ASSERT(genericType); + + // Find the index of genericParam in the base generic. + int paramIndex = -1; + int currentIndex = 0; + for (auto param : genericType->getParams()) { - SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + if (param == genericParam) + paramIndex = currentIndex; + currentIndex ++; } - return nullptr; + + SLANG_ASSERT(paramIndex >= 0); + + // Return the corresponding operand in the specialization inst. + return specializeInst->getOperand(1 + paramIndex); } - IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) + IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) { - auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst()); - SLANG_ASSERT(diffField); - return as<IRFieldExtract>(builder->emitFieldExtract( - diffField->getFieldType(), + findField(basePairStructType, key)->getFieldType(), baseInst, - diffField->getKey() + key )); } else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) { - if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { - auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst()); - SLANG_ASSERT(diffField); - - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType(diffField->getFieldType()), + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + ptrInnerSpecializedType, + findField(ptrInnerSpecializedType, key)->getFieldType())), baseInst, - diffField->getKey() + key )); + } + } + else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findField(ptrBaseStructType, key)->getFieldType()), + baseInst, + key)); + } + } + else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType())) + { + // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's + // type, emit the specialization type. + // + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldExtract>(builder->emitFieldExtract( + (IRType*)findSpecializationForParam( + specializedType, + findField(genericBasePairStructType, key)->getFieldType()), + baseInst, + key + )); + } + else if (auto genericPtrType = as<IRPtrTypeBase>(genericType)) + { + if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + specializedType, + findField(genericPairStructType, key)->getFieldType())), + baseInst, + key + )); + } } } else { - SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); } return nullptr; } + + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + } + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + } + + void relocateNewTypes(IRBuilder* builder) + { + for (auto typeInst : generatedTypeList) + { + moveGlobalToBeforeUses(builder, typeInst); + } + } + + void _createGenericDiffPairType(IRBuilder* builder) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + // Make a generic version of the pair struct. + auto irGeneric = builder->emitGeneric(); + irGeneric->setFullType(builder->getTypeKind()); + builder->setInsertInto(irGeneric); + + generatedTypeList.add(irGeneric); + + auto irBlock = builder->emitBlock(); + builder->setInsertInto(irBlock); + + auto pTypeParam = builder->emitParam(builder->getTypeType()); + builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT")); + + auto dTypeParam = builder->emitParam(builder->getTypeType()); + builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT")); + + auto irStructType = builder->createStructType(); + builder->emitReturn(irStructType); + + auto primalKey = _getOrCreatePrimalStructKey(builder); + builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal")); + builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam); + + auto diffKey = _getOrCreateDiffStructKey(builder); + builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); + builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam); + + // Reset cursor when done. + builder->setInsertLoc(insertLoc); + + this->genericDiffPairType = irGeneric; + } + + IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) + { + if (!this->globalDiffKey) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + this->globalDiffKey = builder->createStructKey(); + builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); + + builder->setInsertLoc(insertLoc); + } + + return this->globalDiffKey; + } + + IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder) + { + if (!this->globalPrimalKey) + { + // Insert directly at top level (skip any generic scopes etc.) + auto insertLoc = builder->getInsertLoc(); + builder->setInsertInto(builder->getModule()->getModuleInst()); + + this->globalPrimalKey = builder->createStructKey(); + builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); + + builder->setInsertLoc(insertLoc); + } + + return this->globalPrimalKey; + } + + IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder) + { + if (!this->genericDiffPairType) + { + _createGenericDiffPairType(builder); + } + + SLANG_ASSERT(this->genericDiffPairType); + return this->genericDiffPairType; + } - IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) { - auto diffPairType = builder->createStructType(); - - // Create a keys for the primal and differential fields. - IRStructKey* origKey = builder->createStructKey(); - builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal")); - builder->createStructField(diffPairType, origKey, origBaseType); + SLANG_ASSERT(!as<IRParam>(origBaseType)); - IRStructKey* diffKey = builder->createStructKey(); - builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); - builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType)); + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType); - return diffPairType; + return pairStructType; } return nullptr; } - IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (pairTypeCache.ContainsKey(origBaseType)) return pairTypeCache[origBaseType]; @@ -328,10 +564,17 @@ struct DifferentialPairTypeBuilder return pairType; } - Dictionary<IRType*, IRStructType*> pairTypeCache; + Dictionary<IRInst*, IRInst*> pairTypeCache; DifferentiableTypeConformanceContext* diffConformanceContext; + + IRStructKey* globalPrimalKey = nullptr; + + IRStructKey* globalDiffKey = nullptr; + IRInst* genericDiffPairType = nullptr; + + List<IRInst*> generatedTypeList; }; struct JVPTranscriber @@ -341,6 +584,9 @@ struct JVPTranscriber // their differential values. Dictionary<IRInst*, IRInst*> instMapD; + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + // Cloning environment to hold mapping from old to new copies for the primal // instructions. IRCloneEnv cloneEnv; @@ -362,7 +608,17 @@ struct JVPTranscriber void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { - instMapD.Add(origInst, diffInst); + if (hasDifferentialInst(origInst)) + { + if (lookupDiffInst(origInst) != diffInst) + { + SLANG_UNEXPECTED("Inconsistent differential mappings"); + } + } + else + { + instMapD.Add(origInst, diffInst); + } } void mapPrimalInst(IRInst* origInst, IRInst* primalInst) @@ -439,6 +695,7 @@ struct JVPTranscriber for (UIndex i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); + origType = (IRType*) lookupPrimalInst(origType, origType); if (auto diffPairType = tryGetDiffPairType(builder, origType)) newParameterTypes.add(diffPairType); else @@ -448,7 +705,8 @@ struct JVPTranscriber // Transcribe return type to a pair. // This will be void if the primal return type is non-differentiable. // - if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType())) + auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); + if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) diffReturnType = returnPairType; else diffReturnType = builder->getVoidType(); @@ -458,41 +716,101 @@ struct JVPTranscriber IRType* differentiateType(IRBuilder* builder, IRType* origType) { - switch (origType->getOp()) - { - case kIROp_HalfType: - case kIROp_FloatType: - case kIROp_DoubleType: - case kIROp_VectorType: - return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType)); - case kIROp_OutType: - return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType())); - case kIROp_InOutType: - return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType())); - default: + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = lookupPrimalInst(origType, origType); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(diffConformanceContext->getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType)); } } - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType) + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from // value type and re-apply the appropropriate PtrType wrapper. // - if (auto origPtrType = as<IRPtrTypeBase>(origType)) + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(origType->getOp(), diffPairValueType); + return builder->getPtrType(primalType->getOp(), diffPairValueType); else return nullptr; } - return pairBuilder->getOrCreateDiffPairType(builder, origType); + return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType); } InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) { - if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType())) + auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) { IRParam* diffPairParam = builder->emitParam(diffPairType); @@ -507,6 +825,7 @@ struct JVPTranscriber pairBuilder->emitDiffFieldAccess(builder, diffPairParam)); } + return InstPair( cloneInst(&cloneEnv, builder, origParam), nullptr); @@ -570,15 +889,13 @@ struct JVPTranscriber auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); auto diffRight = findOrTranscribeDiffInst(builder, origRight); - auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0); - auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0); if (diffLeft || diffRight) { - diffLeft = diffLeft ? diffLeft : leftZero; - diffRight = diffRight ? diffRight : rightZero; + diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); - auto resultType = origArith->getDataType(); + auto resultType = primalArith->getDataType(); switch(origArith->getOp()) { case kIROp_Add: @@ -608,17 +925,36 @@ struct JVPTranscriber return InstPair(primalArith, nullptr); } + + InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) + { + SLANG_ASSERT(origLogic->getOperandCount() == 2); + + // TODO: Check other boolean cases. + if (as<IRBoolType>(origLogic->getDataType())) + { + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); + return InstPair(primalLogic, nullptr); + } + + SLANG_UNEXPECTED("Logical operation with non-boolean result"); + } + InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; + if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { - IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); - SLANG_ASSERT(diffLoad); - + // Default case, we're loading from a known differential inst. + diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); return InstPair(primalLoad, diffLoad); } return InstPair(primalLoad, nullptr); @@ -634,15 +970,17 @@ struct JVPTranscriber auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + IRInst* diffStore = nullptr; + // If the stored value has a differential version, // emit a store instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // if (diffStoreLocation && diffStoreVal) { - IRStore* diffStore = as<IRStore>( - builder->emitStore(diffStoreLocation, diffStoreVal)); - SLANG_ASSERT(diffStore); + // Default case, storing the entire type (and not a member) + diffStore = as<IRStore>( + builder->emitStore(diffStoreLocation, diffStoreVal)); return InstPair(primalStore, diffStore); } @@ -653,14 +991,31 @@ struct JVPTranscriber InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) { IRInst* origReturnVal = origReturn->getVal(); - - if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType())) + + auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); + if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) + { + // If the return value is itself a function, generic or a struct then this + // is likely to be a generic scope. In this case, we lookup the differential + // and return that. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + + // Neither of these should be nullptr. + SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); + IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); + + return InstPair(diffReturn, diffReturn); + } + else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) { IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); if(!diffReturnVal) - diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType()); + diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); + + // If the pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffReturnVal); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); @@ -668,10 +1023,12 @@ struct JVPTranscriber } else { - // If the differential return value is not available, emit a - // void return. - IRInst* voidReturn = builder->emitReturn(); - return InstPair(voidReturn, voidReturn); + // If the return type is not differentiable, emit the primal value only. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + + IRInst* primalReturn = builder->emitReturn(primalReturnVal); + return InstPair(primalReturn, nullptr); + } } @@ -682,15 +1039,43 @@ struct JVPTranscriber InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) { IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst + // + auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); + if (auto diffConstructType = differentiateType(builder, primalConstructType)) + { + UCount operandCount = origConstruct->getOperandCount(); - if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1) - return InstPair(primalConstruct, nullptr); + List<IRInst*> diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) + { + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + { + auto operandDataType = origConstruct->getOperand(ii)->getDataType(); + operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } + } + + return InstPair( + primalConstruct, + builder->emitIntrinsicInst( + diffConstructType, + origConstruct->getOp(), + operandCount, + diffOperands.getBuffer())); + } else - getSink()->diagnose(origConstruct->sourceLoc, - Diagnostics::unimplemented, - "this construct instruction cannot be differentiated"); - - return InstPair(primalConstruct, nullptr); + { + return InstPair(primalConstruct, nullptr); + } } // Differentiating a call instruction here is primarily about generating @@ -699,13 +1084,21 @@ struct JVPTranscriber // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) { - if (auto origCallee = as<IRFunc>(origCall->getCallee())) + + if (as<IRFunc>(origCall->getCallee())) { - + auto origCallee = origCall->getCallee(); + + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; + + // TODO: If inner is not differentiable, treat as non-differentiable call. // Build the differential callee IRInst* diffCall = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())), - origCallee); + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); List<IRInst*> args; // Go over the parameter list and create pairs for each input (if required) @@ -715,17 +1108,17 @@ struct JVPTranscriber auto primalArg = findOrTranscribePrimalInst(builder, origArg); SLANG_ASSERT(primalArg); - auto origType = origArg->getDataType(); - if (auto pairType = tryGetDiffPairType(builder, origType)) + auto primalType = primalArg->getDataType(); + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - // TODO(sai): This part is flawed. Replace with a call to the - // 'zero()' interface method. if (!diffArg) - diffArg = getZeroOfType(builder, origType); + diffArg = getDifferentialZeroOfType(builder, primalType); + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); args.add(diffPair); @@ -737,8 +1130,11 @@ struct JVPTranscriber } } + auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + SLANG_ASSERT(diffReturnType); + auto callInst = builder->emitCallInst( - tryGetDiffPairType(builder, origCall->getFullType()), + diffReturnType, diffCall, args); @@ -746,6 +1142,13 @@ struct JVPTranscriber pairBuilder->emitPrimalFieldAccess(builder, callInst), pairBuilder->emitDiffFieldAccess(builder, callInst)); } + else if(as<IRSpecialize>(origCall->getCallee()) || + as<IRLookupWitnessMethod>(origCall->getCallee())) + { + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unspecialized callee or an interface method"); + } else { // Note that this can only happen if the callee is a result @@ -774,7 +1177,7 @@ struct JVPTranscriber return InstPair( primalSwizzle, builder->emitSwizzle( - differentiateType(builder, origSwizzle->getDataType()), + differentiateType(builder, primalSwizzle->getDataType()), diffBase, origSwizzle->getElementCount(), swizzleIndices.getBuffer())); @@ -806,7 +1209,7 @@ struct JVPTranscriber return InstPair( primalInst, builder->emitIntrinsicInst( - differentiateType(builder, origInst->getDataType()), + differentiateType(builder, primalInst->getDataType()), origInst->getOp(), operandCount, diffOperands.getBuffer())); @@ -819,17 +1222,44 @@ struct JVPTranscriber case kIROp_unconditionalBranch: auto origBranch = as<IRUnconditionalBranch>(origInst); - // Branches with extra operands not handled currently. - if (origBranch->getOperandCount() > 1) - break; + // Grab the differentials for any phi nodes. + List<IRInst*> pairArgs; + for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) + { + auto origArg = origBranch->getArg(ii); - IRInst* diffBranch = nullptr; + IRInst* pairArg = nullptr; + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (!diffArg) + { + diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType()); + } + + pairArg = builder->emitMakeDifferentialPair( + diffPairType, + lookupPrimalInst(origArg), + diffArg); + } + else + { + pairArg = lookupPrimalInst(origArg); + } + pairArgs.add(pairArg); + } - if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr)) - diffBranch = builder->emitBranch(as<IRBlock>(diffBlock)); + IRInst* diffBranch = nullptr; + if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) + { + diffBranch = builder->emitBranch( + as<IRBlock>(diffBlock), + pairArgs.getCount(), + pairArgs.getBuffer()); + } // For now, every block in the original fn must have a corresponding - // block to compute both primals and derivatives. + // block to compute *both* primals and derivatives (i.e linearized block) SLANG_ASSERT(diffBranch); return InstPair(diffBranch, diffBranch); @@ -843,12 +1273,13 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeConst(IRBuilder*, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: + case kIROp_VoidLit: + case kIROp_IntLit: return InstPair(origInst, nullptr); } @@ -860,49 +1291,439 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } + InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + { + // This is slightly counter-intuitive, but we don't perform any differentiation + // logic here. We simple clone the original specialize which points to the original function, + // or the cloned version in case we're inside a generic scope. + // The differentiation logic is inserted later when this is used in an IRCall. + // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn)) + // rather than have Specialize(JVPDifferentiate(Fn)) + // + auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize); + return InstPair(diffSpecialize, diffSpecialize); + } + + InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) + { + // This is slightly counter-intuitive, but we don't perform any differentiation + // logic here. We simple clone the original lookup which points to the original function, + // or the cloned version in case we're inside a generic scope. + // The differentiation logic is inserted later when this is used in an IRCall. + // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table)) + // rather than have Lookup(JVPDifferentiate(Table)) + // + auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); + return InstPair(diffLookup, diffLookup); + } + // In differential computation, the 'default' differential value is always zero. // This is a consequence of differential computing being inherently linear. As a // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. // - IRInst* getZeroOfType(IRBuilder* builder, IRType* type) + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + // We special case a few non-differentiable types that sometimes appear in places + // where we're forced to provide a differential zero value. For instance, + // float3(float, float, int) is accepted by the compiler, but is tricky in the context + // of differentiation since int is non-differentiable, and should be cast to float first. + // In the absence of such casts, this piece of code generates appropriate zero values. + // + switch (primalType->getOp()) + { + case kIROp_IntType: + return builder->getIntValue(primalType, 0); + default: + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + auto oldLoc = builder->getInsertLoc(); + + IRInst* diffBlock = builder->emitBlock(); + + // Note: for blocks, we setup the mapping _before_ + // processing the children since we could encounter + // a lookup while processing the children. + // + mapPrimalInst(origBlock, diffBlock); + mapDifferentialInst(origBlock, diffBlock); + + builder->setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->transcribe(builder, param); + + // Look for the differentiable type dictionary and clone it (and anything else we might need). + // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement) + // that are operands. + // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks. + if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock)) + { + auto clonedDict = cloneInst(&cloneEnv, builder, origDict); + mapPrimalInst(origDict, clonedDict); + mapDifferentialInst(origDict, clonedDict); + } + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->transcribe(builder, child); + + builder->setInsertLoc(oldLoc); + + return InstPair(diffBlock, diffBlock); + } + + InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract) { - switch (type->getOp()) + IRInst* origBase = origExtract->getBase(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + + auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType()); + + IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField()); + IRInst* diffExtract = nullptr; + + if (auto diffExtractType = differentiateType(builder, primalExtractType)) { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - return builder->getFloatValue(type, 0.0); - case kIROp_IntType: - return builder->getIntValue(type, 0); - case kIROp_VectorType: + // Check if we have a getter. + if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>()) { - IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())}; - return builder->emitIntrinsicInst( - type, - kIROp_constructVectorFromScalar, - 1, + + IRInst* getterFunc = getterDecoration->getGetterFunc(); + + // Must be a method with a single parameter. + SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1); + + // Our getter func accepts a _pointer_ to the target type + // So we have to create a variable and store our type into memory + // here. This will eventually get optimized out in later passes. + // + auto diffTempVar = builder->emitVar( + diffBase->getDataType()); + + builder->emitStore(diffTempVar, diffBase); + + List<IRInst*> args; + args.add(diffTempVar); + + // Emit a call to the getter. The getter will return a reference type. + // We need to load from this to go to a non-ptr 'solid' type. + // + auto diffGetterCall = builder->emitCallInst( + as<IRFuncType>(getterFunc->getDataType())->getResultType(), + getterFunc, args); + + diffExtract = builder->emitLoad(diffGetterCall); } - default: - getSink()->diagnose(type->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; } + + return InstPair(primalExtract, diffExtract); + } + + InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress) + { + IRInst* origBase = origAddress->getBase(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto diffBase = findOrTranscribeDiffInst(builder, origBase); + + auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType()); + + IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField()); + IRInst* diffAddress = nullptr; + + if (auto diffAddressType = differentiateType(builder, primalAddressType)) + { + // If we have a getter associated with this field, we want to use that. + if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>()) + { + auto getterFunc = getterDecoration->getGetterFunc(); + + // Add the base differential inst as the argument. + List<IRInst*> args; + args.add(diffBase); + + diffAddress = builder->emitCallInst( + as<IRFuncType>(getterFunc->getDataType())->getResultType(), + getterFunc, + args); + } + + } + + return InstPair(primalAddress, diffAddress); + } + + + InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) + { + SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); + + IRInst* origBase = origGetElementPtr->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); + + auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); + + IRInst* primalOperands[] = {primalBase, primalIndex}; + IRInst* primalGetElementPtr = builder->emitIntrinsicInst( + primalType, + origGetElementPtr->getOp(), + 2, + primalOperands); + + IRInst* diffGetElementPtr = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + IRInst* diffOperands[] = {diffBase, primalIndex}; + diffGetElementPtr = builder->emitIntrinsicInst( + diffType, + origGetElementPtr->getOp(), + 2, + diffOperands); + } + } + + return InstPair(primalGetElementPtr, diffGetElementPtr); + } + + + InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) + { + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body) + auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); + + // Transcribe the break block (this is the block after the exiting the loop) + auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + + + List<IRInst*> diffLoopOperands; + diffLoopOperands.add(diffTargetBlock); + diffLoopOperands.add(diffBreakBlock); + diffLoopOperands.add(diffContinueBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); + diffLoopOperands.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + diffLoopOperands.getCount(), + diffLoopOperands.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) + { + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body). + // Note that for the condition we use the primal inst (condition values should not have a + // differential) + auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); + SLANG_ASSERT(primalConditionBlock); + + // Transcribe the break block (this is the block after the exiting the loop) + auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); + SLANG_ASSERT(diffTrueBlock); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); + SLANG_ASSERT(diffFalseBlock); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); + SLANG_ASSERT(diffAfterBlock); + + + List<IRInst*> diffIfElseArgs; + diffIfElseArgs.add(primalConditionBlock); + diffIfElseArgs.add(diffTrueBlock); + diffIfElseArgs.add(diffFalseBlock); + diffIfElseArgs.add(diffAfterBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); + diffIfElseArgs.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_ifElse, + diffIfElseArgs.getCount(), + diffIfElseArgs.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc) + { + IRFunc* primalFunc = nullptr; + + auto oldLoc = builder->getInsertLoc(); + + // If this is a top-level function, there is no need to clone it + // since it is visible in all the scopes. + // Otherwise, we need to clone it in case of generic scopes. + // + // TODO(sai): Is this the correct thing to do? Can a function cloned inside a + // generic scope but is not the return value of that generic, be used within + // that scope? Or do we have to call out to the original generic specialized with + // the current generic params? + // + bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr); + if (isTopLevelFunc) + { + builder->setInsertBefore(origFunc); + primalFunc = origFunc; + } + else + { + // TODO(sai): this might never be called, and it might never make sense + // to call it either. Potentially remove this. + primalFunc = as<IRFunc>( + cloneInst(&cloneEnv, builder, origFunc)); + } + + auto diffFunc = builder->createFunc(); + + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + builder, + as<IRFuncType>(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + // TODO(sai): Replace naming scheme + // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) + // builder->addNameHintDecoration(diffFunc, jvpName); + + // Transcribe children from origFunc into diffFunc + builder->setInsertInto(diffFunc); + for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(builder, block); + + // Reset builder position + builder->setInsertLoc(oldLoc); + + return InstPair(primalFunc, diffFunc); + } + + // Transcribe a generic definition + InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric) + { + // For now, we assume there's only one generic layer. So this inst must be top level + bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); + SLANG_RELEASE_ASSERT(isTopLevel); + + IRGeneric* primalGeneric = origGeneric; + + auto oldLoc = builder->getInsertLoc(); + builder->setInsertBefore(origGeneric); + + auto diffGeneric = builder->emitGeneric(); + + // Process type of generic. If the generic is a function, then it's type will also be a + // generic and this logic will transcribe that generic first before continuing with the + // function itself. + // + auto primalType = primalGeneric->getFullType(); + + IRType* diffType = nullptr; + if (primalType) + { + diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType); + } + + diffGeneric->setFullType(diffType); + + // TODO(sai): Replace naming scheme + // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) + // builder->addNameHintDecoration(diffFunc, jvpName); + + // Transcribe children from origFunc into diffFunc. + builder->setInsertInto(diffGeneric); + for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(builder, block); + + // Reset builder position. + builder->setInsertLoc(oldLoc); + + return InstPair(primalGeneric, diffGeneric); } IRInst* transcribe(IRBuilder* builder, IRInst* origInst) { + // If a differential intstruction is already mapped for + // this original inst, return that. + // + if (auto diffInst = lookupDiffInst(origInst, nullptr)) + { + SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + return diffInst; + } + + // Otherwise, dispatch to the appropriate method + // depending on the op-code. + // + instsInProgress.Add(origInst); InstPair pair = transcribeInst(builder, origInst); if (auto primalInst = pair.primal) { mapPrimalInst(origInst, pair.primal); - mapDifferentialInst(origInst, pair.differential); return pair.differential; } + instsInProgress.Remove(origInst); + getSink()->diagnose(origInst->sourceLoc, Diagnostics::internalCompilerError, "failed to transcibe instruction"); @@ -911,7 +1732,7 @@ struct JVPTranscriber InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) { - // Handle common operations + // Handle common SSA-style operations switch (origInst->getOp()) { case kIROp_Param: @@ -934,6 +1755,14 @@ struct JVPTranscriber case kIROp_Sub: case kIROp_Div: return transcribeBinaryArith(builder, origInst); + + case kIROp_Less: + case kIROp_Greater: + case kIROp_And: + case kIROp_Or: + case kIROp_Geq: + case kIROp_Leq: + return transcribeBinaryLogic(builder, origInst); case kIROp_Construct: return transcribeConstruct(builder, origInst); @@ -945,24 +1774,91 @@ struct JVPTranscriber return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); case kIROp_constructVectorFromScalar: + case kIROp_MakeTuple: return transcribeByPassthrough(builder, origInst); case kIROp_unconditionalBranch: - case kIROp_conditionalBranch: return transcribeControlFlow(builder, origInst); case kIROp_FloatLit: + case kIROp_IntLit: + case kIROp_VoidLit: return transcribeConst(builder, origInst); + case kIROp_Specialize: + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); + + case kIROp_lookup_interface_method: + return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); + + case kIROp_FieldExtract: + return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst)); + + case kIROp_FieldAddress: + return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst)); + + case kIROp_getElement: + case kIROp_getElementPtr: + return transcribeGetElement(builder, origInst); + + case kIROp_loop: + return transcribeLoop(builder, as<IRLoop>(origInst)); + + case kIROp_ifElse: + return transcribeIfElse(builder, as<IRIfElse>(origInst)); + + case kIROp_DifferentiableTypeDictionary: + // Ignore dictionary insts. + return InstPair(nullptr, nullptr); + } // If none of the cases have been hit, check if the instruction is a - // type. - // For now we don't have logic to differentiate types that appear in blocks. - // So, we clone and avoid differentiating them. - // + // type. Only need to explicitly differentiate types if they appear inside a block. + // if (auto origType = as<IRType>(origInst)) - return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr); + { + // If this is a generic type, transcibe the parent + // generic and derive the type from the transcribed generic's + // return value. + // + if (as<IRGeneric>(origType->getParent()->getParent()) && + findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && + !instsInProgress.Contains(origType->getParent()->getParent())) + { + auto origGenericType = origType->getParent()->getParent(); + auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); + auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); + return InstPair( + origGenericType, + innerDiffGenericType + ); + } + else if (as<IRBlock>(origType->getParent())) + return InstPair( + cloneInst(&cloneEnv, builder, origType), + differentiateType(builder, origType)); + else + return InstPair( + cloneInst(&cloneEnv, builder, origType), + nullptr); + } + + // Handle instructions with children + switch (origInst->getOp()) + { + case kIROp_Func: + return transcribeFunc(builder, as<IRFunc>(origInst)); + + case kIROp_Block: + return transcribeBlock(builder, as<IRBlock>(origInst)); + + case kIROp_Generic: + return transcribeGeneric(builder, as<IRGeneric>(origInst)); + } + // If we reach this statement, the instruction type is likely unhandled. getSink()->diagnose(origInst->sourceLoc, @@ -1042,6 +1938,14 @@ struct JVPDerivativeContext // IRMakeDifferentialPair with an IRMakeStruct. // modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage)); + + // Temporary fix: Move generated types, if any, to before their use locations. + (&pairBuilderStorage)->relocateNewTypes(builder); + + // Remove all kIROp_DifferentiableTypeDictionary instructions and + // kIROp_DifferentialGetterDecoration decorations + // + modified |= stripDiffTypeInformation(builder, module->getModuleInst()); return modified; } @@ -1079,19 +1983,45 @@ struct JVPDerivativeContext if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) { - auto baseFunction = jvpDiffInst->getBaseFn(); + auto baseInst = jvpDiffInst->getBaseFn(); + + IRGlobalValueWithCode* baseFunction = nullptr; + + if (auto specializeInst = as<IRSpecialize>(baseInst)) + { + baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase()); + } + else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) + { + baseFunction = globalValWithCode; + } + + SLANG_ASSERT(baseFunction); + // If the JVP Reference already exists, no need to // differentiate again. // - if(lookupJVPReference(baseFunction)) continue; + if (lookupJVPReference(baseFunction)) continue; - if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction))) + if (isMarkedForJVP(baseFunction)) { - IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction)); - builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction); - workQueue->push(jvpFunction); + if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) + { + IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); + SLANG_ASSERT(diffFunc); + builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc); + workQueue->push(diffFunc); + } + else + { + // TODO(Sai): This would probably be better with a more specific + // error code. + getSink()->diagnose(jvpDiffInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } } - else + else { // TODO(Sai): This would probably be better with a more specific // error code. @@ -1106,55 +2036,33 @@ struct JVPDerivativeContext return true; } - // Run through all the global-level instructions, - // looking for callables. - // Note: We're only processing global callables (IRGlobalValueWithCode) - // for now. - // - bool processMarkedGlobalFunctions(IRBuilder* builder) + IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*) { - for (auto inst : module->getGlobalInsts()) + + if (auto pairType = as<IRDifferentialPairType>(type)) { - // If the instr is a callable, get all the basic blocks - if (auto callable = as<IRGlobalValueWithCode>(inst)) - { - if (isFunctionMarkedForJVP(callable)) - { - SLANG_ASSERT(as<IRFunc>(callable)); + builder->setInsertBefore(pairType); - IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable)); - builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); + auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( + builder, + pairType->getValueType()); - unmarkForJVP(callable); - } - } - } - return true; - } + pairType->replaceUsesWith(diffPairStructType); + pairType->removeAndDeallocate(); - IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext) - { - if (diffContext->isInterfaceAvailable) + return diffPairStructType; + } + else if (auto loweredStructType = as<IRStructType>(type)) { - if (auto pairType = as<IRDifferentialPairType>(type)) - { - builder->setInsertBefore(pairType); - - auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( - builder, - pairType->getValueType()); - - pairType->replaceUsesWith(diffPairStructType); - pairType->removeAndDeallocate(); - - return diffPairStructType; - } - else if (auto loweredStructType = as<IRStructType>(type)) - { - // Already lowered to struct. - return loweredStructType; - } + // Already lowered to struct. + return loweredStructType; } + else if (auto specializedStructType = as<IRSpecialize>(type)) + { + // Already lowered to specialized struct. + return specializedStructType; + } + return nullptr; } @@ -1171,7 +2079,7 @@ struct JVPDerivativeContext operands.add(makePairInst->getPrimalValue()); operands.add(makePairInst->getDifferentialValue()); - auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands); + auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); makePairInst->replaceUsesWith(makeStructInst); makePairInst->removeAndDeallocate(); @@ -1258,10 +2166,43 @@ struct JVPDerivativeContext return modified; } + bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent) + { + bool modified = false; + + auto child = parent->getFirstChild(); + while (child) + { + auto nextChild = child->getNextInst(); + + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + { + child->removeAndDeallocate(); + child = nextChild; + modified = true; + continue; + } + + if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>()) + { + getterDecoration->removeAndDeallocate(); + } + + if (child->getFirstChild() != nullptr) + { + modified |= stripDiffTypeInformation(builder, child); + } + + child = nextChild; + } + + return modified; + } + // Checks decorators to see if the function should // be differentiated (kIROp_JVPDerivativeMarkerDecoration) // - bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable) + bool isMarkedForJVP(IRGlobalValueWithCode* callable) { for(auto decoration = callable->getFirstDecoration(); decoration; @@ -1292,63 +2233,8 @@ struct JVPDerivativeContext } } - List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType) - { - List<IRParam*> params; - for(UIndex i = 0; i < dataType->getParamCount(); i++) - { - params.add( - builder->emitParam(dataType->getParamType(i))); - } - return params; - } - - // Perform forward-mode automatic differentiation on - // the intstructions. - // - IRFunc* emitJVPFunction(IRBuilder* builder, - IRFunc* primalFn) - { - eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn); - - builder->setInsertBefore(primalFn->getNextInst()); - - auto jvpFn = builder->createFunc(); - - SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType())); - IRType* jvpFuncType = transcriberStorage.differentiateFunctionType( - builder, - as<IRFuncType>(primalFn->getFullType())); - jvpFn->setFullType(jvpFuncType); - - if (auto jvpName = getJVPFuncName(builder, primalFn)) - builder->addNameHintDecoration(jvpFn, jvpName); - - builder->setInsertInto(jvpFn); - - // Emit a block instruction for every block in the function, and map it as the - // corresponding differential. - // - for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) - { - auto jvpBlock = builder->emitBlock(); - transcriberStorage.mapDifferentialInst(block, jvpBlock); - transcriberStorage.mapPrimalInst(block, jvpBlock); - } - - // Go back over the blocks, and process the children of each block. - for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) - { - auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block)); - SLANG_ASSERT(jvpBlock); - emitJVPBlock(builder, block, jvpBlock); - } - - return jvpFn; - } - IRStringLit* getJVPFuncName(IRBuilder* builder, - IRFunc* func) + IRInst* func) { auto oldLoc = builder->getInsertLoc(); builder->setInsertBefore(func); @@ -1368,36 +2254,6 @@ struct JVPDerivativeContext return name; } - IRBlock* emitJVPBlock(IRBuilder* builder, - IRBlock* origBlock, - IRBlock* jvpBlock = nullptr) - { - JVPTranscriber* transcriber = &(transcriberStorage); - - // Create if not already created, and then insert into new block. - if (!jvpBlock) - jvpBlock = builder->emitBlock(); - else - builder->setInsertInto(jvpBlock); - - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - { - transcriber->transcribe(builder, param); - } - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - { - transcriber->transcribe(builder, child); - } - - return jvpBlock; - } - JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink), diffConformanceContextStorage(module->getModuleInst()), diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 8f8261af5..f91fc9cda 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -706,6 +706,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0) + /// Used by the auto-diff pass to hold a reference to a + /// differential getter associated with this expression. + INST(DifferentialGetterDecoration, diffGetter, 1, 0) + /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0) @@ -805,6 +809,10 @@ INST(GenericSpecializationDictionary, GenericSpecializationDictionary, 0, PARENT INST(ExistentialFuncSpecializationDictionary, ExistentialFuncSpecializationDictionary, 0, PARENT) INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDictionary, 0, PARENT) +/* Differentiable Type Dictionary */ +INST(DifferentiableTypeDictionary, DifferentiableTypeDictionary, 0, PARENT) +INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) + #undef PARENT #undef USE_OTHER #undef INST_RANGE diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 98bc6a0a2..33a2fbfb0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -554,9 +554,19 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration }; IR_LEAF_ISA(JVPDerivativeReferenceDecoration) - IRFunc* getJVPFunc() { return as<IRFunc>(getOperand(0)); } + IRInst* getJVPFunc() { return getOperand(0); } }; +struct IRDifferentialGetterDecoration : IRDecoration +{ + enum + { + kOp = kIROp_DifferentialGetterDecoration + }; + IR_LEAF_ISA(DifferentialGetterDecoration) + + IRInst* getGetterFunc() { return getOperand(0); } +}; // An instruction that replaces the function symbol // with it's derivative function. @@ -573,6 +583,15 @@ struct IRJVPDifferentiate : IRInst IR_LEAF_ISA(JVPDifferentiate) }; +// Dictionary item mapping a type with a corresponding +// IDifferentiable witness table +// +struct IRDifferentiableTypeDictionaryItem : IRInst +{ + IR_LEAF_ISA(DifferentiableTypeDictionaryItem) +}; + + // An instruction that specializes another IR value // (representing a generic) to a particular set of generic arguments // (instructions representing types, witness tables, etc.) @@ -2462,6 +2481,27 @@ public: IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); + // Emit and return a dictionary instruction to the global or generic scope. + IRInst* emitDifferentiableTypeDictionary(); + + // Emit and return a dictionary instruction to the global or generic scope, + // if one is not already present. + // + IRInst* findOrEmitDifferentiableTypeDictionary(); + + // Returns the IRDifferentiableTypeDictionary in the scope of inst. + IRInst* findDifferentiableTypeDictionary(IRInst* inst); + + // Add a differentiable type entry to the appropriate dictionary. + IRInst* addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness); + + // Lookup a differentiable type entry in the appropriate dictionary. + // This recursively looks up in upper contexts. + // + IRInst* findDifferentiableTypeEntry(IRInst* irType); + + IRInst* findDifferentiableTypeEntry(IRInst* irType, IRInst* scope); + IRInst* emitSpecializeInst( IRType* type, IRInst* genericVal, @@ -3162,6 +3202,11 @@ public: addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn); } + void addDifferentialGetterDecoration(IRInst* value, IRInst* getterFn) + { + addDecoration(value, kIROp_DifferentialGetterDecoration, getterFn); + } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index a5130e8b6..56688abae 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -238,6 +238,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_WitnessTable: case kIROp_InterfaceType: case kIROp_TaggedUnionType: + case kIROp_DifferentiableTypeDictionary: return cloneGlobalValue(this, originalValue); case kIROp_BoolLit: @@ -592,6 +593,24 @@ IRWitnessTable* cloneWitnessTableImpl( return clonedTable; } +IRInst* cloneDifferentiableTypeDictionary( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* originalDict, + IROriginalValuesForClone const& originalValues, + IRInst* dstDict = nullptr, + bool registerValue = true) +{ + IRInst* clonedDict = dstDict; + if (!clonedDict) + { + clonedDict = builder->emitDifferentiableTypeDictionary(); + } + cloneSimpleGlobalValueImpl(context, originalDict, originalValues, clonedDict, registerValue); + return clonedDict; +} + + IRWitnessTable* cloneWitnessTableWithoutRegistering( IRSpecContextBase* context, IRBuilder* builder, @@ -1118,6 +1137,9 @@ IRInst* cloneInst( case kIROp_GlobalGenericParam: return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues); + + case kIROp_DifferentiableTypeDictionary: + return cloneDifferentiableTypeDictionary(context, builder, originalInst, originalValues); default: break; @@ -1504,11 +1526,14 @@ LinkedIR linkIR( { for (auto inst : irModule->getGlobalInsts()) { - auto bindInst = as<IRBindGlobalGenericParam>(inst); - if (!bindInst) - continue; - - cloneValue(context, bindInst); + if (auto bindInst = as<IRBindGlobalGenericParam>(inst)) + { + cloneValue(context, bindInst); + } + else if (inst->getOp() == kIROp_DifferentiableTypeDictionary) + { + cloneValue(context, inst); + } } } diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index a496db3a8..05be164d4 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -318,10 +318,17 @@ IRInst* applyAccessChain( auto fieldKey = accessChain->getOperand(1); auto type = cast<IRPtrTypeBase>(accessChain->getDataType())->getValueType(); auto baseValue = applyAccessChain(context, builder, baseChain, leafVarValue); - return builder->emitFieldExtract( + auto extractInst = builder->emitFieldExtract( type, baseValue, fieldKey); + + for (auto decoration : accessChain->getDecorations()) + { + cloneDecoration(decoration, extractInst); + } + + return extractInst; } case kIROp_getElementPtr: diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 46d6d445d..2aaeb4ac3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3547,6 +3547,125 @@ namespace Slang } } + + IRInst* IRBuilder::emitDifferentiableTypeDictionary() + { + auto inst = createInst<IRInst>( + this, + kIROp_DifferentiableTypeDictionary, + nullptr); + + addGlobalValue(this, inst); + return inst; + } + + IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary() + { + auto currentLoc = this->getInsertLoc(); + auto currentInst = currentLoc.getInst(); + + if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst)) + return diffTypeDictionary; + + return emitDifferentiableTypeDictionary(); + } + + IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent) + { + //auto parent = inst->getParent(); + while (parent) + { + // Inserting into the top level of a module? + // That is fine, and we can stop searching. + if (as<IRModuleInst>(parent)) + break; + + // Inserting into a basic block inside of + // a generic? That is okay too. + if (auto block = as<IRBlock>(parent)) + { + if (as<IRGeneric>(block->parent)) + break; + } + + // Otherwise, move up the chain. + parent = parent->parent; + } + + for (auto child = parent->getFirstChild(); child; child = child->getNextInst()) + { + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + return child; + } + + return nullptr; + } + + IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness) + { + auto oldLoc = this->getInsertLoc(); + + IRDifferentiableTypeDictionaryItem* item = nullptr; + + if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary()) + { + this->setInsertInto(diffTypeDictionary); + + IRInst* args[2] = {irType, conformanceWitness}; + item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>( + this, + kIROp_DifferentiableTypeDictionaryItem, + nullptr, + 2, + args); + + addInst(item); + } + + this->setInsertLoc(oldLoc); + + return item; + } + + IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope) + { + for (auto child = scope->getFirstChild(); child; child = child->getNextInst()) + { + if (child->getOp() == kIROp_DifferentiableTypeDictionary) + { + for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst()) + { + IRInst* entryType = entry->getOperand(0); + IRInst* entryConformanceWitness = entry->getOperand(1); + + if (irType == entryType) + { + return entryConformanceWitness; + } + } + } + } + + return nullptr; + } + + IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType) + { + auto instScope = this->getInsertLoc().getInst(); + + while (instScope) + { + if (auto witness = findDifferentiableTypeEntry(irType, instScope)) + { + return witness; + } + instScope = instScope->getParent(); + } + + return nullptr; + } + + IRFunc* IRBuilder::createFunc() { IRFunc* rsFunc = createInst<IRFunc>( @@ -6322,6 +6441,37 @@ namespace Slang return inst; } + IRInst* findOuterGeneric(IRInst* inst) + { + if (inst) + { + inst = inst->getParent(); + } + else + { + return nullptr; + } + + while(inst) + { + if (as<IRGeneric>(inst)) + return inst; + + inst = inst->getParent(); + } + return nullptr; + } + + IRInst* findOuterMostGeneric(IRInst* inst) + { + IRInst* currInst = inst; + while(auto outerGeneric = findOuterGeneric(currInst)) + { + currInst = outerGeneric; + } + return currInst; + } + IRGeneric* findSpecializedGeneric(IRSpecialize* specialize) { return as<IRGeneric>(specialize->getBase()); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index c48f4b378..a2fb1be98 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1723,6 +1723,12 @@ IRInst* findGenericReturnVal(IRGeneric* generic); // Recursively find the inner most generic return value. IRInst* findInnerMostGenericReturnVal(IRGeneric* generic); +// Find the generic container, if any, that this inst is contained in +// Returns nullptr if there is no outer container. +IRInst* findOuterGeneric(IRInst* inst); +// Recursively find the outer most generic container. +IRInst* findOuterMostGeneric(IRInst* inst); + struct IRSpecialize; IRGeneric* findSpecializedGeneric(IRSpecialize* specialize); IRInst* findSpecializeReturnVal(IRSpecialize* specialize); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b03f3ae62..dc6067868 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1146,10 +1146,6 @@ static void addLinkageDecoration( { builder->addExternCppDecoration(inst, mangledName); } - if (decl->findModifier<JVPDerivativeModifier>()) - { - builder->addJVPDerivativeMarkerDecoration(inst); - } if (as<InterfaceDecl>(decl->parentDecl) && decl->parentDecl->hasModifier<ComInterfaceAttribute>()) { @@ -3042,6 +3038,38 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return info; } + LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) + { + LoweredValInfo info = lowerSubExpr(expr->inner); + + IRInst* irBaseVal = nullptr; + switch (info.flavor) + { + case LoweredValInfo::Flavor::Simple: + irBaseVal = getSimpleVal(context, info); + break; + + case LoweredValInfo::Flavor::Ptr: + irBaseVal = info.val; + break; + + default: + SLANG_UNEXPECTED("Unhandled lowered value cases"); + } + + // If the differentiable expr has an associated getter or setter, lower it + // and put it in a decoration. + // + if (expr->getterExpr != nullptr) + { + auto irGetter = lowerSubExpr(expr->getterExpr); + SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple); + getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val); + } + + return info; + } + // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -5844,6 +5872,45 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo(); } + LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl) + { + for (auto & member : decl->members) + { + if (auto entry = as<DifferentiableTypeDictionaryItem>(member)) + { + + // Lower type and witness. + IRType* irType = lowerType(context, entry->baseType); + IRInst* irWitness = lowerVal(context, entry->confWitness).val; + + SLANG_ASSERT(irType); + + // If the witness can be lowered, and the differentiable type entry exists, + // add an entry to the context. + // + if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType)) + getBuilder()->addDifferentiableTypeEntry(irType, irWitness); + } + else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member)) + { + ensureDecl(context, importEntry->dictionaryRef.getDecl()); + } + else + { + SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + + if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary()) + { + // Place the dictionary at the end of modules and generic blocks. + diffTypeDict->moveToEnd(); + } + + return LoweredValInfo(); + } + #define IGNORED_CASE(NAME) \ LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); } @@ -5853,6 +5920,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IGNORED_CASE(SyntaxDecl) IGNORED_CASE(AttributeDecl) IGNORED_CASE(NamespaceDecl) + IGNORED_CASE(DifferentiableTypeDictionaryItem) #undef IGNORED_CASE @@ -6130,7 +6198,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr); // Register the value now, rather than later, to avoid any possible infinite recursion. - setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(irWitnessTable)); + setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); auto irSubType = lowerType(subContext, subType); irWitnessTable->setOperand(0, irSubType); @@ -7219,6 +7287,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + // We only need dictionaries to be lowered for decls with executable code (i.e. statements) + // Do not lower type dictionaries for inhertiance decls or decls + // that are declaring a type, since this can create a cyclic dependancy. + // + if (as<FunctionDeclBase>(leafDecl)) + { + for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>()) + { + // We directly use lowerDecl() instead of ensureDecl() to emit to + // the current generic block instead of the top-level module. + // + lowerDecl(subContext, diffTypeDict); + } + } + return irGeneric; } @@ -7372,6 +7455,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam); } + + // Add a differentiable type dictionary if necessary. + if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock())) + markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict); } if (valuesToClone.Count() == 0) { @@ -7723,6 +7810,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, irFunc, decl); addLinkageDecoration(context, irFunc, decl); + if (decl->findModifier<JVPDerivativeModifier>()) + { + getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -8788,15 +8880,6 @@ RefPtr<IRModule> generateIRForTranslationUnit( // temporaries whenever possible. constructSSA(module); - // Process higher-order-function calls before any optimization passes - // to allow the optimizations to affect the generated funcitons. - // 1. Process JVP derivative functions. - processJVPDerivativeMarkers(module, compileRequest->getSink()); - // 2. Process VJP derivative functions. - // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. - // 3. Replace JVP & VJP calls. - processDerivativeCalls(module); - // Do basic constant folding and dead code elimination // using Sparse Conditional Constant Propagation (SCCP) // diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang new file mode 100644 index 000000000..3f0d85b60 --- /dev/null +++ b/tests/autodiff/generic-custom-jvp.slang @@ -0,0 +1,35 @@ +//TEST_IGNORE_FILE: + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias IDFloat = IFloat & IDifferentiable; + +__generic<T : IDFloat> +typedef __DifferentialPair<T> dfloat; + +import test_intrinsics; + +dpfloat my_pow_jvp(dpfloat x, dpfloat n) +{ + return dpfloat( + pow(x.p(), n.p()), + x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); +} + +[__custom_jvp(my_pow_jvp)] +float _pow(float, float); + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(5.0, 1.0); + dpfloat dpn = dpfloat(2, 0.0); + + outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[1] = __jvp(_pow)( + dpfloat(dpa.p(), 0.0), + dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 + } +} diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang new file mode 100644 index 000000000..5bf3a25c3 --- /dev/null +++ b/tests/autodiff/generic-impl-jvp.slang @@ -0,0 +1,304 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef float Real; + +typealias IDFloat = IFloat & IDifferentiable; + +__generic<T, let N : int> +struct dvector +{ + T values[N]; +}; + +__generic<T : IDFloat, let N : int> +struct myvector : IDifferentiable +{ + T values[N]; + typedef dvector<T.Differential, N> Differential; + + [__unsafeForceInlineEarly] + static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d) + { + return &(d.values); + } + + __init(T c) + { + for (int i = 0; i < N; i++) + { + values[i] = c; + } + } + + static Differential dadd(Differential a, Differential b) + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.dadd(a.values[i], b.values[i]); + } + + return output; + } + + + static Differential dmul(This a, Differential b) + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.dmul(a.values[i], b.values[i]); + } + + return output; + } + + static Differential zero() + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.zero(); + } + + return output; + } +}; + +__generic<T : IDFloat, let N : int> +__differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a.values[i] + b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> + __differentiate_jvp myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a.values[i] * b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> + __differentiate_jvp myvector<T, N> operator *(T a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a * b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> +[__custom_jvp(dot_jvp)] +T dot(myvector<T, N> a, myvector<T, N> b) +{ + T curr = (T)0.0; + for (int i = 0; i < N; i++) + { + curr = curr + (a.values[i] * b.values[i]); + } + + return curr; +} + +__generic<T : IDFloat, let N : int> +typedef __DifferentialPair<myvector<T, N>> dpvector; + +__generic<T : IDFloat, let N : int> +__DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) +{ + T.Differential curr_d = (T.zero()); + T curr_p = (T)0.0; + for (int i = 0; i < N; i++) + { + curr_p = curr_p + (a.p().values[i] * b.p().values[i]); + curr_d = T.dadd( + curr_d, + T.dadd( + T.dmul(a.p().values[i], b.d().values[i]), + T.dmul(b.p().values[i], a.d().values[i]))); + } + + return __DifferentialPair<T>(curr_p, curr_d); +} + +__generic<let N : int> +struct lineardvector +{ + myvector<Real, N>.Differential val; + + __init(vector<Real.Differential, N> a) + { + for (int i = 0; i < N; i++) + { + val.values[i] = a[i]; + } + } +}; + +__generic<let N : int> +struct linearvector : MyLinearArithmeticType, IDifferentiable +{ + typedef lineardvector<N> Differential; + + myvector<Real, N> val; + + [__unsafeForceInlineEarly] + static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec) + { + return &(dvec.val); + } + + static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v) + { + dvec.val = v; + } + + static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) + { + return linearvector<N>(a.val + b.val); + } + + static __differentiate_jvp linearvector<N> lmul(linearvector<N> a, linearvector<N> b) + { + return linearvector<N>(a.val * b.val); + } + + static __differentiate_jvp linearvector<N> lscale(float a, linearvector<N> b) + { + return linearvector<N>(a * b.val); + } + + static __differentiate_jvp float ldot(linearvector<N> a, linearvector<N> b) + { + return dot(a.val, b.val); + } + + static Differential zero() + { + lineardvector<N> dout; + dout.val = myvector<Real, N>.zero(); + return dout; + } + + static Differential dadd(Differential a, Differential b) + { + return { myvector<Real, N>.dadd(a.val, b.val) }; + } + + static Differential dmul(This a, Differential b) + { + return { myvector<Real, N>.dmul(a.val, b.val) }; + } + + __differentiate_jvp __init(vector<Real, N> a) + { + for (int i = 0; i < N; i++) + { + val.values[i] = a[i]; + } + } + + __differentiate_jvp __init(myvector<Real, N> a) + { + val = a; + } +}; + +typedef linearvector<3> myfloat3; +typedef linearvector<4> myfloat4; + +typedef lineardvector<3> mydfloat3; +typedef lineardvector<4> mydfloat4; + +typedef __DifferentialPair<Real> dpfloat; + +interface MyLinearArithmeticType +{ + static This ladd(This a, This b); + static This lmul(This a, This b); + static This lscale(Real a, This b); + static Real ldot(This a, This b); +}; + +typedef __DifferentialPair<myfloat4> dpfloat4; +typedef __DifferentialPair<myfloat3> dpfloat3; + +extension float : MyLinearArithmeticType +{ + static __differentiate_jvp float ladd(float a, float b) + { + return a + b; + } + + static __differentiate_jvp float lmul(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float lscale(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float ldot(float a, float b) + { + return a * b; + } +}; + +typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator +(T a, T b) +{ + return T.ladd(a, b); +} + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator *(T a, T b) +{ + return T.lmul(a, b); +} + +__generic<G : MyLinearArithmeticDifferentiableType> +__differentiate_jvp G f(G x) +{ + G a = x + x; + G b = x * x; + + return a * a + G.lscale((Real)3.0, x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), mydfloat4(float4(0.5, 0.8, 1.6, 2.5))); + dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5))); + + outputBuffer[0] = f(dpa.p()); // Expect: 22.0 + outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __jvp(f)(dpf4).d().val.values[3]; // Expect: 27.5 + outputBuffer[3] = __jvp(f)(dpf3).d().val.values[1]; // Expect: 40.5 + } +} diff --git a/tests/autodiff/generic-impl-jvp.slang.expected.txt b/tests/autodiff/generic-impl-jvp.slang.expected.txt new file mode 100644 index 000000000..ceeaf120e --- /dev/null +++ b/tests/autodiff/generic-impl-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +22.000000 +9.500000 +27.500000 +40.500000 +0.000000 diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 48993c21c..54a99cae9 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -1,30 +1,202 @@ -//TEST_IGNORE_FILE:(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_IGNORE_FILE:(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; -typedef __DifferentialPair<double> dpdouble; -typedef __DifferentialPair<float3> dpfloat3; +typedef float Real; -__generic<T:__BuiltinArithmeticType> -__differentiate_jvp T g(T x) +__generic<let N : int> +struct myvector { - return x + x; + vector<Real, N> val; } +extension myvector<3> : MyLinearArithmeticType +{ + static __differentiate_jvp myvector<3> ladd(myvector<3> a, myvector<3> b) + { + return myvector<3>(a.val + b.val); + } + + static __differentiate_jvp myvector<3> lmul(myvector<3> a, myvector<3> b) + { + return myvector<3>(a.val * b.val); + } + + static __differentiate_jvp myvector<3> lscale(float a, myvector<3> b) + { + return myvector<3>(a * b.val); + } + + static __differentiate_jvp float ldot(myvector<3> a, myvector<3> b) + { + return dot(a.val, b.val); + } + + __differentiate_jvp __init(vector<Real, 3> a) + { + val = a; + } +}; + + +extension myvector<4> : MyLinearArithmeticType +{ + static __differentiate_jvp myvector<4> ladd(myvector<4> a, myvector<4> b) + { + return myvector<4>(a.val + b.val); + } + + static __differentiate_jvp myvector<4> lmul(myvector<4> a, myvector<4> b) + { + return myvector<4>(a.val * b.val); + } + + static __differentiate_jvp myvector<4> lscale(float a, myvector<4> b) + { + return myvector<4>(a * b.val); + } + + static __differentiate_jvp float ldot(myvector<4> a, myvector<4> b) + { + return dot(a.val, b.val); + } + + __differentiate_jvp __init(vector<Real, 4> a) + { + val = a; + } + +}; + +typedef myvector<3> myfloat3; +typedef myvector<4> myfloat4; + +typedef __DifferentialPair<Real> dpfloat; + +interface MyLinearArithmeticType +{ + static This ladd(This a, This b); + static This lmul(This a, This b); + static This lscale(Real a, This b); + static Real ldot(This a, This b); +}; + +extension myfloat3 : IDifferentiable +{ + typedef myfloat3 Differential; + + [__unsafeForceInlineEarly] + static Ptr<float3> __getDifferentialFor_val(inout Differential dx) + { + return &(dx.val); + } + + static Differential zero() + { + return myfloat3(0); + } + + static __differentiate_jvp Differential dadd(Differential a, Differential b) + { + return a + b; + } + + static __differentiate_jvp Differential dmul(Differential a, Differential b) + { + return a * b; + } + +}; + +extension myfloat4 : IDifferentiable +{ + typedef myfloat4 Differential; + + [__unsafeForceInlineEarly] + static Ptr<float4> __getDifferentialFor_val(inout Differential dx) + { + return &(dx.val); + } + + static Differential zero() + { + return myfloat4(0); + } + + static __differentiate_jvp Differential dadd(Differential a, Differential b) + { + return a + b; + } + + static __differentiate_jvp Differential dmul(Differential a, Differential b) + { + return a * b; + } +}; + +typedef __DifferentialPair<myfloat4> dpfloat4; +typedef __DifferentialPair<myfloat3> dpfloat3; + +extension float : MyLinearArithmeticType +{ + static __differentiate_jvp float ladd(float a, float b) + { + return a + b; + } + + static __differentiate_jvp float lmul(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float lscale(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float ldot(float a, float b) + { + return a * b; + } +}; + +typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator +(T a, T b) +{ + return T.ladd(a, b); +} + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator *(T a, T b) +{ + return T.lmul(a, b); +} + +__generic<G : MyLinearArithmeticDifferentiableType> +__differentiate_jvp G f(G x) +{ + G a = x + x; + G b = x * x; + + return a * a + G.lscale((Real)3.0, x); +} + + [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { dpfloat dpa = dpfloat(2.0, 1.0); - dpdouble dpb = dpdouble(1.5, 2.0); - dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); + dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), myfloat4(float4(0.5, 0.8, 1.6, 2.5))); + dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5))); - outputBuffer[0] = f(dpa.p()); // Expect: 1 - outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.0)).d(); // Expect: 0 - outputBuffer[2] = (float)__jvp(f)(dpb).d(); // Expect: 2 - outputBuffer[3] = __jvp(f)(dpf3).d().y; // Expect: 1.5 + outputBuffer[0] = f(dpa.p()); // Expect: 22.0 + outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __jvp(f)(dpf4).d().val.w; // Expect: 27.5 + outputBuffer[3] = __jvp(f)(dpf3).d().val.y; // Expect: 40.5 } } diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt new file mode 100644 index 000000000..ceeaf120e --- /dev/null +++ b/tests/autodiff/generic-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +22.000000 +9.500000 +27.500000 +40.500000 +0.000000 diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang new file mode 100644 index 000000000..61cb96a07 --- /dev/null +++ b/tests/autodiff/getter-setter-multi.slang @@ -0,0 +1,83 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B +{ + float3 z; + float.Differential k[10]; +}; + +struct A : IDifferentiable +{ + typedef B Differential; + + float3 x; + float y[10]; + + [__unsafeForceInlineEarly] + static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b) + { + return &(b.z); + } + + [__unsafeForceInlineEarly] + static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b) + { + return &(b.k); + } + + [__unsafeForceInlineEarly] + static Differential zero() + { + B b = {0.0}; + return b; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + B o = {a.z + b.z}; + return o; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + B o = {a.x * b.z}; + return o; + } +}; + +typedef __DifferentialPair<A> dpA; + +__differentiate_jvp A f(A a) +{ + A aout; + + aout.y[5] = (2 * a.x).y; + aout.y[2] = (3 * a.y[4]); + aout.x = float3(5 * a.x.z, 3 * a.x.y, 0.5 * a.x.x); + + return aout; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + float arr[10] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; + A a = {float3(1.0, 2.0, 3.0), arr}; + + float d_arr[10] = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0 }; + B b = {float3(1.0, 0.5, 0.3), d_arr}; + + dpA dpa = dpA(a, b); + + outputBuffer[0] = __jvp(f)(dpa).d().z.z; // Expect: 0.5 + outputBuffer[1] = __jvp(f)(dpa).d().k[5]; // Expect: 1 + outputBuffer[2] = __jvp(f)(dpa).d().k[2]; // Expect: 1.5 + } +} diff --git a/tests/autodiff/getter-setter-multi.slang.expected.txt b/tests/autodiff/getter-setter-multi.slang.expected.txt new file mode 100644 index 000000000..ece9872b0 --- /dev/null +++ b/tests/autodiff/getter-setter-multi.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.500000 +1.000000 +1.500000 +0.000000 +0.000000 diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang new file mode 100644 index 000000000..6b280433b --- /dev/null +++ b/tests/autodiff/getter-setter.slang @@ -0,0 +1,69 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B +{ + float z; +}; + +struct A : IDifferentiable +{ + typedef B Differential; + + float x; + float y; + + [__unsafeForceInlineEarly] + static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b) + { + return &(b.z); + } + + [__unsafeForceInlineEarly] + static Differential zero() + { + B b = {0.0}; + return b; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + B o = {a.z + b.z}; + return o; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + B o = {a.x * b.z}; + return o; + } +}; + +typedef __DifferentialPair<A> dpA; + +__differentiate_jvp A f(A a) +{ + A aout; + aout.y = 2 * a.x; + aout.x = 5 * a.x; + + return aout; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + A a = {1.0, 2.0}; + B b = {0.2}; + + dpA dpa = dpA(a, b); + + outputBuffer[0] = __jvp(f)(dpa).d().z; // Expect: 1 + } +} diff --git a/tests/autodiff/getter-setter.slang.expected.txt b/tests/autodiff/getter-setter.slang.expected.txt new file mode 100644 index 000000000..ca54c9afe --- /dev/null +++ b/tests/autodiff/getter-setter.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +1.000000 +0.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang new file mode 100644 index 000000000..ee8bdf51d --- /dev/null +++ b/tests/autodiff/imported-custom-jvp.slang @@ -0,0 +1,25 @@ +//TEST_IGNORE_FILE: + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +import test_intrinsics_jvp; + +typedef __DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +__differentiate_jvp float f(float x) +{ + return pow_(x, 2.0); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat dpb = dpfloat(1.5, 1.0); + + outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 2 + } +} diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang new file mode 100644 index 000000000..26b5c0076 --- /dev/null +++ b/tests/autodiff/overloads-jvp.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef __DifferentialPair<float> dpfloat; +typedef __DifferentialPair<float3> dpfloat3; + +__differentiate_jvp float f(float a) +{ + return a * a + a; +} + +__differentiate_jvp float f(float3 a) +{ + return a.x * a.y + a.z; +} + +__differentiate_jvp float g(float a) +{ + // df((2.0, 4.0, 6.0), (1.0, 2.0, 3.0)) + // 2.0 * 2.0 + 4.0 * 1.0 + 3.0 = 11.0 + return f(float3(a, 2*a, 3*a)); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); + + outputBuffer[0] = f(dpa.p()); // Expect: 6 + outputBuffer[1] = f(dpf3.p()); // Expect: 8 + outputBuffer[2] = __jvp(f)(dpf3).d(); // Expect: 5.5 + outputBuffer[3] = __jvp(f)(dpa).d(); // Expect: 5 + outputBuffer[4] = __jvp(g)(dpa).d(); // Expect: 11.0 + } +} diff --git a/tests/autodiff/overloads-jvp.slang.expected.txt b/tests/autodiff/overloads-jvp.slang.expected.txt new file mode 100644 index 000000000..999777e1e --- /dev/null +++ b/tests/autodiff/overloads-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +6.0 +8.0 +5.5 +5.0 +11.0
\ No newline at end of file diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang index 333c89189..cb4c5c6b4 100644 --- a/tests/autodiff/test-intrinsics-jvp.slang +++ b/tests/autodiff/test-intrinsics-jvp.slang @@ -14,4 +14,5 @@ float max_(float x, float y); float max_jvp(float x, float y, float dx, float dy) { return (x > y) ? dx : dy; -}
\ No newline at end of file +} + diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index 393cc18ec..e05d94733 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -37,11 +37,10 @@ __differentiate_jvp float4 j(float4 x, float4 y) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - { + { float3 a = float3(2.0, 2.0, 2.0); float3 b = float3(1.5, 1.5, 1.5); float3 da = float3(1.0, 1.0, 1.0); - //dpfloat3 dpa = dpfloat3(a, da); float2 a2 = float2(2.0, 1.0); float2 b2 = float2(1.5, -2.0); |
