diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 30 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 186 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 114 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 36 | ||||
| -rw-r--r-- | source/slang/slang-ir-dce.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 451 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 122 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 92 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast-type-info.h | 3 |
21 files changed, 325 insertions, 898 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 5df9d01fe..ce52dbb56 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2737,9 +2737,6 @@ __attributeTarget(InterfaceDecl) attribute_syntax [Specialize] : SpecializeAttribute; __attributeTarget(DeclBase) -attribute_syntax [Differentiable] : DifferentiableAttribute; - -__attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; enum _BuiltinRequirementKind diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 90175dd9d..cbd3f0f0c 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -511,36 +511,6 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; -// A declaration to hold differentiable type conformances generated during -// the semantic checking phase. -// -class DifferentiableTypeDictionary : public ContainerDecl -{ - SLANG_AST_CLASS(DifferentiableTypeDictionary); -}; - -// A declaration to hold differentiable type conformances generated during -// the semantic checking phase. -// -class DifferentiableTypeDictionaryItem : public Decl -{ - SLANG_AST_CLASS(DifferentiableTypeDictionaryItem); - - DeclRefType* baseType; - SubtypeWitness* confWitness; -}; - -// A declaration that references another dictionary (generally from another module) -// Used to tell the IR lowering pass to process the referenced dictionary. -// -class DifferentiableTypeDictionaryImportItem : public Decl -{ - SLANG_AST_CLASS(DifferentiableTypeDictionaryImportItem); - - DeclRef<DifferentiableTypeDictionary> dictionaryRef; -}; - - bool isInterfaceRequirement(Decl* decl); } // namespace Slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 455a9db74..fc3c015e0 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -408,6 +408,35 @@ struct ASTDumpContext m_writer->emit("}"); } + template <typename KEY, typename VALUE> + void dump(const OrderedDictionary<KEY, VALUE>& dict) + { + m_writer->emit(" { \n"); + m_writer->indent(); + + for (auto iter : dict) + { + const auto& key = iter.Key; + const auto& value = iter.Value; + + dump(key); + m_writer->emit(" : "); + dump(value); + + m_writer->emit("\n"); + } + + m_writer->dedent(); + m_writer->emit("}"); + } + + void dump(DeclRefBase declRef) + { + StringBuilder sb; + sb << declRef; + m_writer->emit(sb.ToString()); + } + void dump(const DeclCheckStateExt& extState) { auto state = extState.getState(); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 0c1eb8d49..67ff297dc 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -977,6 +977,9 @@ class SpecializeAttribute : public Attribute class DifferentiableAttribute : public Attribute { SLANG_AST_CLASS(DifferentiableAttribute) + + /// Mapping from types to subtype witnesses for conformance to IDifferentiable. + OrderedDictionary<DeclRefBase, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; }; class DllImportAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index d4a781846..015e6969c 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1510,7 +1510,6 @@ namespace Slang DAddFunc, ///< The `IDifferentiable.dadd` function requirement DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; - } // namespace Slang #endif diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f60fbcc2c..7140d541a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -903,14 +903,6 @@ namespace Slang // If `decl` is a container, then we want to ensure its children. if(auto containerDecl = as<ContainerDecl>(decl)) { - bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr); - if (trackDiffTypes) - { - // Add a context to track differentiable types. - DifferentiableTypeSemanticContext subDiffTypeContext; - visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext); - } - // NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here, // because the visitor may add to `members` whilst iteration takes place, invalidating the iterator // and likely a crash. @@ -932,22 +924,6 @@ namespace Slang _ensureAllDeclsRec(visitor, childDecl, state); } - - if (trackDiffTypes) - { - auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext(); - - // If there were any differentiable types used in differentiable - // methods, generate a dictionary with the required info. - // - if (subDiffTypeContext->isDictionaryRequired()) - { - auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder()); - diffTypeDict->parentDecl = containerDecl; - containerDecl->members.add(diffTypeDict); - containerDecl->invalidateMemberDictionary(); - } - } } // Note: the "inner" declaration of a `GenericDecl` is currently @@ -1541,49 +1517,6 @@ namespace Slang return false; } - void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) - { - // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any - // request to check differentiable types. - // - if (!m_astBuilder->isDifferentiableInterfaceAvailable()) - return; - - auto diffInterface = m_astBuilder->getDifferentiableInterface(); - - DeclRefType* type = nullptr; - - if (auto extensionDecl = as<ExtensionDecl>(decl)) - { - // If this is an extension, use the provided target type. - type = as<DeclRefType>(extensionDecl->targetType.type); - } - else - { - // If this is a type declaration, create a decl ref without - // any substitutions. - // - auto declRef = makeDeclRef(decl); - - // TODO: Strip substitutions from the declreftype - type = DeclRefType::create(m_astBuilder, declRef); - } - - // Skip if the declaration is the interface itself. - if (type->declRef == diffInterface) - return; - - // If the DeclRefType conforms to IDifferentiable, register it with the top-level - // context. - // - if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface))) - { - // TODO: Temporarily disabled to move to new system. Fix later. - // context->registerDifferentiableType(type, witness); - } - - } - void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { // TODO: are there any other validations we can do at this point? @@ -1637,23 +1570,6 @@ namespace Slang ensureDecl(constraint, DeclCheckState::ReadyForReference); } } - - // TODO(sai): Is this the right checking stage to be doing this? - DifferentiableTypeSemanticContext diffTypeContext; - - for (Index i = 0; i < members.getCount(); ++i) - { - Decl* m = members[i]; - - if (auto typeParam = as<GenericTypeParamDecl>(m)) - { - tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext); - } - } - - auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder); - diffTypeDictionaryNode->parentDecl = genericDecl; - genericDecl->members.add(diffTypeDictionaryNode); } void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) @@ -1689,7 +1605,6 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl) { checkAggTypeConformance(aggTypeDecl); - tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext()); } // Conformances can also come via `extension` declarations, and @@ -1698,7 +1613,6 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* extensionDecl) { checkExtensionConformance(extensionDecl); - tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext()); } }; @@ -1855,32 +1769,6 @@ namespace Slang // Furthermore, because a fully checked function will have checked // its body, this also means that all function bodies and the // declarations they contain should be fully checked. - - // Generate a dictionary node to hold information about all - // available differentiable types in scope (including imports and stdlib) - // - if (getShared()->getDiffTypeContext()->isDictionaryRequired()) - finishDifferentiableTypeDictionary(moduleDecl); - } - - void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl) - { - // Grab the differentiable type information from imported modules. - for(auto importedModule : getShared()->importedModulesList) - { - this->getShared()->getDiffTypeContext()->addImportedModule(importedModule); - } - - // Grad the differentiable type information from the standard library modules. - for (auto stdLibModule : this->getSession()->stdlibModules) - { - this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl()); - } - - auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder); - diffTypeDictNode->parentDecl = moduleDecl; - - moduleDecl->members.add(diffTypeDictNode); } bool SemanticsVisitor::doesSignatureMatchRequirement( @@ -5374,11 +5262,6 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier<DifferentiableAttribute>()) - { - this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); - } - for(auto paramDecl : decl->getParameters()) { ensureDecl(paramDecl, DeclCheckState::ReadyForReference); @@ -6249,75 +6132,6 @@ namespace Slang m_mapTypeDeclToCandidateExtensions.Clear(); } - void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) - { - // Need to generate a type dictionary since we have a declaration that works with - // a differentiable type. - // - this->requireDifferentiableTypeDictionary(); - - m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness); - } - - List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList() - { - List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances; - for (auto entry : m_mapTypeToIDifferentiableWitness) - { - diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value)); - } - - return diffConformances; - } - - DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode( - ASTBuilder* builder) - { - auto dictionary = builder->create<DifferentiableTypeDictionary>(); - - for (auto item : m_mapTypeToIDifferentiableWitness) - { - auto entry = builder->create<DifferentiableTypeDictionaryItem>(); - entry->baseType = item.Key.type; - entry->confWitness = item.Value; - entry->parentDecl = dictionary; - - dictionary->members.add(entry); - } - - for (auto item : m_importedDictionaries) - { - auto entry = builder->create<DifferentiableTypeDictionaryImportItem>(); - entry->dictionaryRef = item; - entry->parentDecl = dictionary; - - dictionary->members.add(entry); - } - - return dictionary; - } - - void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl) - { - // TODO: This is a terribly slow way to find the diff type dictionary. - // Switch to lookUp() when possible (this might involve naming the dictionary something) - // - for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>()) - { - m_importedDictionaries.add(makeDeclRef(diffTypeDict)); - } - } - - void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary() - { - this->m_isTypeDictionaryRequired = true; - } - - bool DifferentiableTypeSemanticContext::isDictionaryRequired() - { - return this->m_isTypeDictionaryRequired; - } - void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) { for( auto& entry : moduleDecl->mapTypeToCandidateExtensions ) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 251849ede..ad199300a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -899,6 +899,16 @@ namespace Slang return result; } + void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) + { + SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); + if (witness) + { + m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.AddIfNotExists(type->declRef, witness); + } + } + + void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) { if (!builder->isDifferentiableInterfaceAvailable()) @@ -906,6 +916,11 @@ namespace Slang return; } + if (!m_parentDifferentiableAttr) + { + return; + } + // Check for special cases such as PtrTypeBase<T> or Array<T> // This could potentially be handled later by simply defining extensions // for Ptr<T:IDifferentiable> etc.. @@ -927,10 +942,8 @@ namespace Slang if (auto subtypeWitness = as<SubtypeWitness>( tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) { - auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); - diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness); + registerDifferentiableType((DeclRefType*)type, subtypeWitness); } - return; } } @@ -2007,20 +2020,9 @@ namespace Slang Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { - this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); - // Check/Resolve inner function declaration. expr->baseFunction = CheckTerm(expr->baseFunction); - // Register parameter types. - if (auto funcType = as<FuncType>(expr->baseFunction->type.type)) - { - for (UInt i = 0; i < funcType->getParamCount(); i++) - { - maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i)); - } - } - // For now we only support using higher order expr as callee in an invoke expr. // The actual type of the higher order function will be derived during resolve invoke. expr->type = m_astBuilder->getBottomType(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 33455e42d..a0141911a 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -214,75 +214,6 @@ namespace Slang Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; }; - - struct DifferentiableTypeSemanticContext - { - - public: - /// Registers a type as conforming to IDifferentiable, along with a witness - /// describing the relationship. - /// - void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); - - /// Returns the list of registered differentiable types. - List<KeyValuePair<DeclRefType*, SubtypeWitness*>> getDifferentiableTypeConformanceList(); - - /// Creates a DifferentiableTypeDictionary AST container node with an entry for - /// every registered type. This can be inserted into the appropriate context for the - /// auto-diff pass. - /// - DifferentiableTypeDictionary* makeDifferentiableTypeDictionaryNode(ASTBuilder* builder); - - /// Creates a DifferentiableTypeDictionary AST container node with an entry for - /// every registered type. This can be inserted into the appropriate context for the - /// auto-diff pass. - /// - void addImportedModule(ModuleDecl* importedModuleDecl); - - /// Set flag to indicate that the type dictionary is requried. - void requireDifferentiableTypeDictionary(); - - /// Returns flag indicating whether the type dictionary is requried. - bool isDictionaryRequired(); - - private: - // Nested struct to override the '==' operator for DeclRefTypes - struct DeclRefTypeKey - { - DeclRefType* type; - - DeclRefTypeKey(DeclRefType* type) : type(type) - {}; - - DeclRefTypeKey(DeclRefTypeKey& typeKey) : type(typeKey.type) - {}; - - DeclRefTypeKey() : type(nullptr) - {}; - - bool operator==(const DeclRefTypeKey& other) const - { - return (other.type->declRef == this->type->declRef); - } - - HashCode getHashCode() const - { - Hasher hasher; - hasher.hashObject(&type->declRef); - return hasher.getResult(); - } - }; - - /// Mapping from types to subtype witnesses for conformance to IDifferentiable. - OrderedDictionary<DeclRefTypeKey, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; - - /// List of external dictionaries (from imported modules) - List<DeclRef<DifferentiableTypeDictionary>> m_importedDictionaries; - - /// Flag to indicate if a differentiable type dictionary is required. - bool m_isTypeDictionaryRequired = false; - }; - /// Shared state for a semantics-checking session. struct SharedSemanticsContext { @@ -310,11 +241,6 @@ namespace Slang // List<ModuleDecl*> importedModulesList; HashSet<ModuleDecl*> importedModulesSet; - - DifferentiableTypeSemanticContext diffTypeContext; - - List<DifferentiableTypeSemanticContext*> diffTypeContextStack; - public: SharedSemanticsContext( Linkage* linkage, @@ -349,28 +275,6 @@ namespace Slang return false; } - DifferentiableTypeSemanticContext* getDiffTypeContext() - { - return &diffTypeContext; - } - - DifferentiableTypeSemanticContext* innermostDiffTypeContext() - { - return (diffTypeContextStack.getCount() > 0) ? diffTypeContextStack.getLast() : &diffTypeContext; - } - - void pushDiffTypeContext(DifferentiableTypeSemanticContext* context) - { - diffTypeContextStack.add(context); - } - - DifferentiableTypeSemanticContext* popDiffTypeContext() - { - auto context = diffTypeContextStack.getLast(); - diffTypeContextStack.removeLast(); - return context; - } - /// Get the list of extension declarations that appear to apply to `decl` in this context List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); @@ -431,6 +335,7 @@ namespace Slang SemanticsContext result(*this); result.m_parentFunc = parentFunc; result.m_outerStmts = nullptr; + result.m_parentDifferentiableAttr = parentFunc->findModifier<DifferentiableAttribute>(); return result; } @@ -519,6 +424,8 @@ namespace Slang /// The parent function (if any) that surrounds the statement being checked. FunctionDeclBase* m_parentFunc = nullptr; + DifferentiableAttribute* m_parentDifferentiableAttr = nullptr; + /// The linked list of lexically surrounding statements. OuterStmtInfo* m_outerStmts = nullptr; @@ -801,6 +708,11 @@ namespace Slang // Convert a function's original type to it's JVP type. Type* processJVPFuncType(FuncType* originalType); + /// Registers a type as conforming to IDifferentiable, along with a witness + /// describing the relationship. + /// + void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); + // Check and register a type if it is differentiable. void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); @@ -1129,16 +1041,6 @@ namespace Slang DeclRef<AssocTypeDecl> requirementDeclRef, RefPtr<WitnessTable> witnessTable); - /// Registers a type as differentiable in the currrent semantic context, if the declaration represents - /// a subtype of IDifferentable. Does nothing otherwise. - void tryAddDifferentiableConformanceToContext( - Decl* decl, - DifferentiableTypeSemanticContext* context); - - /// Generates a dictionary node for the module with all registered differentiable types, - /// as well as information about differentiable types in imported modules. - void finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl); - // Find the appropriate member of a declared type to // satisfy a requirement of an interface the type // claims to conform to. diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 6bc4b9d36..d402dde03 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -320,19 +320,6 @@ namespace Slang getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource); } } - - // Differentiable type checking. - // TODO: This can be super slow. Switch to caching the result asap. - if (this->m_parentFunc && - this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) - { - auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); - if (auto subtypeWitness = as<SubtypeWitness>( - tryGetInterfaceConformanceWitness(result, getASTBuilder()->getDifferentiableInterface()))) - { - diffTypeContext->registerDifferentiableType((DeclRefType*)result, subtypeWitness); - } - } *outProperType = result; return true; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index dd3acff78..f62007bb0 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -7,6 +7,7 @@ #include "slang-mangled-lexer.h" #include "slang-ir-clone.h" +#include "slang-ir-util.h" #include "../compiler-core/slang-artifact-desc-util.h" @@ -80,39 +81,6 @@ static UnownedStringSlice _getTypePrefix(IROp op) } } -static IROp _getTypeStyle(IROp op) -{ - switch (op) - { - case kIROp_VoidType: - case kIROp_BoolType: - { - return op; - } - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_IntType: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_UIntType: - case kIROp_Int64Type: - case kIROp_UInt64Type: - case kIROp_IntPtrType: - case kIROp_UIntPtrType: - { - // All int like - return kIROp_IntType; - } - case kIROp_HalfType: - case kIROp_FloatType: - case kIROp_DoubleType: - { - // All float like - return kIROp_FloatType; - } - default: return kIROp_Invalid; - } -} static IROp _getCType(IROp op) { @@ -912,7 +880,7 @@ void CPPSourceEmitter::_emitAnyAllDefinition(const UnownedStringSlice& funcName, IRType* retType = specOp->returnType; auto retTypeName = _getTypeName(retType); - IROp style = _getTypeStyle(elementType->getOp()); + IROp style = getTypeStyle(elementType->getOp()); const TypeDimension dim = _getTypeDimension(paramType0, false); diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 7d677b488..d58e307da 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -361,13 +361,6 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o case kIROp_WitnessTableEntry: return true; - // Special dictionaries used for differentiable type tracking - // should be kept alive. These are removed by the auto-diff pass, - // once they are used. - case kIROp_DifferentiableTypeDictionaryItem: - case kIROp_DifferentiableTypeDictionary: - return true; - default: break; } diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 3d02d4fc0..d0bf8f347 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -6,6 +6,7 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" // origX, primalX, diffX // origX -> primalX (cloneEnv) @@ -26,11 +27,9 @@ struct Pair typedef Pair<IRInst*, IRInst*> InstPair; -struct DifferentiableTypeConformanceContext +struct AutoDiffSharedContext { - Dictionary<IRInst*, IRInst*> witnessTableMap; - - IRInst* inst = nullptr; + IRModuleInst* moduleInst = nullptr; // A reference to the builtin IDifferentiable interface type. // We use this to look up all the other types (and type exprs) @@ -62,114 +61,27 @@ struct DifferentiableTypeConformanceContext // bool isInterfaceAvailable = false; - // For handling generic blocks, we use a parent pointer to allow - // looking up types in all relevant scopes. - DifferentiableTypeConformanceContext* parent = nullptr; - DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst) + AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) { - if (parent) + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) { - differentiableInterfaceType = parent->differentiableInterfaceType; - differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey; - zeroMethodStructKey = parent->zeroMethodStructKey; - addMethodStructKey = parent->addMethodStructKey; - - isInterfaceAvailable = parent->isInterfaceAvailable; - } - else - { - differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } - } - } - - DifferentiableTypeConformanceContext(IRInst* inst) : - DifferentiableTypeConformanceContext(nullptr, inst) - {} + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type) - { - SLANG_ASSERT(isInterfaceAvailable); - // TODO: Cache the returned value to avoid repeatedly scanning through - // blocks looking for the type entries. - // - if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent())) - { - return irWitness; - } - - return nullptr; - } - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) - { - if (auto conformance = lookUpConformanceForType(builder, origType)) - { - if (auto witnessTable = as<IRWitnessTable>(conformance)) - { - for (auto entry : witnessTable->getEntries()) - { - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - } - } - else if (auto witnessTableParam = as<IRParam>(conformance)) - { - return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), - witnessTableParam, - key); - } - } - - return nullptr; - } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; } - return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, addMethodStructKey); } private: IRInst* findDifferentiableInterface() { - if (auto module = as<IRModuleInst>(inst)) + if (auto module = as<IRModuleInst>(moduleInst)) { for (auto globalInst : module->getGlobalInsts()) { @@ -203,7 +115,7 @@ struct DifferentiableTypeConformanceContext IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) { - if (as<IRModuleInst>(inst) && differentiableInterfaceType) + if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) { // Assume for now that IDifferentiable has exactly four fields. SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); @@ -217,110 +129,126 @@ struct DifferentiableTypeConformanceContext return nullptr; } +}; - void loadWitnessTablesForInterface(IRInst* interfaceType) +namespace +{ + +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +{ + if (auto witnessTable = as<IRWitnessTable>(witness)) { - - if (auto module = as<IRModuleInst>(inst)) + for (auto entry : witnessTable->getEntries()) { - for (auto globalInst : module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTable && - cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() == - interfaceType) - { - // TODO: Can we have multiple conformances for the same pair of types? - // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can - // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)' - witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst); - } - } + if (entry->getRequirementKey() == requirementKey) + return entry->getSatisfyingVal(); } - else if (auto generic = as<IRGeneric>(inst)) - { - List<IRParam*> typeParams; + } + else if (auto witnessTableParam = as<IRParam>(witness)) + { + return builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + witnessTableParam, + requirementKey); + } + return nullptr; +} + +} + +struct DifferentiableTypeConformanceContext +{ + AutoDiffSharedContext* sharedContext; + + IRGlobalValueWithCode* parentFunc = nullptr; + Dictionary<IRType*, IRInst*> differentiableWitnessDictionary; + + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) + : sharedContext(shared) + {} + + void setFunc(IRGlobalValueWithCode* func) + { + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); - auto genericParam = generic->getFirstParam(); - while (genericParam) + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) { - if (as<IRTypeType>(genericParam->getDataType())) + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) { - typeParams.add(genericParam); + if (auto witness = as<IRWitnessTable>(item->getWitness())) + { + if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) + continue; + } + *existingItem = item->getWitness(); } else - break; - - genericParam = genericParam->getNextParam(); - } - - Count tableIndex = 0; - while (genericParam) - { - SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType())); - - if (tableIndex >= typeParams.getCount()) - break; - - if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType())) { - // TODO(sai): Heavily flawed way to find the right witness table. - // Rewrite this part - if (witnessTableType->getConformanceType() == differentiableInterfaceType) - witnessTableMap.Add(typeParams[tableIndex], genericParam); + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); } - else - break; - - tableIndex += 1; - genericParam = genericParam->getNextParam(); } - } - } -}; - -IRInst* findGlobal(IRInst* inst) -{ - if (inst->getParent() != inst->getModule()->getModuleInst()) + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type) { - return findGlobal(inst->getParent()); + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; } - return inst; -} - -void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst) -{ - HashSet<IRInst*> globalsOfUses; - for (auto use = globalInst->firstUse; use; use = use->nextUse) + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { - globalsOfUses.Add(findGlobal(use->getUser())); + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; } - IRInst* earliestUse = nullptr; - for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst()) - { - if (globalsOfUses.Contains(cursor)) + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + switch (origType->getOp()) { - earliestUse = cursor; + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; } + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); } - if (earliestUse) + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - globalInst->insertBefore(earliestUse); + return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); } -} + +}; struct DifferentialPairTypeBuilder { - - DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) : - diffConformanceContext(diffConformanceContext) - {} IRStructField* findField(IRInst* type, IRStructKey* key) { @@ -454,14 +382,6 @@ struct DifferentialPairTypeBuilder return emitFieldAccessor(builder, baseInst, this->globalDiffKey); } - void relocateNewTypes(IRBuilder* builder) - { - for (auto typeInst : generatedTypeList) - { - moveGlobalToBeforeUses(builder, typeInst); - } - } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) { if (!this->globalDiffKey) @@ -496,27 +416,23 @@ struct DifferentialPairTypeBuilder return this->globalPrimalKey; } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) { - if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) - { - SLANG_ASSERT(!as<IRParam>(origBaseType)); - - auto pairStructType = builder->createStructType(); - builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); - builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType); + SLANG_ASSERT(!as<IRParam>(origBaseType)); + SLANG_ASSERT(diffType); + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); - return pairStructType; - } - return nullptr; + return pairStructType; } - IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) { if (pairTypeCache.ContainsKey(origBaseType)) return pairTypeCache[origBaseType]; - auto pairType = _createDiffPairType(builder, origBaseType); + auto pairType = _createDiffPairType(builder, origBaseType, diffType); pairTypeCache.Add(origBaseType, pairType); return pairType; @@ -524,8 +440,6 @@ struct DifferentialPairTypeBuilder Dictionary<IRInst*, IRInst*> pairTypeCache; - DifferentiableTypeConformanceContext* diffConformanceContext; - IRStructKey* globalPrimalKey = nullptr; IRStructKey* globalDiffKey = nullptr; @@ -553,11 +467,17 @@ struct JVPTranscriber DiagnosticSink* sink; // Type conformance information. - DifferentiableTypeConformanceContext* diffConformanceContext; + AutoDiffSharedContext* autoDiffSharedContext; // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct DifferentialPairTypeBuilder* pairBuilder; + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + JVPTranscriber(AutoDiffSharedContext* shared) + : differentiableTypeConformanceContext(shared) + {} + DiagnosticSink* getSink() { SLANG_ASSERT(sink); @@ -692,7 +612,7 @@ struct JVPTranscriber { case kIROp_Param: if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(diffConformanceContext->getDifferentialForType( + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( builder, (IRType*)primalType)); else if (as<IRWitnessTableType>(primalType->getDataType())) @@ -737,7 +657,7 @@ struct JVPTranscriber } default: - return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType)); + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); } } @@ -753,8 +673,10 @@ struct JVPTranscriber else return nullptr; } - - return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType); + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType); + return nullptr; } InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) @@ -1325,7 +1247,7 @@ struct JVPTranscriber { // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). - auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType); + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); SLANG_ASSERT(zeroMethod); auto emptyArgList = List<IRInst*>(); @@ -1333,6 +1255,11 @@ struct JVPTranscriber } else { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + getSink()->diagnose(primalType->sourceLoc, Diagnostics::internalCompilerError, "could not generate zero value for given type"); @@ -1359,17 +1286,6 @@ struct JVPTranscriber for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) this->transcribe(builder, param); - // Look for the differentiable type dictionary and clone it (and anything else we might need). - // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement) - // that are operands. - // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks. - if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock)) - { - auto clonedDict = cloneInst(&cloneEnv, builder, origDict); - mapPrimalInst(origDict, clonedDict); - mapDifferentialInst(origDict, clonedDict); - } - // Then, run through every instruction and use the transcriber to generate the appropriate // derivative code. // @@ -1547,6 +1463,8 @@ struct JVPTranscriber { IRFunc* primalFunc = nullptr; + differentiableTypeConformanceContext.setFunc(origFunc); + auto oldLoc = builder->getInsertLoc(); // If this is a top-level function, there is no need to clone it @@ -1602,6 +1520,16 @@ struct JVPTranscriber // Transcribe a generic definition InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric) { + auto innerVal = findInnerMostGenericReturnVal(origGeneric); + if (auto innerFunc = as<IRFunc>(innerVal)) + { + differentiableTypeConformanceContext.setFunc(innerFunc); + } + else + { + return InstPair(origGeneric, nullptr); + } + // For now, we assume there's only one generic layer. So this inst must be top level bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); SLANG_RELEASE_ASSERT(isTopLevel); @@ -1757,10 +1685,6 @@ struct JVPTranscriber case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); - case kIROp_DifferentiableTypeDictionary: - // Ignore dictionary insts. - return InstPair(nullptr, nullptr); - } // If none of the cases have been hit, check if the instruction is a @@ -1885,11 +1809,8 @@ struct JVPDerivativeContext // IRDifferentialPairGetPrimal with 'primal' field access, and // IRMakeDifferentialPair with an IRMakeStruct. // - modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage)); + modified |= processPairTypes(builder, module->getModuleInst()); - // Temporary fix: Move generated types, if any, to before their use locations. - (&pairBuilderStorage)->relocateNewTypes(builder); - return modified; } @@ -1981,7 +1902,7 @@ struct JVPDerivativeContext return true; } - IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*) + IRInst* lowerPairType(IRBuilder* builder, IRType* type) { if (auto pairType = as<IRDifferentialPairType>(type)) @@ -1990,13 +1911,18 @@ struct JVPDerivativeContext if (!as<IRType>(pairType->getValueType())) { - // Do not handle non-concrete types. return nullptr; } - + auto witness = pairType->getWitness(); + auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey); + if (!diffType) + { + return nullptr; + } auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( builder, - pairType->getValueType()); + pairType->getValueType(), + (IRType*)(diffType)); pairType->replaceUsesWith(diffPairStructType); pairType->removeAndDeallocate(); @@ -2017,12 +1943,12 @@ struct JVPDerivativeContext return nullptr; } - IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext)) + if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType())) { builder->setInsertBefore(makePairInst); @@ -2041,11 +1967,11 @@ struct JVPDerivativeContext return nullptr; } - IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) { if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext)) + if (lowerPairType(builder, getDiffInst->getBase()->getDataType())) { builder->setInsertBefore(getDiffInst); @@ -2057,7 +1983,7 @@ struct JVPDerivativeContext } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext)) + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType())) { builder->setInsertBefore(getPrimalInst); @@ -2072,16 +1998,10 @@ struct JVPDerivativeContext return nullptr; } - bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext) + bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Create a new sub-context to scan witness tables inside workItem - // (mainly relevant if instWithChildren is a generic scope) - // - auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren); - (&pairBuilderStorage)->diffConformanceContext = (&subContext); - for (auto child = instWithChildren->getFirstChild(); child; ) { // Make sure the builder is at the right level. @@ -2092,53 +2012,21 @@ struct JVPDerivativeContext switch (child->getOp()) { case kIROp_DifferentialPairType: - lowerPairType(builder, as<IRType>(child), &subContext); + lowerPairType(builder, as<IRType>(child)); break; case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, child, &subContext); + lowerPairAccess(builder, child); break; case kIROp_MakeDifferentialPair: - lowerMakePair(builder, child, &subContext); + lowerMakePair(builder, child); break; default: if (child->getFirstChild()) - modified = processPairTypes(builder, child, (&subContext)) | modified; - } - - child = nextChild; - } - - // Reset the context back to the parent. - (&pairBuilderStorage)->diffConformanceContext = diffContext; - - return modified; - } - - bool stripDiffTypeInformation(IRInst* parent) - { - bool modified = false; - - auto child = parent->getFirstChild(); - while (child) - { - auto nextChild = child->getNextInst(); - - switch (child->getOp()) - { - case kIROp_DifferentiableTypeDictionary: - child->removeAndDeallocate(); - child = nextChild; - modified = true; - continue; - } - - if (child->getFirstChild() != nullptr) - { - modified |= stripDiffTypeInformation(child); + modified = processPairTypes(builder, child) | modified; } child = nextChild; @@ -2186,12 +2074,13 @@ struct JVPDerivativeContext } JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - module(module), sink(sink), - diffConformanceContextStorage(module->getModuleInst()), - pairBuilderStorage(&diffConformanceContextStorage) + module(module), + sink(sink), + autoDiffSharedContextStorage(module->getModuleInst()), + transcriberStorage(&autoDiffSharedContextStorage) { transcriberStorage.sink = sink; - transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage); + transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); transcriberStorage.pairBuilder = &(pairBuilderStorage); } @@ -2221,7 +2110,7 @@ struct JVPDerivativeContext // Context to find and manage the witness tables for types // implementing `IDifferentiable` - DifferentiableTypeConformanceContext diffConformanceContextStorage; + AutoDiffSharedContext autoDiffSharedContextStorage; // Builder for dealing with differential pair types. DifferentialPairTypeBuilder pairBuilderStorage; @@ -2243,7 +2132,6 @@ bool processForwardDifferentiableFuncs( JVPDerivativeContext context(module, sink); bool changed = context.processModule(); - changed |= context.stripDiffTypeInformation(module->getModuleInst()); return changed; } @@ -2258,6 +2146,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) { case kIROp_ForwardDerivativeDecoration: case kIROp_DerivativeMemberDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: decor->removeAndDeallocate(); break; default: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 61aa28bbe..431446f01 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -715,6 +715,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// the witness table to be easily picked up by emit. INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0) + /* Differentiable Type Dictionary */ + INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT) + /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) @@ -812,7 +815,6 @@ INST(ExistentialFuncSpecializationDictionary, ExistentialFuncSpecializationDicti INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDictionary, 0, PARENT) /* Differentiable Type Dictionary */ -INST(DifferentiableTypeDictionary, DifferentiableTypeDictionary, 0, PARENT) INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) #undef PARENT diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index deb81134b..989777944 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -598,6 +598,14 @@ struct IRForwardDifferentiate : IRInst struct IRDifferentiableTypeDictionaryItem : IRInst { IR_LEAF_ISA(DifferentiableTypeDictionaryItem) + + IRInst* getConcreteType() { return getOperand(0); } + IRInst* getWitness() { return getOperand(1); } +}; + +struct IRDifferentiableTypeDictionaryDecoration : IRInst +{ + IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration) }; @@ -2490,26 +2498,10 @@ public: IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); - // Emit and return a dictionary instruction to the global or generic scope. - IRInst* emitDifferentiableTypeDictionary(); - - // Emit and return a dictionary instruction to the global or generic scope, - // if one is not already present. - // - IRInst* findOrEmitDifferentiableTypeDictionary(); - - // Returns the IRDifferentiableTypeDictionary in the scope of inst. - IRInst* findDifferentiableTypeDictionary(IRInst* inst); + IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); // Add a differentiable type entry to the appropriate dictionary. - IRInst* addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness); - - // Lookup a differentiable type entry in the appropriate dictionary. - // This recursively looks up in upper contexts. - // - IRInst* findDifferentiableTypeEntry(IRInst* irType); - - IRInst* findDifferentiableTypeEntry(IRInst* irType, IRInst* scope); + IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); IRInst* emitSpecializeInst( IRType* type, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index eb899b69c..ad4f691f1 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -238,7 +238,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_WitnessTable: case kIROp_InterfaceType: case kIROp_TaggedUnionType: - case kIROp_DifferentiableTypeDictionary: return cloneGlobalValue(this, originalValue); case kIROp_BoolLit: @@ -593,24 +592,6 @@ IRWitnessTable* cloneWitnessTableImpl( return clonedTable; } -IRInst* cloneDifferentiableTypeDictionary( - IRSpecContextBase* context, - IRBuilder* builder, - IRInst* originalDict, - IROriginalValuesForClone const& originalValues, - IRInst* dstDict = nullptr, - bool registerValue = true) -{ - IRInst* clonedDict = dstDict; - if (!clonedDict) - { - clonedDict = builder->emitDifferentiableTypeDictionary(); - } - cloneSimpleGlobalValueImpl(context, originalDict, originalValues, clonedDict, registerValue); - return clonedDict; -} - - IRWitnessTable* cloneWitnessTableWithoutRegistering( IRSpecContextBase* context, IRBuilder* builder, @@ -1138,9 +1119,6 @@ IRInst* cloneInst( case kIROp_GlobalGenericParam: return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues); - case kIROp_DifferentiableTypeDictionary: - return cloneDifferentiableTypeDictionary(context, builder, originalInst, originalValues); - default: break; } @@ -1164,9 +1142,8 @@ IRInst* cloneInst( } builder->addInst(clonedInst); context->builder = oldBuilder; - cloneDecorations(context, clonedInst, originalInst); + cloneDecorationsAndChildren(context, clonedInst, originalInst); cloneExtraDecorations(context, clonedInst, originalValues); - return clonedInst; } @@ -1530,10 +1507,6 @@ LinkedIR linkIR( { cloneValue(context, bindInst); } - else if (inst->getOp() == kIROp_DifferentiableTypeDictionary) - { - cloneValue(context, inst); - } } } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 1f13eb754..214f10ef9 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -66,5 +66,38 @@ bool isComInterfaceType(IRType* type) return false; } +IROp getTypeStyle(IROp op) +{ + switch (op) + { + case kIROp_VoidType: + case kIROp_BoolType: + { + return op; + } + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_Int64Type: + case kIROp_UInt64Type: + case kIROp_IntPtrType: + case kIROp_UIntPtrType: + { + // All int like + return kIROp_IntType; + } + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + { + // All float like + return kIROp_FloatType; + } + default: return kIROp_Invalid; + } +} } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 2300c929d..b6690a28c 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -24,6 +24,14 @@ Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* inte bool isComInterfaceType(IRType* type); + +IROp getTypeStyle(IROp op); + +inline bool isScalarIntegerType(IRType* type) +{ + return getTypeStyle(type->getOp()) == kIROp_IntType; +} + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 382f7be5e..083ef98c5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3546,133 +3546,35 @@ namespace Slang value->insertAtEnd(parent); } } - - IRInst* IRBuilder::emitDifferentiableTypeDictionary() - { - auto inst = createInst<IRInst>( - this, - kIROp_DifferentiableTypeDictionary, - nullptr); - - addGlobalValue(this, inst); - return inst; - } - - IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary() + IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target) { - auto currentLoc = this->getInsertLoc(); - auto currentInst = currentLoc.getInst(); - - if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst)) - return diffTypeDictionary; - - return emitDifferentiableTypeDictionary(); + return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration); } - IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent) - { - //auto parent = inst->getParent(); - while (parent) - { - // Inserting into the top level of a module? - // That is fine, and we can stop searching. - if (as<IRModuleInst>(parent)) - break; - - // Inserting into a basic block inside of - // a generic? That is okay too. - if (auto block = as<IRBlock>(parent)) - { - if (as<IRGeneric>(block->parent)) - break; - } - - // Otherwise, move up the chain. - parent = parent->parent; - } - - for (auto child = parent->getFirstChild(); child; child = child->getNextInst()) - { - if (child->getOp() == kIROp_DifferentiableTypeDictionary) - return child; - } - - return nullptr; - } - - IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness) + IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness) { auto oldLoc = this->getInsertLoc(); IRDifferentiableTypeDictionaryItem* item = nullptr; - if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary()) - { - this->setInsertInto(diffTypeDictionary); + this->setInsertInto(dictDecoration); - IRInst* args[2] = {irType, conformanceWitness}; - item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>( - this, - kIROp_DifferentiableTypeDictionaryItem, - nullptr, - 2, - args); + IRInst* args[2] = {irType, conformanceWitness}; + item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>( + this, + kIROp_DifferentiableTypeDictionaryItem, + nullptr, + 2, + args); - addInst(item); - } + addInst(item); this->setInsertLoc(oldLoc); return item; } - IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope) - { - IRInst* foundResult = nullptr; - for (auto child = scope->getFirstChild(); child; child = child->getNextInst()) - { - if (child->getOp() == kIROp_DifferentiableTypeDictionary) - { - for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst()) - { - IRInst* entryType = entry->getOperand(0); - IRInst* entryConformanceWitness = entry->getOperand(1); - - if (irType == entryType) - { - foundResult = entryConformanceWitness; - // If the found witness table is not a trivial one (i.e. DifferentialBottom:IDifferential), - // return immediately. Otherwise, continue the search to see if we can find a better one. - if (auto witness = as<IRWitnessTable>(foundResult)) - { - if (witness->getConcreteType()->getOp() != kIROp_DifferentialBottomType) - return foundResult; - } - } - } - } - } - - return foundResult; - } - - IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType) - { - auto instScope = this->getInsertLoc().getInst(); - - while (instScope) - { - if (auto witness = findDifferentiableTypeEntry(irType, instScope)) - { - return witness; - } - instScope = instScope->getParent(); - } - - return nullptr; - } - IRFunc* IRBuilder::createFunc() { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index a2fb1be98..9295ca2f5 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1323,6 +1323,7 @@ SIMPLE_IR_TYPE(GenericKind, Kind) struct IRDifferentialPairType : IRType { IRType* getValueType() { return (IRType*)getOperand(0); } + IRInst* getWitness() { return (IRInst*)getOperand(1); } IR_LEAF_ISA(DifferentialPairType) }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e2b14f1e3..f8d8282d8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5866,47 +5866,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo(); } - LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl) - { - for (auto & member : decl->members) - { - if (auto entry = as<DifferentiableTypeDictionaryItem>(member)) - { - - // Lower type and witness. - IRType* irType = lowerType(context, entry->baseType); - IRInst* irWitness = lowerVal(context, entry->confWitness).val; - - SLANG_ASSERT(irType); - - // If the witness can be lowered, and the differentiable type entry exists, - // add an entry to the context. - // - if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType)) - { - getBuilder()->addDifferentiableTypeEntry(irType, irWitness); - } - } - else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member)) - { - ensureDecl(context, importEntry->dictionaryRef.getDecl()); - } - else - { - SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary"); - UNREACHABLE_RETURN(LoweredValInfo()); - } - } - - if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary()) - { - // Place the dictionary at the end of modules and generic blocks. - diffTypeDict->moveToEnd(); - } - - return LoweredValInfo(); - } - #define IGNORED_CASE(NAME) \ LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); } @@ -5916,7 +5875,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IGNORED_CASE(SyntaxDecl) IGNORED_CASE(AttributeDecl) IGNORED_CASE(NamespaceDecl) - IGNORED_CASE(DifferentiableTypeDictionaryItem) #undef IGNORED_CASE @@ -7119,6 +7077,27 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key); } + void lowerDifferentiableAttribute(IRGenContext* subContext, IRInst* inst, DifferentiableAttribute* attr) + { + auto irDict = getBuilder()->addDifferentiableTypeDictionaryDecoration(inst); + for (auto& entry : attr->m_mapTypeToIDifferentiableWitness) + { + // Lower type and witness. + IRType* irType = lowerType(subContext, entry.Value->sub); + IRInst* irWitness = lowerVal(subContext, entry.Value).val; + + SLANG_ASSERT(irType); + + // If the witness can be lowered, and the differentiable type entry exists, + // add an entry to the context. + // + if (irWitness) + { + getBuilder()->addDifferentiableTypeEntry(irDict, irType, irWitness); + } + } + } + LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) { // Each field declaration in the AST translates into @@ -7170,14 +7149,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // output for the chosen target. addTargetIntrinsicDecorations(irFieldKey, fieldDecl); - return LoweredValInfo::simple(irFieldKey); } - - - - bool isImportedDecl(Decl* decl) { return Slang::isImportedDecl(context, decl); @@ -7196,6 +7170,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> GenericTypeConstraintDecl* constraintDecl, IRType* supType) { + auto subBuilder = subContext->irBuilder; // There are two cases we care about here. @@ -7311,21 +7286,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - // We only need dictionaries to be lowered for decls with executable code (i.e. statements) - // Do not lower type dictionaries for inhertiance decls or decls - // that are declaring a type, since this can create a cyclic dependancy. - // - if (as<FunctionDeclBase>(leafDecl)) - { - for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>()) - { - // We directly use lowerDecl() instead of ensureDecl() to emit to - // the current generic block instead of the top-level module. - // - lowerDecl(subContext, diffTypeDict); - } - } - return irGeneric; } @@ -7479,10 +7439,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam); } - - // Add a differentiable type dictionary if necessary. - if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock())) - markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict); } if (valuesToClone.Count() == 0) { @@ -7838,6 +7794,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addForwardDifferentiableDecoration(irFunc); } + if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) + { + lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); + } // Always force inline diff setter accessor to prevent downstream compiler from complaining // fields are not fully initialized for the first `inout` parameter. diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index 937ecc95f..0412ef4da 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -83,6 +83,9 @@ struct SerialGetFieldType<DeclRef<T>> template <typename T> struct SerialTypeInfo<DeclRef<T>> : public SerialDeclRefBaseTypeInfo {}; +template<> +struct SerialTypeInfo<DeclRefBase> : public SerialDeclRefBaseTypeInfo {}; + // MatrixCoord can just go as is template <> struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {}; |
