summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-inheritance.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-inheritance.cpp')
-rw-r--r--source/slang/slang-check-inheritance.cpp205
1 files changed, 139 insertions, 66 deletions
diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp
index 3e59c5e8d..0dc80cdc3 100644
--- a/source/slang/slang-check-inheritance.cpp
+++ b/source/slang/slang-check-inheritance.cpp
@@ -7,14 +7,14 @@
namespace Slang
{
- InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type)
+ InheritanceInfo SharedSemanticsContext::getInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo)
{
// We cache the computed inheritance information for types,
// and re-use that information whenever possible.
// DeclRefTypes will have their inheritance info cached in m_mapDeclRefToInheritanceInfo.
if (auto declRefType = as<DeclRefType>(type))
- return _getInheritanceInfo(declRefType->getDeclRef(), declRefType);
+ return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo);
// Non ordinary types are cached on m_mapTypeToInheritanceInfo.
if (auto found = m_mapTypeToInheritanceInfo.tryGetValue(type))
@@ -29,22 +29,48 @@ namespace Slang
//
m_mapTypeToInheritanceInfo[type] = InheritanceInfo();
- auto info = _calcInheritanceInfo(type);
+ auto info = _calcInheritanceInfo(type, circularityInfo);
m_mapTypeToInheritanceInfo[type] = info;
return info;
}
- InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef<ExtensionDecl> const& extension)
+ InheritanceInfo SharedSemanticsContext::getInheritanceInfo(DeclRef<ExtensionDecl> const& extension, InheritanceCircularityInfo* circularityInfo)
{
+ if (_checkForCircularityInExtensionTargetType(extension.getDecl(), circularityInfo))
+ {
+ // If we detect a circularity in the inheritance graph,
+ // we will return an empty `InheritanceInfo` to avoid
+ // infinite recursion.
+ //
+ return InheritanceInfo();
+ }
+
// We bottleneck the calculation of inheritance information
// for type and `extension` `DeclRef`s through a single
// routine with an optional `Type` parameter.
//
- return _getInheritanceInfo(extension, nullptr);
+ InheritanceCircularityInfo newCircularityInfo(extension.getDecl(), circularityInfo);
+ return _getInheritanceInfo(extension, nullptr, &newCircularityInfo);
}
- InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType)
+ bool SharedSemanticsContext::_checkForCircularityInExtensionTargetType(
+ Decl* decl,
+ InheritanceCircularityInfo* circularityInfo)
+ {
+ for (auto info = circularityInfo; info; info = info->next)
+ {
+ if (decl == info->decl)
+ {
+ getSink()->diagnose(decl, Diagnostics::circularityInExtension, decl);
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ InheritanceInfo SharedSemanticsContext::_getInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo)
{
// Just as with `Type`s, we cache and re-use the inheritance
// information that has been computed for a `DeclRef` whenever
@@ -62,7 +88,7 @@ namespace Slang
//
m_mapDeclRefToInheritanceInfo[declRef] = InheritanceInfo();
- auto info = _calcInheritanceInfo(declRef, declRefType);
+ auto info = _calcInheritanceInfo(declRef, declRefType, circularityInfo);
m_mapDeclRefToInheritanceInfo[declRef] = info;
getSession()->m_typeDictionarySize = Math::Max(
@@ -71,7 +97,7 @@ namespace Slang
return info;
}
- InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType)
+ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo)
{
// This method is the main engine for computing linearized inheritance
// lists for types and `extension` declarations.
@@ -215,7 +241,7 @@ namespace Slang
//
SLANG_ASSERT(selfIsBaseWitness);
- auto baseInheritanceInfo = getInheritanceInfo(baseType);
+ auto baseInheritanceInfo = getInheritanceInfo(baseType, circularityInfo);
DeclRef<Decl> baseDeclRef;
if (auto baseDeclRefType = as<DeclRefType>(baseType))
@@ -231,6 +257,51 @@ namespace Slang
baseInheritanceInfo);
};
+ // If we know the type has a facet represented by `extensionTargetDeclRef`, we can consider
+ // all extensions on this decl to see if they apply to the type.
+ //
+ auto considerExtension = [&](DeclRef<AggTypeDecl> extensionTargetDeclRef, Dictionary<Type*, SubtypeWitness*>* additionalSubtypeWitness)
+ {
+ bool result = false;
+ for (auto extDecl : getCandidateExtensions(extensionTargetDeclRef, &visitor))
+ {
+ // The list of *candidate* extensions is computed and
+ // cached based on the identity of the declaration alone,
+ // and does not take into account any generic arguments
+ // of either the type or the `extension`.
+ //
+ // For example, we might have an `extension` that applies
+ // to `vector<float,N>` for any `N`, but the `selfType`
+ // that we are working with could be `<vector<int,2>` so
+ // that the extension doesn't match.
+ //
+ // In order to make sure that we don't enumerate members
+ // that don't make sense in context, we must apply
+ // the extension to the type and see if we succeed in
+ // making a match.
+ //
+ auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType, additionalSubtypeWitness);
+ if (!extDeclRef)
+ continue;
+
+ // In the case where we *do* find an extension that
+ // applies to the type, we add a declared base to
+ // represent the `extension`, knowing that its
+ // own linearized inheritance list will include
+ // any transitive based declared on the `extension`.
+ //
+ auto extInheritanceInfo = getInheritanceInfo(extDeclRef, circularityInfo);
+ addDirectBaseFacet(
+ Facet::Kind::Extension,
+ selfType,
+ selfIsSelf,
+ extDeclRef,
+ extInheritanceInfo);
+ result = true;
+ }
+ return result;
+ };
+
// We now look at the structure of the declaration itself
// to help us enumerate the direct bases.
//
@@ -280,9 +351,26 @@ namespace Slang
auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef);
-
ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric);
+ if (auto extensionDecl = as<ExtensionDecl>(genericDeclRef.getDecl()->inner))
+ {
+ if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == genericTypeParamDeclRef)
+ {
+ // If `T` is a generic parameter where the same generic is an extension on `T`,
+ // then we need to add the extension itself as a facet.
+ //
+ auto extDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, &visitor, extensionDecl);
+ auto selfExtFacet = new(arena) Facet::Impl(
+ Facet::Kind::Extension,
+ Facet::Directness::Direct,
+ extDeclRef,
+ selfType,
+ astBuilder->getTypeEqualityWitness(selfType));
+ allFacets.add(selfExtFacet);
+ }
+ }
+
for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDeclRef))
{
auto subType = getSub(astBuilder, constraintDeclRef);
@@ -326,63 +414,48 @@ namespace Slang
// to consider any `extension` declarations that might apply to
// a type being delared.
//
- // In our current system, only nominal types (those with `Decl`s)
- // can be extended, so we begin by checking if the `selfType`
- // is a nominal/`DeclRef` type.
- //
- // Note: this step will *not* apply when `declRef` is an `extension`
- // declaration, since it directly checks for an `AggTypeDecl`
- // instead of an `AggTypeDeclBase`.
- //
- // Similarly, we do *not* add the type being extended to the list
- // of bases for an `extension`.
- //
- // These choices are important to avoid circular dependencies, where
- // the linearization of an `extension` would end up depending on its
- // own linearization (either directly or through a dependency on
- // the linearization of the type being extended).
- //
- // Instead, the linearization we create here for an `extension` will
- // *only* contain facets for the members introduced by the `extension`
- // itself, as well as any transitive bases declared on that `extension`.
+ // An `extension` may apply to our type, if it directly extends
+ // the type, or extends a generic `T` type that are constrained
+ // on one of the interfaces that our type conforms to.
//
if (auto directAggTypeDeclRef = declRef.as<AggTypeDecl>())
{
- for (auto extDecl : getCandidateExtensions(directAggTypeDeclRef, &visitor))
+ considerExtension(directAggTypeDeclRef, nullptr);
+ }
+ HashSet<Type*> supTypesConsideredForExtensionApplication;
+ Dictionary<Type*, SubtypeWitness*> additionalSubtypeWitnesses;
+ for (;;)
+ {
+ // After we flatten the list of bases, we may discover additional opportunities
+ // to apply extensions.
+ List<DeclRef<AggTypeDecl>> supTypeWorkList;
+ for (auto curFacet : directBaseFacets)
{
- // The list of *candidate* extensions is computed and
- // cached based on the identity of the declaration alone,
- // and does not take into account any generic arguments
- // of either the type or the `extension`.
- //
- // For example, we might have an `extension` that applies
- // to `vector<float,N>` for any `N`, but the `selfType`
- // that we are working with could be `<vector<int,2>` so
- // that the extension doesn't match.
- //
- // In order to make sure that we don't enumerate members
- // that don't make sense in context, we must apply
- // the extension to the type and see if we succeed in
- // making a match.
- //
- auto extDeclRef = applyExtensionToType(&visitor, extDecl, selfType);
- if (!extDeclRef)
+ if (!curFacet->subtypeWitness)
continue;
-
- // In the case where we *do* find an extension that
- // applies to the type, we add a declared base to
- // represent the `extension`, knowing that its
- // own linearized inheritance list will include
- // any transitive based declared on the `extension`.
- //
- auto extInheritanceInfo = getInheritanceInfo(extDeclRef);
- addDirectBaseFacet(
- Facet::Kind::Extension,
- selfType,
- selfIsSelf,
- extDeclRef,
- extInheritanceInfo);
+ auto inheritanceInfo = getInheritanceInfo(curFacet->subtypeWitness->getSup(), circularityInfo);
+ for (auto facet : inheritanceInfo.facets)
+ {
+ if (auto interfaceDeclRef = facet->origin.declRef.as<InterfaceDecl>())
+ {
+ SubtypeWitness* transitiveWitness = curFacet->subtypeWitness;
+ transitiveWitness = astBuilder->getTransitiveSubtypeWitness(curFacet->subtypeWitness, facet->subtypeWitness);
+ additionalSubtypeWitnesses.addIfNotExists(facet->origin.type, transitiveWitness);
+ if (supTypesConsideredForExtensionApplication.add(facet->origin.type))
+ {
+ supTypeWorkList.add(interfaceDeclRef);
+ }
+ }
+ }
}
+ bool canExit = true;
+ for (auto baseItem : supTypeWorkList)
+ {
+ if (considerExtension(baseItem, &additionalSubtypeWitnesses))
+ canExit = false;
+ }
+ if (canExit)
+ break;
}
// At this point, the list of direct bases (each with its own linearization)
@@ -846,7 +919,7 @@ namespace Slang
return false;
}
- InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type)
+ InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo)
{
// The majority of the interesting for for computing linearized
// inheritance information arises for `DeclRef`s, but we still
@@ -861,7 +934,7 @@ namespace Slang
// bottleneck through the logic that gets shared between
// type and `extension` declarations.
//
- return _getInheritanceInfo(declRefType->getDeclRef(), declRefType);
+ return _getInheritanceInfo(declRefType->getDeclRef(), declRefType, circularityInfo);
}
else if (auto conjunctionType = as<AndType>(type))
{
@@ -875,8 +948,8 @@ namespace Slang
// must include all the facets from the lists for `L`
// and `R`, respectively.
//
- auto leftInfo = getInheritanceInfo(leftType);
- auto rightInfo = getInheritanceInfo(rightType);
+ auto leftInfo = getInheritanceInfo(leftType, circularityInfo);
+ auto rightInfo = getInheritanceInfo(rightType, circularityInfo);
// We have a case of subtype witness that can show that
// `T : L` or `T : R` based on `T : L&R`. In this case,
@@ -931,7 +1004,7 @@ namespace Slang
}
else if (auto eachType = as<EachType>(type))
{
- auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType());
+ auto elementInheritanceInfo = getInheritanceInfo(eachType->getElementType(), circularityInfo);
SemanticsVisitor visitor(this);
auto directFacet = new(arena) Facet::Impl(
Facet::Kind::Type,