diff options
Diffstat (limited to 'source/slang/slang-check-inheritance.cpp')
| -rw-r--r-- | source/slang/slang-check-inheritance.cpp | 205 |
1 files changed, 139 insertions, 66 deletions
diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 3e59c5e8d..0dc80cdc3 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -7,14 +7,14 @@ namespace Slang { - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type) + InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) { // We cache the computed inheritance information for types, // and re-use that information whenever possible. // DeclRefTypes will have their inheritance info cached in m_mapDeclRefToInheritanceInfo. if (auto declRefType = as<DeclRefType>(type)) - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType); + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); // Non ordinary types are cached on m_mapTypeToInheritanceInfo. if (auto found = m_mapTypeToInheritanceInfo.tryGetValue(type)) @@ -29,22 +29,48 @@ namespace Slang // m_mapTypeToInheritanceInfo[type] = InheritanceInfo(); - auto info = _calcInheritanceInfo(type); + auto info = _calcInheritanceInfo(type, circularityInfo); m_mapTypeToInheritanceInfo[type] = info; return info; } - InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef<ExtensionDecl> const& extension) + InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef<ExtensionDecl> const& extension, InheritanceCircularityInfo* circularityInfo) { + if (_checkForCircularityInExtensionTargetType(extension.getDecl(), circularityInfo)) + { + // If we detect a circularity in the inheritance graph, + // we will return an empty `InheritanceInfo` to avoid + // infinite recursion. + // + return InheritanceInfo(); + } + // We bottleneck the calculation of inheritance information // for type and `extension` `DeclRef`s through a single // routine with an optional `Type` parameter. // - return _getInheritanceInfo(extension, nullptr); + InheritanceCircularityInfo newCircularityInfo(extension.getDecl(), circularityInfo); + return _getInheritanceInfo(extension, nullptr, &newCircularityInfo); } - InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType) + bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType( + Decl* decl, + InheritanceCircularityInfo* circularityInfo) + { + for (auto info = circularityInfo; info; info = info->next) + { + if (decl == info->decl) + { + getSink()->diagnose(decl, Diagnostics::circularityInExtension, decl); + return true; + } + } + + return false; + } + + InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // Just as with `Type`s, we cache and re-use the inheritance // information that has been computed for a `DeclRef` whenever @@ -62,7 +88,7 @@ namespace Slang // m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo(); - auto info = _calcInheritanceInfo(declRef, declRefType); + auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo); m_mapDeclRefToInheritanceInfo[declRef] = info; getSession()->m_typeDictionarySize = Math::Max( @@ -71,7 +97,7 @@ namespace Slang return info; } - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType) + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // This method is the main engine for computing linearized inheritance // lists for types and `extension` declarations. @@ -215,7 +241,7 @@ namespace Slang // SLANG_ASSERT(selfIsBaseWitness); - auto baseInheritanceInfo = getInheritanceInfo(baseType); + auto baseInheritanceInfo = getInheritanceInfo(baseType, circularityInfo); DeclRef<Decl> baseDeclRef; if (auto baseDeclRefType = as<DeclRefType>(baseType)) @@ -231,6 +257,51 @@ namespace Slang baseInheritanceInfo); }; + // If we know the type has a facet represented by `extensionTargetDeclRef`, we can consider + // all extensions on this decl to see if they apply to the type. + // + auto considerExtension = [&](DeclRef<AggTypeDecl> extensionTargetDeclRef, Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitness) + { + bool result = false; + for (auto extDecl : getCandidateExtensions(extensionTargetDeclRef, &visitor)) + { + // The list of *candidate* extensions is computed and + // cached based on the identity of the declaration alone, + // and does not take into account any generic arguments + // of either the type or the `extension`. + // + // For example, we might have an `extension` that applies + // to `vector<float,N>` for any `N`, but the `selfType` + // that we are working with could be `<vector<int,2>` so + // that the extension doesn't match. + // + // In order to make sure that we don't enumerate members + // that don't make sense in context, we must apply + // the extension to the type and see if we succeed in + // making a match. + // + auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType, additionalSubtypeWitness); + if (!extDeclRef) + continue; + + // In the case where we *do* find an extension that + // applies to the type, we add a declared base to + // represent the `extension`, knowing that its + // own linearized inheritance list will include + // any transitive based declared on the `extension`. + // + auto extInheritanceInfo = getInheritanceInfo(extDeclRef, circularityInfo); + addDirectBaseFacet( + Facet::Kind::Extension, + selfType, + selfIsSelf, + extDeclRef, + extInheritanceInfo); + result = true; + } + return result; + }; + // We now look at the structure of the declaration itself // to help us enumerate the direct bases. // @@ -280,9 +351,26 @@ namespace Slang auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); - ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); + if (auto extensionDecl = as<ExtensionDecl>(genericDeclRef.getDecl()->inner)) + { + if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == genericTypeParamDeclRef) + { + // If `T` is a generic parameter where the same generic is an extension on `T`, + // then we need to add the extension itself as a facet. + // + auto extDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, &visitor, extensionDecl); + auto selfExtFacet = new(arena) Facet::Impl( + Facet::Kind::Extension, + Facet::Directness::Direct, + extDeclRef, + selfType, + astBuilder->getTypeEqualityWitness(selfType)); + allFacets.add(selfExtFacet); + } + } + for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDeclRef)) { auto subType = getSub(astBuilder, constraintDeclRef); @@ -326,63 +414,48 @@ namespace Slang // to consider any `extension` declarations that might apply to // a type being delared. // - // In our current system, only nominal types (those with `Decl`s) - // can be extended, so we begin by checking if the `selfType` - // is a nominal/`DeclRef` type. - // - // Note: this step will *not* apply when `declRef` is an `extension` - // declaration, since it directly checks for an `AggTypeDecl` - // instead of an `AggTypeDeclBase`. - // - // Similarly, we do *not* add the type being extended to the list - // of bases for an `extension`. - // - // These choices are important to avoid circular dependencies, where - // the linearization of an `extension` would end up depending on its - // own linearization (either directly or through a dependency on - // the linearization of the type being extended). - // - // Instead, the linearization we create here for an `extension` will - // *only* contain facets for the members introduced by the `extension` - // itself, as well as any transitive bases declared on that `extension`. + // An `extension` may apply to our type, if it directly extends + // the type, or extends a generic `T` type that are constrained + // on one of the interfaces that our type conforms to. // if (auto directAggTypeDeclRef = declRef.as<AggTypeDecl>()) { - for (auto extDecl : getCandidateExtensions(directAggTypeDeclRef, &visitor)) + considerExtension(directAggTypeDeclRef, nullptr); + } + HashSet<Type*> supTypesConsideredForExtensionApplication; + Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses; + for (;;) + { + // After we flatten the list of bases, we may discover additional opportunities + // to apply extensions. + List<DeclRef<AggTypeDecl>> supTypeWorkList; + for (auto curFacet : directBaseFacets) { - // The list of *candidate* extensions is computed and - // cached based on the identity of the declaration alone, - // and does not take into account any generic arguments - // of either the type or the `extension`. - // - // For example, we might have an `extension` that applies - // to `vector<float,N>` for any `N`, but the `selfType` - // that we are working with could be `<vector<int,2>` so - // that the extension doesn't match. - // - // In order to make sure that we don't enumerate members - // that don't make sense in context, we must apply - // the extension to the type and see if we succeed in - // making a match. - // - auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType); - if (!extDeclRef) + if (!curFacet->subtypeWitness) continue; - - // In the case where we *do* find an extension that - // applies to the type, we add a declared base to - // represent the `extension`, knowing that its - // own linearized inheritance list will include - // any transitive based declared on the `extension`. - // - auto extInheritanceInfo = getInheritanceInfo(extDeclRef); - addDirectBaseFacet( - Facet::Kind::Extension, - selfType, - selfIsSelf, - extDeclRef, - extInheritanceInfo); + auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo); + for (auto facet : inheritanceInfo.facets) + { + if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>()) + { + SubtypeWitness* transitiveWitness = curFacet->subtypeWitness; + transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness); + additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness); + if (supTypesConsideredForExtensionApplication.add(facet->origin.type)) + { + supTypeWorkList.add(interfaceDeclRef); + } + } + } } + bool canExit = true; + for (auto baseItem : supTypeWorkList) + { + if (considerExtension(baseItem, &additionalSubtypeWitnesses)) + canExit = false; + } + if (canExit) + break; } // At this point, the list of direct bases (each with its own linearization) @@ -846,7 +919,7 @@ namespace Slang return false; } - InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type) + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo) { // The majority of the interesting for for computing linearized // inheritance information arises for `DeclRef`s, but we still @@ -861,7 +934,7 @@ namespace Slang // bottleneck through the logic that gets shared between // type and `extension` declarations. // - return _getInheritanceInfo(declRefType->getDeclRef(), declRefType); + return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo); } else if (auto conjunctionType = as<AndType>(type)) { @@ -875,8 +948,8 @@ namespace Slang // must include all the facets from the lists for `L` // and `R`, respectively. // - auto leftInfo = getInheritanceInfo(leftType); - auto rightInfo = getInheritanceInfo(rightType); + auto leftInfo = getInheritanceInfo(leftType, circularityInfo); + auto rightInfo = getInheritanceInfo(rightType, circularityInfo); // We have a case of subtype witness that can show that // `T : L` or `T : R` based on `T : L&R`. In this case, @@ -931,7 +1004,7 @@ namespace Slang } else if (auto eachType = as<EachType>(type)) { - auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType()); + auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType(), circularityInfo); SemanticsVisitor visitor(this); auto directFacet = new(arena) Facet::Impl( Facet::Kind::Type, |
