summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp348
1 files changed, 293 insertions, 55 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index e511fbc39..1e524e27f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2626,9 +2626,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
assocTypeDef->nameAndLoc.name = getName("Differential");
assocTypeDef->type.type = context->conformingType;
- assocTypeDef->parentDecl = context->parentDecl;
+ context->parentDecl->addMember(assocTypeDef);
assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
- context->parentDecl->members.add(assocTypeDef);
markSelfDifferentialMembersOfType(
as<AggTypeDecl>(context->parentDecl),
@@ -2660,8 +2659,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
if (!aggTypeDecl)
{
aggTypeDecl = m_astBuilder->create<StructDecl>();
- aggTypeDecl->parentDecl = context->parentDecl;
- context->parentDecl->members.add((aggTypeDecl));
+ context->parentDecl->addMember(aggTypeDecl);
aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
context->parentDecl->invalidateMemberDictionary();
@@ -2719,8 +2717,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
diffField->nameAndLoc = member->nameAndLoc;
diffField->type.type = diffMemberType;
diffField->checkState = DeclCheckState::SignatureChecked;
- diffField->parentDecl = aggTypeDecl;
- aggTypeDecl->members.add(diffField);
+ aggTypeDecl->addMember(diffField);
auto visibility = getDeclVisibility(member);
addVisibilityModifier(diffField, visibility);
@@ -2775,8 +2772,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
{
auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>();
inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType();
- inheritanceIDiffernetiable->parentDecl = aggTypeDecl;
- aggTypeDecl->members.add(inheritanceIDiffernetiable);
+ aggTypeDecl->addMember(inheritanceIDiffernetiable);
}
// The `Differential` type of a `Differential` type is always itself.
@@ -2797,9 +2793,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
assocTypeDef->nameAndLoc.name = getName("Differential");
assocTypeDef->type.type = satisfyingType;
- assocTypeDef->parentDecl = aggTypeDecl;
+ aggTypeDecl->addMember(assocTypeDef);
assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
- aggTypeDecl->members.add(assocTypeDef);
}
// Go through all members and collect their differential types.
@@ -4319,7 +4314,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
typeParamDeclBase->astNodeType);
synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc();
synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex;
- synTypeParamDeclBase->parentDecl = synGenericDecl;
+ synGenericDecl->addMember(synTypeParamDeclBase);
// Note: we intentionally do not copy GenericTypeParamDecl::initType here,
// because initType maybe dependent on the original type parameters,
@@ -4327,7 +4322,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
// synthesized ones. It shouldn't be required for the implementing declaration to define
// initType anyways, so we'll just save ourselves from the trouble.
//
- synGenericDecl->members.add(synTypeParamDeclBase);
mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase);
@@ -4345,7 +4339,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
{
auto synValParamDecl = m_astBuilder->create<GenericValueParamDecl>();
synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc;
- synValParamDecl->parentDecl = synGenericDecl;
+ synGenericDecl->addMember(synValParamDecl);
synValParamDecl->parameterIndex = valParamDecl->parameterIndex;
synValParamDecl->type = valParamDecl->type;
@@ -4355,7 +4349,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
// synthesized ones. It shouldn't be required for the implementing declaration to define
// initType anyways, so we'll just save ourselves from the trouble.
//
- synGenericDecl->members.add(synValParamDecl);
mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl);
@@ -4537,8 +4530,7 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl(
// We need to add the parameter as a child declaration of
// the method we are building.
//
- synParamDecl->parentDecl = synthesized;
- synthesized->members.add(synParamDecl);
+ synthesized->addMember(synParamDecl);
// Add modifiers
paramType.isLeftValue = true;
@@ -5508,8 +5500,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness(
// We need to add the parameter as a child declaration of
// the accessor we are building.
//
- synParamDecl->parentDecl = synAccessorDecl;
- synAccessorDecl->members.add(synParamDecl);
+ synAccessorDecl->addMember(synParamDecl);
// For each paramter, we will create an argument expression
// to represent it in the body of the accessor.
@@ -5567,8 +5558,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness(
addModifier(synAccessorDecl, m_astBuilder->create<ForceInlineAttribute>());
synAccessorDecl->body = synBodyStmt;
- synAccessorDecl->parentDecl = synPropertyDecl;
- synPropertyDecl->members.add(synAccessorDecl);
+ synPropertyDecl->addMember(synAccessorDecl);
// Register the synthesized accessor.
//
@@ -5709,8 +5699,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements(
// We need to add the parameter as a child declaration of
// the accessor we are building.
//
- synParamDecl->parentDecl = synAccessorDecl;
- synAccessorDecl->members.add(synParamDecl);
+ synAccessorDecl->addMember(synParamDecl);
// For each paramter, we will create an argument expression
// to represent it in the body of the accessor.
@@ -5869,8 +5858,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements(
synAccessorDecl->body = synBodyStmt;
- synAccessorDecl->parentDecl = synAccesorContainer;
- synAccesorContainer->members.add(synAccessorDecl);
+ synAccesorContainer->addMember(synAccessorDecl);
// If synthesis of an accessor worked, then we will record it into
// a local dictionary. We do *not* install the accessor into the
@@ -6377,7 +6365,9 @@ bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness(
}
synFunc->loc = context->parentDecl->closingSourceLoc;
synFunc->nameAndLoc.loc = synFunc->loc;
- context->parentDecl->members.add(synFunc);
+ // synFunc already has its parent set
+ SLANG_ASSERT(context->parentDecl == synFunc->parentDecl);
+ context->parentDecl->addMember(synFunc);
context->parentDecl->invalidateMemberDictionary();
addModifier(synFunc, intrinsicOpModifier);
witnessTable->add(
@@ -6561,7 +6551,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
seqStmt->stmts.add(synReturn);
Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc;
- context->parentDecl->members.add(witnessDecl);
+ SLANG_ASSERT(context->parentDecl == witnessDecl->parentDecl);
+ context->parentDecl->addMember(witnessDecl);
context->parentDecl->invalidateMemberDictionary();
addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>());
@@ -7577,11 +7568,10 @@ void SemanticsDeclBasesVisitor::visitStructDecl(StructDecl* decl)
IsSubTypeOptions::NoCaching))
{
InheritanceDecl* conformanceDecl = m_astBuilder->create<InheritanceDecl>();
- conformanceDecl->parentDecl = decl;
conformanceDecl->loc = decl->loc;
conformanceDecl->base.type = defaultInitializableType;
conformanceDecl->nameAndLoc.name = getName("$inheritance");
- decl->members.add(conformanceDecl);
+ decl->addMember(conformanceDecl);
}
}
@@ -7940,10 +7930,9 @@ void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl)
Type* enumTypeType = getASTBuilder()->getEnumTypeType();
InheritanceDecl* enumConformanceDecl = m_astBuilder->create<InheritanceDecl>();
- enumConformanceDecl->parentDecl = decl;
enumConformanceDecl->loc = decl->loc;
enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType();
- decl->members.add(enumConformanceDecl);
+ decl->addMember(enumConformanceDecl);
// The `__EnumType` interface has one required member, the `__Tag` type.
// We need to satisfy this requirement automatically, rather than require
@@ -9556,8 +9545,7 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(
default:
break;
}
- decl->members.add(param);
- param->parentDecl = decl;
+ decl->addMember(param);
}
}
@@ -9576,8 +9564,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
.as<CallableDecl>();
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
+ interfaceDecl->addMember(reqDecl);
if (!decl->hasModifier<NoDiffThisAttribute>())
{
@@ -9596,8 +9583,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
+ decl->addMember(reqRef);
isDiffFunc = true;
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
@@ -9612,8 +9598,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
+ interfaceDecl->addMember(reqDecl);
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type for this-type.
@@ -9631,8 +9616,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl*
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
+ decl->addMember(reqRef);
}
isDiffFunc = true;
}
@@ -10123,8 +10107,7 @@ void SemanticsDeclHeaderVisitor::visitAbstractStorageDeclCommon(ContainerDecl* d
GetterDecl* getterDecl = m_astBuilder->create<GetterDecl>();
getterDecl->loc = decl->loc;
- getterDecl->parentDecl = decl;
- decl->members.add(getterDecl);
+ decl->addMember(getterDecl);
}
}
@@ -10263,8 +10246,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
newValueParam->nameAndLoc.name = getName("newValue");
newValueParam->nameAndLoc.loc = decl->loc;
- newValueParam->parentDecl = decl;
- decl->members.add(newValueParam);
+ decl->addMember(newValueParam);
}
// The new-value parameter is expected to have the
@@ -11355,6 +11337,207 @@ static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, Seman
}
}
+
+// Replace with <=> in C++20
+template<typename T>
+int compareThreeWays(T a, T b)
+{
+ if (a > b)
+ return -1;
+ else if (b > a)
+ return 1;
+ else
+ return 0;
+}
+
+// lhs and rhs cannot be nullptr
+int compareDecls(Decl& lhs, Decl& rhs);
+
+// lhs and rhs cannot be nullptr
+int compareVals(Val& lhs, Val& rhs);
+
+template<typename T, class Compare>
+int comparePtrs(T* lhs, T* rhs, Compare const& compare)
+{
+ int res = 0;
+ if (lhs == rhs)
+ res = 0;
+ else if (!lhs)
+ res = -1;
+ else if (!rhs)
+ res = 1;
+ else
+ res = compare(*lhs, *rhs);
+ return res;
+}
+
+// lhs or rhs might be nullptr
+int compareDecls(Decl* lhs, Decl* rhs)
+{
+ return comparePtrs(lhs, rhs, [&](Decl& lhs, Decl& rhs) { return compareDecls(lhs, rhs); });
+}
+
+// lhs or rhs might be nullptr
+int compareVals(Val* lhs, Val* rhs)
+{
+ return comparePtrs(lhs, rhs, [&](Val& lhs, Val& rhs) { return compareVals(lhs, rhs); });
+}
+
+// Compare operands of lhs and rhs from offset,
+// and at most count operands, if the capacity allows it.
+int compareValOperands(Val& lhs, Val& rhs, Index offset, Index count)
+{
+ const Index lN = std::clamp<Index>(lhs.getOperandCount() - offset, 0, count);
+ const Index rN = std::clamp<Index>(rhs.getOperandCount() - offset, 0, count);
+ int res = compareThreeWays(lN, rN);
+ if (res)
+ return res;
+ for (Index i = 0; i < lN; ++i)
+ {
+ auto lOp = lhs.m_operands[offset + i];
+ auto rOp = rhs.m_operands[offset + i];
+ res = compareThreeWays(lOp.kind, rOp.kind);
+ if (res)
+ {
+ break;
+ }
+ switch (lOp.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ res = compareThreeWays(lOp.getIntConstant(), rOp.getIntConstant());
+ break;
+ case ValNodeOperandKind::ValNode:
+ res = compareVals(lOp.getVal(), rOp.getVal());
+ break;
+ case ValNodeOperandKind::ASTNode:
+ res = compareDecls(lOp.getDecl(), rOp.getDecl());
+ break;
+ }
+ if (res)
+ {
+ break;
+ }
+ }
+ return res;
+}
+
+// Compare operands of lhs and rhs from offset to the end
+int compareValOperands(Val& lhs, Val& rhs, Index offset)
+{
+ return compareValOperands(lhs, rhs, offset, std::numeric_limits<Index>::max());
+}
+
+// Find the lowest common ancestor (LCA) of nodes a and b
+// Uppon return, a and b are modified to be direct children of the LCA
+// Returns nullptr if a and b have no common ancestor
+// (e.g. a and b are not in the same module)
+// a and b are set to each respective module
+ContainerDecl* findDeclsLowestCommonAncestor(Decl*& a, Decl*& b)
+{
+ auto ascendToRoot = [](Decl*& d)
+ {
+ UIndex depth = 0;
+ while (d->parentDecl)
+ {
+ ++depth;
+ d = d->parentDecl;
+ }
+ return depth;
+ };
+
+ Decl* aRoot = a;
+ Decl* bRoot = b;
+ auto aDepth = ascendToRoot(aRoot);
+ auto bDepth = ascendToRoot(bRoot);
+ if (aRoot != bRoot) // Not in the same tree / module
+ {
+ a = aRoot;
+ b = bRoot;
+ return nullptr;
+ }
+ // Level nodes
+ Decl** toAscend = nullptr;
+ Decl** reference = nullptr;
+ UIndex n = 0;
+ if (aDepth > bDepth)
+ {
+ toAscend = &a;
+ reference = &b;
+ n = aDepth - bDepth;
+ }
+ else if (bDepth > aDepth)
+ {
+ toAscend = &b;
+ reference = &a;
+ n = bDepth - aDepth;
+ }
+ if (n)
+ {
+ // Level until toAscend is one level under reference
+ while (n > UIndex(1))
+ {
+ *toAscend = (*toAscend)->parentDecl;
+ --n;
+ }
+ // If toAscend was a child of reference
+ if ((*toAscend)->parentDecl == *reference)
+ {
+ return (*toAscend)->parentDecl;
+ }
+ else
+ {
+ *toAscend = (*toAscend)->parentDecl;
+ }
+ }
+ while (a->parentDecl != b->parentDecl)
+ {
+ a = a->parentDecl;
+ b = b->parentDecl;
+ }
+ return a->parentDecl;
+}
+
+int compareDecls(Decl& lhs, Decl& rhs)
+{
+ int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType);
+ if (res)
+ return res;
+ Decl* lLCAChild = &lhs;
+ Decl* rLCAChild = &rhs;
+ if (ContainerDecl* lca = findDeclsLowestCommonAncestor(lLCAChild, rLCAChild))
+ {
+ res = compareThreeWays(lca->getDeclIndex(lLCAChild), lca->getDeclIndex(rLCAChild));
+ }
+ else
+ {
+ res = comparePtrs(
+ lLCAChild->getName(),
+ rLCAChild->getName(),
+ [](Name const& lName, Name const& rName)
+ { return strcmp(lName.text.begin(), rName.text.begin()); });
+ }
+ return res;
+}
+
+int compareVals(Val& lhs, Val& rhs)
+{
+ int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType);
+ if (res)
+ return res;
+ res = compareValOperands(lhs, rhs, 0);
+ return res;
+}
+
+int compareTypes(Type* lhs, Type* rhs)
+{
+ return compareVals(lhs, rhs);
+}
+
+int compareTypes(Type& lhs, Type& rhs)
+{
+ return compareVals(lhs, rhs);
+}
+
static void _getCanonicalConstraintTypes(List<Type*>& outTypeList, Type* type)
{
if (auto andType = as<AndType>(type))
@@ -11379,16 +11562,22 @@ OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericCon
for (auto genericTypeConstraintDecl :
getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDecl))
{
- assert(
- genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType);
- auto typeParamDecl =
- as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl();
- List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl);
- if (!constraintTypes)
- continue;
- constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type);
+ if (genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType)
+ {
+ auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)
+ ->getDeclRef()
+ .getDecl();
+ List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl);
+ if (!constraintTypes)
+ continue;
+ constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Cannot extract Cannonical Generic Constraints on non DeclRefTypes. "
+ "Use getCanonicalGenericConstraints2(...) instead.");
+ }
}
-
OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> result;
for (auto& constraints : genericConstraints)
{
@@ -11397,8 +11586,57 @@ OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericCon
{
_getCanonicalConstraintTypes(typeList, type);
}
- // TODO: we also need to sort the types within the list for each generic type param.
- result[constraints.key] = typeList;
+ const auto typeComparator = [&](Type* lhs, Type* rhs)
+ { return compareTypes(*lhs, *rhs) < 0; };
+ typeList.sort(typeComparator);
+ result[constraints.key] = std::move(typeList);
+ }
+ return result;
+}
+
+
+OrderedDictionary<Type*, List<Type*>> getCanonicalGenericConstraints2(
+ ASTBuilder* astBuilder,
+ DeclRef<ContainerDecl> genericDecl)
+{
+ Dictionary<Type*, HashSet<Type*>> genericConstraints;
+ for (auto genericTypeConstraintDecl :
+ getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDecl))
+ {
+ auto subExpr = genericTypeConstraintDecl.getDecl()->sub;
+ auto supExpr = genericTypeConstraintDecl.getDecl()->sup;
+ Type* typeToAdd = subExpr.type;
+ if (typeToAdd)
+ {
+ if (!genericConstraints.containsKey(typeToAdd))
+ {
+ genericConstraints[typeToAdd] = HashSet<Type*>();
+ }
+ genericConstraints[typeToAdd].add(supExpr.type);
+ }
+ }
+ const auto typeComparator = [&](Type* lhs, Type* rhs) { return compareTypes(lhs, rhs) < 0; };
+ const List<Type*> sortedKeys = [&]()
+ {
+ List<Type*> res;
+ res.reserve(genericConstraints.getCount());
+ for (auto& t : genericConstraints)
+ {
+ res.add(t.first);
+ }
+ res.sort(typeComparator);
+ return res;
+ }();
+ OrderedDictionary<Type*, List<Type*>> result;
+ for (auto& key : sortedKeys)
+ {
+ List<Type*> typeList;
+ for (auto type : genericConstraints[key])
+ {
+ _getCanonicalConstraintTypes(typeList, type);
+ }
+ typeList.sort(typeComparator);
+ result[key] = std::move(typeList);
}
return result;
}