diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-check-decl.cpp | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 205 |
1 files changed, 203 insertions, 2 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b18e1c4da..2d6e20622 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -835,7 +835,15 @@ namespace Slang // If `decl` is a container, then we want to ensure its children. if(auto containerDecl = as<ContainerDecl>(decl)) - { + { + bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr); + if (trackDiffTypes) + { + // Add a context to track differentiable types. + DifferentiableTypeSemanticContext subDiffTypeContext; + visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext); + } + // NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here, // because the visitor may add to `members` whilst iteration takes place, invalidating the iterator // and likely a crash. @@ -857,6 +865,21 @@ namespace Slang _ensureAllDeclsRec(visitor, childDecl, state); } + + if (trackDiffTypes) + { + auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext(); + + // If there were any differentiable types used in differentiable + // methods, generate a dictionary with the required info. + // + if (subDiffTypeContext->isDictionaryRequired()) + { + auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder()); + diffTypeDict->parentDecl = containerDecl; + containerDecl->members.add(diffTypeDict); + } + } } // Note: the "inner" declaration of a `GenericDecl` is currently @@ -1234,6 +1257,49 @@ namespace Slang } } + void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) + { + // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any + // request to check differentiable types. + // + if (!m_astBuilder->isDifferentiableInterfaceAvailable()) + return; + + auto diffInterface = m_astBuilder->getDifferentiableInterface(); + + DeclRefType* type = nullptr; + + if (auto extensionDecl = as<ExtensionDecl>(decl)) + { + // If this is an extension, use the provided target type. + type = as<DeclRefType>(extensionDecl->targetType.type); + } + else + { + // If this is a type declaration, create a decl ref without + // any substitutions. + // + auto declRef = makeDeclRef(decl); + + // TODO: Strip substitutions from the declreftype + type = DeclRefType::create(m_astBuilder, declRef); + } + + // Skip if the declaration is the interface itself. + if (type->declRef == diffInterface) + return; + + // If the DeclRefType conforms to IDifferentiable, register it with the top-level + // context. + // + if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface))) + { + // TODO: Temporarily disabled to move to new system. Fix later. + // context->registerDifferentiableType(type, witness); + } + + } + void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { // TODO: are there any other validations we can do at this point? @@ -1287,6 +1353,23 @@ namespace Slang ensureDecl(constraint, DeclCheckState::ReadyForReference); } } + + // TODO(sai): Is this the right checking stage to be doing this? + DifferentiableTypeSemanticContext diffTypeContext; + + for (Index i = 0; i < members.getCount(); ++i) + { + Decl* m = members[i]; + + if (auto typeParam = as<GenericTypeParamDecl>(m)) + { + tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext); + } + } + + auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder); + diffTypeDictionaryNode->parentDecl = genericDecl; + genericDecl->members.add(diffTypeDictionaryNode); } void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) @@ -1322,6 +1405,7 @@ namespace Slang void visitAggTypeDecl(AggTypeDecl* aggTypeDecl) { checkAggTypeConformance(aggTypeDecl); + tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext()); } // Conformances can also come via `extension` declarations, and @@ -1330,6 +1414,7 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* extensionDecl) { checkExtensionConformance(extensionDecl); + tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext()); } }; @@ -1486,6 +1571,32 @@ namespace Slang // Furthermore, because a fully checked function will have checked // its body, this also means that all function bodies and the // declarations they contain should be fully checked. + + // Generate a dictionary node to hold information about all + // available differentiable types in scope (including imports and stdlib) + // + if (getShared()->getDiffTypeContext()->isDictionaryRequired()) + finishDifferentiableTypeDictionary(moduleDecl); + } + + void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl) + { + // Grab the differentiable type information from imported modules. + for(auto importedModule : getShared()->importedModulesList) + { + this->getShared()->getDiffTypeContext()->addImportedModule(importedModule); + } + + // Grad the differentiable type information from the standard library modules. + for (auto stdLibModule : this->getSession()->stdlibModules) + { + this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl()); + } + + auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder); + diffTypeDictNode->parentDecl = moduleDecl; + + moduleDecl->members.add(diffTypeDictNode); } bool SemanticsVisitor::doesSignatureMatchRequirement( @@ -4292,7 +4403,23 @@ namespace Slang nullptr); args.add(val); } - // TODO: need to handle constraints here? + } + + // Add defaults for constraint parameters. + for (auto dd : genericDecl->members) + { + if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd)) + { + // Convert the constraint to an appropriate witness. + auto witness = tryGetSubtypeWitness(constraintDecl->sub, constraintDecl->sup); + + // Must be non-null since we know there's a constraint. If null, something is + // very wrong. + // + SLANG_ASSERT(witness); + + args.add(witness); + } } GenericSubstitution* subst = m_astBuilder->getOrCreateGenericSubstitution(genericDecl, args, nullptr); return subst; @@ -4725,6 +4852,11 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { + if (decl->findModifier<JVPDerivativeModifier>()) + { + this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); + } + for(auto paramDecl : decl->getParameters()) { ensureDecl(paramDecl, DeclCheckState::ReadyForReference); @@ -5594,6 +5726,75 @@ namespace Slang m_candidateExtensionListsBuilt = false; m_mapTypeDeclToCandidateExtensions.Clear(); } + + void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) + { + // Need to generate a type dictionary since we have a declaration that works with + // a differentiable type. + // + this->requireDifferentiableTypeDictionary(); + + m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness); + } + + List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList() + { + List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances; + for (auto entry : m_mapTypeToIDifferentiableWitness) + { + diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value)); + } + + return diffConformances; + } + + DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode( + ASTBuilder* builder) + { + auto dictionary = builder->create<DifferentiableTypeDictionary>(); + + for (auto item : m_mapTypeToIDifferentiableWitness) + { + auto entry = builder->create<DifferentiableTypeDictionaryItem>(); + entry->baseType = item.Key.type; + entry->confWitness = item.Value; + entry->parentDecl = dictionary; + + dictionary->members.add(entry); + } + + for (auto item : m_importedDictionaries) + { + auto entry = builder->create<DifferentiableTypeDictionaryImportItem>(); + entry->dictionaryRef = item; + entry->parentDecl = dictionary; + + dictionary->members.add(entry); + } + + return dictionary; + } + + void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl) + { + // TODO: This is a terribly slow way to find the diff type dictionary. + // Switch to lookUp() when possible (this might involve naming the dictionary something) + // + for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>()) + { + m_importedDictionaries.add(makeDeclRef(diffTypeDict)); + } + } + + void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary() + { + this->m_isTypeDictionaryRequired = true; + } + + bool DifferentiableTypeSemanticContext::isDictionaryRequired() + { + return this->m_isTypeDictionaryRequired; + } void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl) { |
