summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-09-05 11:24:19 -0700
committerGitHub <noreply@github.com>2024-09-05 11:24:19 -0700
commitd655302465457c5d3285ae5339201a0769cc38dc (patch)
tree4c0946ba4ea4879831133370d2203f569c135c35 /source
parenta88055c6f5190ca62bb4aa853b4f0fa11546278f (diff)
Support `where` clause and type equality constraint. (#4986)
* Support `where` clause. * Fix. * Fix parser. * Enhance test to cover traditional __generic syntax. * Update user-guide. * Support `where` clause on associatedtype. * Fix. * Put in more comments.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h5
-rw-r--r--source/slang/slang-ast-iterator.h12
-rw-r--r--source/slang/slang-ast-support-types.h13
-rw-r--r--source/slang/slang-ast-type.h2
-rw-r--r--source/slang/slang-ast-val.h34
-rw-r--r--source/slang/slang-check-constraint.cpp10
-rw-r--r--source/slang/slang-check-conversion.cpp16
-rw-r--r--source/slang/slang-check-decl.cpp63
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-inheritance.cpp157
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp10
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp8
-rw-r--r--source/slang/slang-lower-to-ir.cpp30
-rw-r--r--source/slang/slang-parser.cpp166
-rw-r--r--source/slang/slang-syntax.cpp3
17 files changed, 439 insertions, 102 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 4ed38ef69..36aa3a313 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -583,6 +583,11 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl
TypeExp sub;
TypeExp sup;
+ // If this decl is defined in a where clause, store the source location of the where token.
+ SourceLoc whereTokenLoc = SourceLoc();
+
+ bool isEqualityConstraint = false;
+
// Overrides should be public so base classes can access
const TypeExp& _getSupOverride() const { return sup; }
};
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 40a943dc0..24f98391c 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -477,6 +477,18 @@ void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl)
}
else if (auto typeConstraint = as<TypeConstraintDecl>(decl))
{
+ if (auto genericTypeConstraint = as<GenericTypeConstraintDecl>(typeConstraint))
+ {
+ // A generic constraint decl has a left hand side and right hand side expression
+ // for the base and super type of the constraint.
+ // In the case of a folded-in constraint syntax as in `Foo<T:IBar>`,
+ // the left hand side of the constraint is represented by the same token
+ // as the parameter decl itself, so we don't need to traverse into it.
+ // In the case of `Foo<T> where T:IBar`, the left hand side is its own
+ // expression so we do want to traverse it.
+ if (genericTypeConstraint->whereTokenLoc.isValid())
+ visitExpr(genericTypeConstraint->sub.exp);
+ }
visitExpr(typeConstraint->getSup().exp);
}
else if (auto typedefDecl = as<TypeDefDecl>(decl))
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 21948dc04..83a4cf353 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1510,14 +1510,6 @@ namespace Slang
const RequirementDictionary& getRequirementDictionary()
{
- if (m_requirementDictionary.getCount() != m_requirements.getCount())
- {
- for (Index i = m_requirementDictionary.getCount(); i < m_requirements.getCount(); i++)
- {
- auto& r = m_requirements[i];
- m_requirementDictionary.add(r.key, r.value);
- }
- }
return m_requirementDictionary;
}
@@ -1532,11 +1524,8 @@ namespace Slang
// Whether or not this witness table is an extern declaration.
bool isExtern = false;
- // Satisfying values of each requirement.
- List<KeyValuePair<Decl*, RequirementWitness>> m_requirements;
-
// Cached dictionary for looking up satisfying values.
- SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary;
+ RequirementDictionary m_requirementDictionary;
RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst);
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 3a1318696..401d73e29 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -72,7 +72,7 @@ class DeclRefType : public Type
};
template<typename T>
-DeclRef<T> isDeclRefTypeOf(Type* type)
+DeclRef<T> isDeclRefTypeOf(Val* type)
{
if (auto declRefType = as<DeclRefType>(type))
{
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 2599ce46a..8752195cb 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -626,6 +626,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness
return as<DeclRefBase>(getOperand(2));
}
+ bool isEquality()
+ {
+ if (auto declRef = getDeclRef().as<GenericTypeConstraintDecl>())
+ return declRef.getDecl()->isEqualityConstraint;
+ return false;
+ }
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
@@ -895,4 +902,31 @@ void SubstitutionSet::forEachSubstitutionArg(F func) const
}
}
}
+
+inline bool isTypeEqualityWitness(Val* witness)
+{
+ if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
+ {
+ return declaredWitness->isEquality();
+ }
+ else if (as<TypeEqualityWitness>(witness))
+ {
+ return true;
+ }
+ else if (auto eachWitness = as<EachSubtypeWitness>(witness))
+ {
+ return isTypeEqualityWitness(eachWitness->getPatternTypeWitness());
+ }
+ else if (auto typePackWitness = as<TypePackSubtypeWitness>(witness))
+ {
+ for (Index i = 0; i < typePackWitness->getCount(); i++)
+ {
+ if (!isTypeEqualityWitness(typePackWitness->getWitness(i)))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 3e3ed5297..90b0e44f5 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -648,7 +648,7 @@ namespace Slang
else if (auto subEachType = as<EachType>(constraintDeclRef.getDecl()->sub.type))
constrainedGenericParams.add(as<DeclRefType>(subEachType->getElementType())->getDeclRef().getDecl());
- if (sub->equals(sup))
+ if (sub->equals(sup) && isDeclRefTypeOf<InterfaceDecl>(sup))
{
// We are trying to use an interface type itself to conform to the
// type constraint. We can reach this case when the user code does
@@ -674,6 +674,14 @@ namespace Slang
sup,
system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None);
}
+
+ if (constraintDecl->isEqualityConstraint)
+ {
+ // If constraint is an equality constraint, we need to make sure
+ // the witness is equality witness.
+ if (!isTypeEqualityWitness(subTypeWitness))
+ subTypeWitness = nullptr;
+ }
if(subTypeWitness)
{
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 85040dc55..c0d7feaff 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1019,6 +1019,22 @@ namespace Slang
*outCost = kConversionCost_CastToInterface;
return true;
}
+ else if (auto fromIsToWitness = tryGetSubtypeWitness(toType, fromType))
+ {
+ // Is toType and fromType the same via some type equality witness?
+ // If so there is no need to do any conversion.
+ //
+ if (isTypeEqualityWitness(fromIsToWitness))
+ {
+ if (outToExpr)
+ {
+ *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, fromIsToWitness);
+ }
+ if (outCost)
+ *outCost = 0;
+ return true;
+ }
+ }
// Disallow converting to a ParameterGroupType.
//
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ad3f94fc3..02e3241a9 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2278,14 +2278,18 @@ namespace Slang
markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType);
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType));
if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable))
{
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType));
// Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
m_astBuilder->incrementEpoch();
return true;
}
+ else
+ {
+ witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl());
+ }
// Something went wrong.
return false;
@@ -2471,19 +2475,16 @@ namespace Slang
// conformance on the synthesized decl.
checkAggTypeConformance(aggTypeDecl);
- if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+ if (!doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
{
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
-
- // Incrase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
- m_astBuilder->incrementEpoch();
- return true;
+ // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed.
+ // If not, there is something wrong with the code synthesis logic. For now we just return false
+ // instead of crashing so the user can work around the issues.
+ witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl());
+ return false;
}
-
- // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed.
- // If not, there is something wrong with the code synthesis logic. For now we just return false
- // instead of crashing so the user can work around the issues.
- return false;
+ return true;
}
void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
@@ -2497,7 +2498,7 @@ namespace Slang
CheckConstraintSubType(decl->sub);
decl->sub = TranslateTypeNodeForced(decl->sub);
decl->sup = TranslateTypeNodeForced(decl->sup);
- if (!isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type))
+ if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type))
{
getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup);
}
@@ -3548,18 +3549,28 @@ namespace Slang
bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable)
{
+ SLANG_UNUSED(satisfyingType);
+
// We will enumerate the type constraints placed on the
// associated type and see if they can be satisfied.
//
bool conformance = true;
Val* witness = nullptr;
- for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef))
+ for (auto requiredConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef))
{
// Grab the type we expect to conform to from the constraint.
auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef);
+ auto subType = getSub(m_astBuilder, requiredConstraintDeclRef);
+
// Perform a search for a witness to the subtype relationship.
- witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
+ witness = tryGetSubtypeWitness(subType, requiredSuperType);
+ if (witness)
+ {
+ auto genConstraint = as<GenericTypeConstraintDecl>(requiredConstraintDeclRef.getDecl());
+ if (genConstraint && genConstraint->isEqualityConstraint && !isTypeEqualityWitness(witness))
+ witness = nullptr;
+ }
if (witness)
{
// If a subtype witness was found, then the conformance
@@ -3588,6 +3599,15 @@ namespace Slang
if (declRefType->getDeclRef().getDecl()->hasModifier<ToBeSynthesizedModifier>())
return false;
}
+
+ // Register the satisfying type to the witness table
+ // before checking the constraints, since the subtype of
+ // the constraints maybe referencing the satisfying type via
+ // witness lookups.
+ auto requirementWitness = RequirementWitness(satisfyingType->getCanonicalType());
+ witnessTable->m_requirementDictionary[requiredAssociatedTypeDeclRef.getDecl()]
+ = requirementWitness;
+
// We need to confirm that the chosen type `satisfyingType`,
// meets all the constraints placed on the associated type
// requirement `requiredAssociatedTypeDeclRef`.
@@ -3601,15 +3621,11 @@ namespace Slang
// TODO: if any conformance check failed, we should probably include
// that in an error message produced about not satisfying the requirement.
- if(conformance)
+ if (!conformance)
{
- // If all the constraints were satisfied, then the chosen
- // type can indeed satisfy the interface requirement.
- witnessTable->add(
- requiredAssociatedTypeDeclRef.getDecl(),
- RequirementWitness(satisfyingType->getCanonicalType()));
+ witnessTable->m_requirementDictionary.remove(requiredAssociatedTypeDeclRef.getDecl());
}
-
+
return conformance;
}
@@ -9667,7 +9683,8 @@ namespace Slang
ASTNodeType::DeclRefType);
auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl();
List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl);
- assert(constraintTypes);
+ if (!constraintTypes)
+ continue;
constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type);
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index f39d5be1d..dc4568f8a 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -766,6 +766,12 @@ namespace Slang
InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo);
InheritanceInfo _calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo);
+ // Get the inner most generic decl that a decl-ref is dependent on.
+ // For example, `Foo<T>` depends on the generic decl that defines `T`.
+ //
+ DeclRef<GenericDecl> getDependentGenericParent(DeclRef<Decl> declRef);
+ void getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef);
+
struct DirectBaseInfo
{
FacetList facets;
diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp
index 20f41c1bb..0d3929901 100644
--- a/source/slang/slang-check-inheritance.cpp
+++ b/source/slang/slang-check-inheritance.cpp
@@ -97,6 +97,51 @@ namespace Slang
return info;
}
+ void SharedSemanticsContext::getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef)
+ {
+ auto mergeParent = [](DeclRef<GenericDecl>& currentParent, DeclRef<GenericDecl> newParent)
+ {
+ if (!currentParent)
+ {
+ currentParent = newParent;
+ return;
+ }
+ if (currentParent == newParent)
+ return;
+ if (newParent.getDecl()->isChildOf(currentParent.getDecl()))
+ currentParent = newParent;
+ };
+
+ if (declRef.as<GenericTypeParamDeclBase>())
+ {
+ if (!genericParent)
+ mergeParent(genericParent, declRef.getParent().as<GenericDecl>());
+ return;
+ }
+ else if (auto lookupDeclRef = as<LookupDeclRef>(declRef.declRefBase))
+ {
+ if (auto lookupSourceDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource()))
+ getDependentGenericParentImpl(genericParent, lookupSourceDeclRef);
+ }
+ else if (auto genericAppDeclRef = as<GenericAppDeclRef>(declRef.declRefBase))
+ {
+ for (Index i = 0; i < genericAppDeclRef->getArgCount(); i++)
+ {
+ if (auto argDeclRef = isDeclRefTypeOf<Decl>(genericAppDeclRef->getArg(i)))
+ {
+ getDependentGenericParentImpl(genericParent, argDeclRef);
+ }
+ }
+ }
+ }
+
+ DeclRef<GenericDecl> SharedSemanticsContext::getDependentGenericParent(DeclRef<Decl> declRef)
+ {
+ DeclRef<GenericDecl> genericParent;
+ getDependentGenericParentImpl(genericParent, declRef);
+ return genericParent;
+ }
+
InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo)
{
// This method is the main engine for computing linearized inheritance
@@ -305,39 +350,82 @@ namespace Slang
// We now look at the structure of the declaration itself
// to help us enumerate the direct bases.
//
- if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>())
+ auto currentDeclRef = declRef;
+ for (; currentDeclRef;)
{
- // In the case where we have an aggregate type or `extension`
- // declaration, we can use the explicit list of direct bases.
- //
- for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef))
+ if (auto aggTypeDeclBaseRef = currentDeclRef.as<AggTypeDeclBase>())
{
- visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
-
- // Note: In certain cases something takes the *syntactic* form of an inheritance
- // clause, but it is not actually something that should be treated as implying
- // a subtype relationship. For example, an `enum` declaration can use what looks
- // like an inheritance clause to indicate its underlying "tag type."
- //
- // We skip such pseudo-inheritance relationships for the purposes of determining
- // the linearized list of bases.
+ // In the case where we have an aggregate type or `extension`
+ // declaration, we can use the explicit list of direct bases.
//
- if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>())
- continue;
+ for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef))
+ {
+ // Note: In certain cases something takes the *syntactic* form of an inheritance
+ // clause, but it is not actually something that should be treated as implying
+ // a subtype relationship. For example, an `enum` declaration can use what looks
+ // like an inheritance clause to indicate its underlying "tag type."
+ //
+ // We skip such pseudo-inheritance relationships for the purposes of determining
+ // the linearized list of bases.
+ //
+ if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>())
+ continue;
- // The base type and subtype witness can easily be determined
- // using the `InheritanceDecl`.
- //
- auto baseType = getSup(astBuilder, typeConstraintDeclRef);
- auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness(
- selfType,
- baseType,
- typeConstraintDeclRef);
+ // The only case we will ever see a GenericTypeConstraintDecl inside a AggTypeDecl is when
+ // AggTypeDecl is a associatedtype decl. In this case, we will only lookup the type constraint
+ // if the constraint is on the associated type itself.
+ //
+ auto genericTypeConstraintDeclRef = typeConstraintDeclRef.as<GenericTypeConstraintDecl>();
+ if (genericTypeConstraintDeclRef)
+ {
+ // If the base expr on the constraint isn't even a `VarExpr`, then it can't be referencing
+ // the associated type itself and we can skip this constraint.
+ if (!genericTypeConstraintDeclRef.getDecl()->sub.type
+ && !as<VarExpr>(genericTypeConstraintDeclRef.getDecl()->sub.exp))
+ continue;
+ }
+
+ visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
- addDirectBaseType(baseType, satisfyingWitness);
+ // For generic type constraint decls, always make sure it is about the type being checked.
+ //
+ if (genericTypeConstraintDeclRef)
+ {
+ auto subType = getSub(astBuilder, genericTypeConstraintDeclRef);
+ if (subType != selfType)
+ continue;
+ }
+ else if (currentDeclRef != declRef)
+ {
+ continue;
+ }
+ // The base type and subtype witness can easily be determined
+ // using the `InheritanceDecl`.
+ //
+ auto baseType = getSup(astBuilder, typeConstraintDeclRef);
+ auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness(
+ selfType,
+ baseType,
+ typeConstraintDeclRef);
+
+ addDirectBaseType(baseType, satisfyingWitness);
+ }
}
+ if (currentDeclRef.as<AssocTypeDecl>())
+ {
+ // If the current type is an associated type, continue inspecting the base/parent of the
+ // associatedtype to discover additional constraints defined on the parent associatedtype decls.
+ //
+ if (auto lookupDeclRef = as<LookupDeclRef>(currentDeclRef.declRefBase))
+ {
+ currentDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource()).as<AssocTypeDecl>();
+ continue;
+ }
+ }
+ break;
}
- else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDeclBase>())
+
+ if (auto genericDeclRef = getDependentGenericParent(declRef))
{
// The constraints placed on a generic type parameter are siblings of that
// parameter in its parent `GenericDecl`, so we need to enumerate all of
@@ -349,13 +437,11 @@ namespace Slang
// representation would need to take into account canonicalization of
// constraints.
- auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>();
- SLANG_ASSERT(genericDeclRef);
ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric);
if (auto extensionDecl = as<ExtensionDecl>(genericDeclRef.getDecl()->inner))
{
- if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == genericTypeParamDeclRef)
+ if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == declRef)
{
// If `T` is a generic parameter where the same generic is an extension on `T`,
// then we need to add the extension itself as a facet.
@@ -377,13 +463,10 @@ namespace Slang
auto superType = getSup(astBuilder, constraintDeclRef);
// We only consider constraints where the type represented
- // by `genericTypeParamDeclRef` is the subtype, since those
+ // by `declRef` is the subtype, since those
// constraints are the ones that give us information about
// the declared supertypes.
//
- // TODO: consider whether other kinds of constraints could
- // also apply here.
- //
auto subDeclRefType = as<DeclRefType>(subType);
if (!subDeclRefType)
{
@@ -394,7 +477,7 @@ namespace Slang
if (!subDeclRefType)
continue;
}
- if (subDeclRefType->getDeclRef() != genericTypeParamDeclRef)
+ if (subDeclRefType->getDeclRef() != declRef)
continue;
// Because the constraint is a declared inheritance relationship,
@@ -402,9 +485,9 @@ namespace Slang
// as in all the preceding cases.
//
auto satisfyingWitness = _getASTBuilder()->getDeclaredSubtypeWitness(
- selfType,
- superType,
- constraintDeclRef);
+ selfType,
+ superType,
+ constraintDeclRef);
addDirectBaseType(superType, satisfyingWitness);
}
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f308f340d..5afca48b3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -291,6 +291,10 @@ INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
// A placeholder witness that ThisType implements the enclosing interface.
// Used only in interface definitions.
INST(ThisTypeWitness, thisTypeWitness, 1, 0)
+
+// A placeholder witness for the fact that two types are equal.
+INST(TypeEqualityWitness, TypeEqualityWitness, 2, HOISTABLE)
+
INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0)
INST(Module, module, 0, PARENT)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 8f648aabd..fc963697a 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4075,6 +4075,8 @@ public:
IRInst* createThisTypeWitness(IRType* interfaceType);
+ IRInst* getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2);
+
IRInterfaceRequirementEntry* createInterfaceRequirementEntry(
IRInst* requirementKey,
IRInst* requirementVal);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index acc6abd57..d02c01105 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4614,6 +4614,15 @@ namespace Slang
return witness;
}
+ IRInst* IRBuilder::getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2)
+ {
+ IRInst* operands[2] = { type1, type2 };
+ return (IRType*)createIntrinsicInst(
+ witnessType,
+ kIROp_TypeEqualityWitness,
+ 2,
+ operands);
+ }
IRStructType* IRBuilder::createStructType()
{
@@ -8347,6 +8356,7 @@ namespace Slang
case kIROp_InterfaceRequirementEntry:
case kIROp_Block:
case kIROp_Each:
+ case kIROp_TypeEqualityWitness:
return false;
/// Liveness markers have no side effects
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 2d6ed2568..06b7c937f 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -717,6 +717,14 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node)
ASTLookupExprVisitor visitor(&context);
if (visitor.dispatchIfNotNull(typeConstraint->getSup().exp))
return true;
+ if (auto genTypeConstraint = as<GenericTypeConstraintDecl>(node))
+ {
+ if (genTypeConstraint->whereTokenLoc.isValid())
+ {
+ if (visitor.dispatchIfNotNull(genTypeConstraint->sub.exp))
+ return true;
+ }
+ }
}
else if (auto typedefDecl = as<TypeDefDecl>(node))
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f62bb631e..813467743 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1753,6 +1753,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
lowerType(context, val->getSup())));
}
+ LoweredValInfo visitTypeEqualityWitness(TypeEqualityWitness* val)
+ {
+ auto subType = lowerType(context, val->getSub());
+ auto supType = lowerType(context, val->getSup());
+ auto witnessType = context->irBuilder->getWitnessTableType(
+ lowerType(context, val->getSup()));
+ return LoweredValInfo::simple(context->irBuilder->getTypeEqualityWitness(witnessType, subType, supType));
+ }
+
LoweredValInfo visitTransitiveSubtypeWitness(
TransitiveSubtypeWitness* val)
{
@@ -4989,6 +4998,19 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
auto superType = lowerType(context, expr->type);
auto value = lowerRValueExpr(context, expr->valueArg);
+ // First, we check if the witness is a type equality witness.
+ // If so, we can simply emit a bit cast to the target type that should eventually
+ // fold out to a no-op.
+ // Note: if we are going to equivalent but not identical types in the future,
+ // then the cast between equivalent types shouldn't be as simple as a bit cast
+ // and will require actual coercion logic between the two types.
+ // For now, we don't support type equivalence witness so this is safe for
+ // equal types.
+ if (isTypeEqualityWitness(expr->witnessArg))
+ {
+ return LoweredValInfo::simple(getBuilder()->emitBitCast(superType, getSimpleVal(context, value)));
+ }
+
// The actual operation that we need to perform here
// depends on the kind of subtype relationship we
// are making use of.
@@ -5039,7 +5061,6 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return emitCastToConcreteSuperTypeRec(value, superType, expr->witnessArg);
}
}
-
SLANG_UNEXPECTED("unexpected case of subtype relationship");
UNREACHABLE_RETURN(LoweredValInfo());
}
@@ -8400,8 +8421,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
for (auto constraintDecl : decl->getMembersOfType<GenericTypeConstraintDecl>())
{
auto baseType = lowerType(context, constraintDecl->sup.type);
- SLANG_ASSERT(baseType && baseType->getOp() == kIROp_InterfaceType);
- constraintInterfaces.add((IRInterfaceType*)baseType);
+ if (baseType && baseType->getOp() == kIROp_InterfaceType)
+ constraintInterfaces.add((IRInterfaceType*)baseType);
}
auto assocType = context->irBuilder->getAssociatedType(
constraintInterfaces.getArrayView().arrayView);
@@ -10686,7 +10707,8 @@ LoweredValInfo emitDeclRef(
for (auto argVal : genericSubst->getArgs())
{
auto irArgVal = lowerSimpleVal(context, argVal);
- SLANG_ASSERT(irArgVal);
+ if (!irArgVal)
+ continue;
// It is possible that some of the arguments to the generic
// represent conformances to conjunction types like `A & B`.
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 121c68be7..5729adb29 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -148,6 +148,12 @@ namespace Slang
resetLookupScope();
}
+ void PushScope(Scope* newScope)
+ {
+ currentScope = newScope;
+ resetLookupScope();
+ }
+
void pushScopeAndSetParent(ContainerDecl* containerDecl)
{
containerDecl->parentDecl = currentScope->containerDecl;
@@ -306,6 +312,8 @@ namespace Slang
static Expr* _parseGenericArg(Parser* parser);
+ static Expr* parsePrefixExpr(Parser* parser);
+
//
static void Unexpected(
@@ -721,6 +729,16 @@ namespace Slang
return false;
}
+ bool AdvanceIf(Parser* parser, char const* text, Token* outToken)
+ {
+ if (parser->LookAheadToken(text))
+ {
+ *outToken = parser->ReadToken();
+ return true;
+ }
+ return false;
+ }
+
/// Information on how to parse certain pairs of matches tokens
struct MatchedTokenInfo
{
@@ -1496,10 +1514,47 @@ namespace Slang
}
else
{
- // default case is a type parameter
- paramDecl = parser->astBuilder->create<GenericTypeParamDecl>();
- parser->FillPosition(paramDecl);
- paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
+ // Disambiguate between a type parameter and a value parameter.
+ // If next token is "typename", then it is a type parameter.
+ bool isTypeParam = AdvanceIf(parser, "typename");
+ if (!isTypeParam)
+ {
+ // Otherwise, if the next token is an identifier, followed by a colon, comma, '=' or '>', then it is a type parameter.
+ isTypeParam = parser->LookAheadToken(TokenType::Identifier);
+ auto nextNextTokenType = peekTokenType(parser, 1);
+ switch (nextNextTokenType)
+ {
+ case TokenType::Colon:
+ case TokenType::Comma:
+ case TokenType::OpGreater:
+ case TokenType::OpAssign:
+ break;
+ default:
+ isTypeParam = false;
+ break;
+ }
+ }
+
+ if (isTypeParam)
+ {
+ // Parse as a type parameter.
+ paramDecl = parser->astBuilder->create<GenericTypeParamDecl>();
+ parser->FillPosition(paramDecl);
+ paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
+ }
+ else
+ {
+ // Parse as a traditional syntax value parameter in the form of `type paramName`.
+ auto valueParamDecl = parser->astBuilder->create<GenericValueParamDecl>();
+ parser->FillPosition(valueParamDecl);
+ valueParamDecl->type = parser->ParseTypeExp();
+ valueParamDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
+ if (AdvanceIf(parser, TokenType::OpAssign))
+ {
+ valueParamDecl->initExpr = parser->ParseInitExpr();
+ }
+ return valueParamDecl;
+ }
}
if (AdvanceIf(parser, TokenType::Colon))
{
@@ -1600,7 +1655,43 @@ namespace Slang
}
else
{
- return parseInner(nullptr);
+ auto genericParent = parser->currentScope ? as<GenericDecl>(parser->currentScope->containerDecl) : nullptr;
+ return parseInner(genericParent);
+ }
+ }
+
+ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericParent)
+ {
+ if (!genericParent)
+ return;
+ Token whereToken;
+ while (AdvanceIf(parser, "where", &whereToken))
+ {
+ auto subType = parser->ParseTypeExp();
+ if (AdvanceIf(parser, TokenType::Colon))
+ {
+ for (;;)
+ {
+ auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>();
+ constraint->whereTokenLoc = whereToken.loc;
+ parser->FillPosition(constraint);
+ constraint->sub = subType;
+ constraint->sup = parser->ParseTypeExp();
+ AddMember(genericParent, constraint);
+ if (!AdvanceIf(parser, TokenType::Comma))
+ break;
+ }
+ }
+ else if (AdvanceIf(parser, TokenType::OpEql))
+ {
+ auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>();
+ constraint->whereTokenLoc = whereToken.loc;
+ constraint->isEqualityConstraint = true;
+ parser->FillPosition(constraint);
+ constraint->sub = subType;
+ constraint->sup = parser->ParseTypeExp();
+ AddMember(genericParent, constraint);
+ }
}
}
@@ -1702,7 +1793,7 @@ namespace Slang
decl->loc = declaratorInfo.nameAndLoc.loc;
decl->nameAndLoc = declaratorInfo.nameAndLoc;
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
// HACK: The return type of the function will already have been
// parsed in a scope that didn't include the function's generic
@@ -1731,6 +1822,12 @@ namespace Slang
}
_parseOptSemantics(parser, decl);
+
+ auto funcScope = parser->currentScope;
+ parser->PopScope();
+ maybeParseGenericConstraints(parser, genericParent);
+ parser->PushScope(funcScope);
+
decl->body = parseOptBody(parser);
if (auto block = as<BlockStmt>(decl->body))
{
@@ -2597,7 +2694,7 @@ namespace Slang
}
else if (parser->LookAheadToken("expand") || parser->LookAheadToken("each"))
{
- typeSpec.expr = parser->ParseExpression();
+ typeSpec.expr = parsePrefixExpr(parser);
return typeSpec;
}
// Uncomment should we decide to enable (a,b,c) tuple types
@@ -3335,12 +3432,13 @@ namespace Slang
static NodeBase* parseExtensionDecl(Parser* parser, void* /*userData*/)
{
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
ExtensionDecl* decl = parser->astBuilder->create<ExtensionDecl>();
parser->FillPosition(decl);
decl->targetType = parser->ParseTypeExp();
parseOptionalInheritanceClause(parser, decl);
+ maybeParseGenericConstraints(parser, genericParent);
parseDeclBody(parser, decl);
return decl;
});
@@ -3357,16 +3455,27 @@ namespace Slang
parser->FillPosition(paramConstraint);
// substitution needs to be filled during check
- Type* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl));
+ Type* paramType = nullptr;
+ if (as<GenericTypeParamDeclBase>(decl))
+ {
+ paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl));
- SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>();
- paramTypeExpr->loc = decl->loc;
- paramTypeExpr->base.type = paramType;
- paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType));
+ SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>();
+ paramTypeExpr->loc = decl->loc;
+ paramTypeExpr->base.type = paramType;
+ paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType));
- paramConstraint->sub = TypeExp(paramTypeExpr);
- paramConstraint->sup = parser->ParseTypeExp();
+ paramConstraint->sub = TypeExp(paramTypeExpr);
+ }
+ else if (as<AssocTypeDecl>(decl))
+ {
+ auto varExpr = parser->astBuilder->create<VarExpr>();
+ varExpr->scope = parser->currentScope;
+ varExpr->name = decl->getName();
+ paramConstraint->sub.exp = varExpr;
+ }
+ paramConstraint->sup = parser->ParseTypeExp();
AddMember(decl, paramConstraint);
} while (AdvanceIf(parser, TokenType::Comma));
}
@@ -3380,6 +3489,7 @@ namespace Slang
assocTypeDecl->nameAndLoc = NameLoc(nameToken);
assocTypeDecl->loc = nameToken.loc;
parseOptionalGenericConstraints(parser, assocTypeDecl);
+ maybeParseGenericConstraints(parser, assocTypeDecl);
parser->ReadToken(TokenType::Semicolon);
return assocTypeDecl;
}
@@ -3424,11 +3534,12 @@ namespace Slang
AdvanceIf(parser, TokenType::CompletionRequest);
decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
// We allow for an inheritance clause on a `struct`
// so that it can conform to interfaces.
parseOptionalInheritanceClause(parser, decl);
+ maybeParseGenericConstraints(parser, genericParent);
parseDeclBody(parser, decl);
return decl;
});
@@ -3661,7 +3772,7 @@ namespace Slang
{
ConstructorDecl* decl = parser->astBuilder->create<ConstructorDecl>();
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
// Note: we leave the source location of this decl as invalid, to
// trigger the fallback logic that fills in the location of the
@@ -3680,6 +3791,10 @@ namespace Slang
decl->nameAndLoc.name = getName(parser, "$init");
parseParameterList(parser, decl);
+ auto funcScope = parser->currentScope;
+ parser->PopScope();
+ maybeParseGenericConstraints(parser, genericParent);
+ parser->PushScope(funcScope);
decl->body = parseOptBody(parser);
@@ -3997,7 +4112,7 @@ namespace Slang
parser->FillPosition(decl);
decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
parser->PushScope(decl);
parseModernParamList(parser, decl);
@@ -4009,6 +4124,10 @@ namespace Slang
{
decl->returnType = parser->ParseTypeExp();
}
+ auto funcScope = parser->currentScope;
+ parser->PopScope();
+ maybeParseGenericConstraints(parser, genericParent);
+ parser->PushScope(funcScope);
decl->body = parseOptBody(parser);
if (auto blockStmt = as<BlockStmt>(decl->body))
decl->closingSourceLoc = blockStmt->closingSourceLoc;
@@ -4025,8 +4144,9 @@ namespace Slang
parser->FillPosition(decl);
decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
+ maybeParseGenericConstraints(parser, genericParent);
if( expect(parser, TokenType::OpAssign) )
{
decl->type = parser->ParseTypeExp();
@@ -4860,7 +4980,7 @@ namespace Slang
rs->nameAndLoc.name = generateName(this);
rs->nameAndLoc.loc = rs->loc;
}
- return parseOptGenericDecl(this, [&](GenericDecl*)
+ return parseOptGenericDecl(this, [&](GenericDecl* genericParent)
{
// We allow for an inheritance clause on a `struct`
// so that it can conform to interfaces.
@@ -4878,6 +4998,7 @@ namespace Slang
rs->hasBody = false;
return rs;
}
+ maybeParseGenericConstraints(this, genericParent);
parseDeclBody(this, rs);
return rs;
});
@@ -4977,9 +5098,10 @@ namespace Slang
addModifier(decl, parser->astBuilder->create<TransparentModifier>());
}
- return parseOptGenericDecl(parser, [&](GenericDecl*)
+ return parseOptGenericDecl(parser, [&](GenericDecl* genericParent)
{
parseOptionalInheritanceClause(parser, decl);
+ maybeParseGenericConstraints(parser, genericParent);
parser->ReadToken(TokenType::LBrace);
Token closingToken;
parser->pushScopeAndSetParent(decl);
@@ -7583,7 +7705,7 @@ namespace Slang
{
EachExpr* eachExpr = parser->astBuilder->create<EachExpr>();
eachExpr->loc = loc;
- eachExpr->baseExpr = parser->ParseLeafExpression();
+ eachExpr->baseExpr = parsePostfixExpr(parser);
return eachExpr;
}
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index a55f0eb1a..ec7169e04 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
RefPtr<WitnessTable> result = new WitnessTable();
result->baseType = as<Type>(newBaseType);
result->witnessedType = as<Type>(newWitnessedType);
- for (auto requirement : m_requirements)
+ for (auto requirement : m_requirementDictionary)
{
auto newRequirement = requirement.value.specialize(astBuilder, subst);
result->add(requirement.key, newRequirement);
@@ -500,7 +500,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
void WitnessTable::add(Decl* decl, RequirementWitness const& witness)
{
- m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness));
m_requirementDictionary.add(decl, witness);
}