From a5efbb1b775afb2f6b29b37d39947c41744bb005 Mon Sep 17 00:00:00 2001 From: Ronan Date: Sat, 26 Apr 2025 21:04:01 +0200 Subject: Added getCanonicalGenericConstraints2 (sorts constraints and allows more generic expressions) (#6787) --- source/slang/slang-check-decl.cpp | 348 ++++++++++++++++++++++++++++++++------ 1 file changed, 293 insertions(+), 55 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e511fbc39..1e524e27f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2626,9 +2626,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness auto assocTypeDef = m_astBuilder->create(); assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = context->conformingType; - assocTypeDef->parentDecl = context->parentDecl; + context->parentDecl->addMember(assocTypeDef); assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - context->parentDecl->members.add(assocTypeDef); markSelfDifferentialMembersOfType( as(context->parentDecl), @@ -2660,8 +2659,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness if (!aggTypeDecl) { aggTypeDecl = m_astBuilder->create(); - aggTypeDecl->parentDecl = context->parentDecl; - context->parentDecl->members.add((aggTypeDecl)); + context->parentDecl->addMember(aggTypeDecl); aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; context->parentDecl->invalidateMemberDictionary(); @@ -2719,8 +2717,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness diffField->nameAndLoc = member->nameAndLoc; diffField->type.type = diffMemberType; diffField->checkState = DeclCheckState::SignatureChecked; - diffField->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(diffField); + aggTypeDecl->addMember(diffField); auto visibility = getDeclVisibility(member); addVisibilityModifier(diffField, visibility); @@ -2775,8 +2772,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness { auto inheritanceIDiffernetiable = m_astBuilder->create(); inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); - inheritanceIDiffernetiable->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(inheritanceIDiffernetiable); + aggTypeDecl->addMember(inheritanceIDiffernetiable); } // The `Differential` type of a `Differential` type is always itself. @@ -2797,9 +2793,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness auto assocTypeDef = m_astBuilder->create(); assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = satisfyingType; - assocTypeDef->parentDecl = aggTypeDecl; + aggTypeDecl->addMember(assocTypeDef); assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - aggTypeDecl->members.add(assocTypeDef); } // Go through all members and collect their differential types. @@ -4319,7 +4314,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( typeParamDeclBase->astNodeType); synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc(); synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex; - synTypeParamDeclBase->parentDecl = synGenericDecl; + synGenericDecl->addMember(synTypeParamDeclBase); // Note: we intentionally do not copy GenericTypeParamDecl::initType here, // because initType maybe dependent on the original type parameters, @@ -4327,7 +4322,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( // synthesized ones. It shouldn't be required for the implementing declaration to define // initType anyways, so we'll just save ourselves from the trouble. // - synGenericDecl->members.add(synTypeParamDeclBase); mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase); @@ -4345,7 +4339,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( { auto synValParamDecl = m_astBuilder->create(); synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc; - synValParamDecl->parentDecl = synGenericDecl; + synGenericDecl->addMember(synValParamDecl); synValParamDecl->parameterIndex = valParamDecl->parameterIndex; synValParamDecl->type = valParamDecl->type; @@ -4355,7 +4349,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( // synthesized ones. It shouldn't be required for the implementing declaration to define // initType anyways, so we'll just save ourselves from the trouble. // - synGenericDecl->members.add(synValParamDecl); mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl); @@ -4537,8 +4530,7 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl( // We need to add the parameter as a child declaration of // the method we are building. // - synParamDecl->parentDecl = synthesized; - synthesized->members.add(synParamDecl); + synthesized->addMember(synParamDecl); // Add modifiers paramType.isLeftValue = true; @@ -5508,8 +5500,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( // We need to add the parameter as a child declaration of // the accessor we are building. // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + synAccessorDecl->addMember(synParamDecl); // For each paramter, we will create an argument expression // to represent it in the body of the accessor. @@ -5567,8 +5558,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( addModifier(synAccessorDecl, m_astBuilder->create()); synAccessorDecl->body = synBodyStmt; - synAccessorDecl->parentDecl = synPropertyDecl; - synPropertyDecl->members.add(synAccessorDecl); + synPropertyDecl->addMember(synAccessorDecl); // Register the synthesized accessor. // @@ -5709,8 +5699,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements( // We need to add the parameter as a child declaration of // the accessor we are building. // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + synAccessorDecl->addMember(synParamDecl); // For each paramter, we will create an argument expression // to represent it in the body of the accessor. @@ -5869,8 +5858,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements( synAccessorDecl->body = synBodyStmt; - synAccessorDecl->parentDecl = synAccesorContainer; - synAccesorContainer->members.add(synAccessorDecl); + synAccesorContainer->addMember(synAccessorDecl); // If synthesis of an accessor worked, then we will record it into // a local dictionary. We do *not* install the accessor into the @@ -6377,7 +6365,9 @@ bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness( } synFunc->loc = context->parentDecl->closingSourceLoc; synFunc->nameAndLoc.loc = synFunc->loc; - context->parentDecl->members.add(synFunc); + // synFunc already has its parent set + SLANG_ASSERT(context->parentDecl == synFunc->parentDecl); + context->parentDecl->addMember(synFunc); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, intrinsicOpModifier); witnessTable->add( @@ -6561,7 +6551,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( seqStmt->stmts.add(synReturn); Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc; - context->parentDecl->members.add(witnessDecl); + SLANG_ASSERT(context->parentDecl == witnessDecl->parentDecl); + context->parentDecl->addMember(witnessDecl); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create()); @@ -7577,11 +7568,10 @@ void SemanticsDeclBasesVisitor::visitStructDecl(StructDecl* decl) IsSubTypeOptions::NoCaching)) { InheritanceDecl* conformanceDecl = m_astBuilder->create(); - conformanceDecl->parentDecl = decl; conformanceDecl->loc = decl->loc; conformanceDecl->base.type = defaultInitializableType; conformanceDecl->nameAndLoc.name = getName("$inheritance"); - decl->members.add(conformanceDecl); + decl->addMember(conformanceDecl); } } @@ -7940,10 +7930,9 @@ void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl) Type* enumTypeType = getASTBuilder()->getEnumTypeType(); InheritanceDecl* enumConformanceDecl = m_astBuilder->create(); - enumConformanceDecl->parentDecl = decl; enumConformanceDecl->loc = decl->loc; enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType(); - decl->members.add(enumConformanceDecl); + decl->addMember(enumConformanceDecl); // The `__EnumType` interface has one required member, the `__Tag` type. // We need to satisfy this requirement automatically, rather than require @@ -9556,8 +9545,7 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( default: break; } - decl->members.add(param); - param->parentDecl = decl; + decl->addMember(param); } } @@ -9576,8 +9564,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* .as(); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; + interfaceDecl->addMember(reqDecl); if (!decl->hasModifier()) { @@ -9596,8 +9583,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* auto reqRef = m_astBuilder->create(); reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + decl->addMember(reqRef); isDiffFunc = true; } if (decl->hasModifier()) @@ -9612,8 +9598,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; + interfaceDecl->addMember(reqDecl); if (!decl->hasModifier()) { // Build decl-ref-type for this-type. @@ -9631,8 +9616,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* auto reqRef = m_astBuilder->create(); reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + decl->addMember(reqRef); } isDiffFunc = true; } @@ -10123,8 +10107,7 @@ void SemanticsDeclHeaderVisitor::visitAbstractStorageDeclCommon(ContainerDecl* d GetterDecl* getterDecl = m_astBuilder->create(); getterDecl->loc = decl->loc; - getterDecl->parentDecl = decl; - decl->members.add(getterDecl); + decl->addMember(getterDecl); } } @@ -10263,8 +10246,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) newValueParam->nameAndLoc.name = getName("newValue"); newValueParam->nameAndLoc.loc = decl->loc; - newValueParam->parentDecl = decl; - decl->members.add(newValueParam); + decl->addMember(newValueParam); } // The new-value parameter is expected to have the @@ -11355,6 +11337,207 @@ static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, Seman } } + +// Replace with <=> in C++20 +template +int compareThreeWays(T a, T b) +{ + if (a > b) + return -1; + else if (b > a) + return 1; + else + return 0; +} + +// lhs and rhs cannot be nullptr +int compareDecls(Decl& lhs, Decl& rhs); + +// lhs and rhs cannot be nullptr +int compareVals(Val& lhs, Val& rhs); + +template +int comparePtrs(T* lhs, T* rhs, Compare const& compare) +{ + int res = 0; + if (lhs == rhs) + res = 0; + else if (!lhs) + res = -1; + else if (!rhs) + res = 1; + else + res = compare(*lhs, *rhs); + return res; +} + +// lhs or rhs might be nullptr +int compareDecls(Decl* lhs, Decl* rhs) +{ + return comparePtrs(lhs, rhs, [&](Decl& lhs, Decl& rhs) { return compareDecls(lhs, rhs); }); +} + +// lhs or rhs might be nullptr +int compareVals(Val* lhs, Val* rhs) +{ + return comparePtrs(lhs, rhs, [&](Val& lhs, Val& rhs) { return compareVals(lhs, rhs); }); +} + +// Compare operands of lhs and rhs from offset, +// and at most count operands, if the capacity allows it. +int compareValOperands(Val& lhs, Val& rhs, Index offset, Index count) +{ + const Index lN = std::clamp(lhs.getOperandCount() - offset, 0, count); + const Index rN = std::clamp(rhs.getOperandCount() - offset, 0, count); + int res = compareThreeWays(lN, rN); + if (res) + return res; + for (Index i = 0; i < lN; ++i) + { + auto lOp = lhs.m_operands[offset + i]; + auto rOp = rhs.m_operands[offset + i]; + res = compareThreeWays(lOp.kind, rOp.kind); + if (res) + { + break; + } + switch (lOp.kind) + { + case ValNodeOperandKind::ConstantValue: + res = compareThreeWays(lOp.getIntConstant(), rOp.getIntConstant()); + break; + case ValNodeOperandKind::ValNode: + res = compareVals(lOp.getVal(), rOp.getVal()); + break; + case ValNodeOperandKind::ASTNode: + res = compareDecls(lOp.getDecl(), rOp.getDecl()); + break; + } + if (res) + { + break; + } + } + return res; +} + +// Compare operands of lhs and rhs from offset to the end +int compareValOperands(Val& lhs, Val& rhs, Index offset) +{ + return compareValOperands(lhs, rhs, offset, std::numeric_limits::max()); +} + +// Find the lowest common ancestor (LCA) of nodes a and b +// Uppon return, a and b are modified to be direct children of the LCA +// Returns nullptr if a and b have no common ancestor +// (e.g. a and b are not in the same module) +// a and b are set to each respective module +ContainerDecl* findDeclsLowestCommonAncestor(Decl*& a, Decl*& b) +{ + auto ascendToRoot = [](Decl*& d) + { + UIndex depth = 0; + while (d->parentDecl) + { + ++depth; + d = d->parentDecl; + } + return depth; + }; + + Decl* aRoot = a; + Decl* bRoot = b; + auto aDepth = ascendToRoot(aRoot); + auto bDepth = ascendToRoot(bRoot); + if (aRoot != bRoot) // Not in the same tree / module + { + a = aRoot; + b = bRoot; + return nullptr; + } + // Level nodes + Decl** toAscend = nullptr; + Decl** reference = nullptr; + UIndex n = 0; + if (aDepth > bDepth) + { + toAscend = &a; + reference = &b; + n = aDepth - bDepth; + } + else if (bDepth > aDepth) + { + toAscend = &b; + reference = &a; + n = bDepth - aDepth; + } + if (n) + { + // Level until toAscend is one level under reference + while (n > UIndex(1)) + { + *toAscend = (*toAscend)->parentDecl; + --n; + } + // If toAscend was a child of reference + if ((*toAscend)->parentDecl == *reference) + { + return (*toAscend)->parentDecl; + } + else + { + *toAscend = (*toAscend)->parentDecl; + } + } + while (a->parentDecl != b->parentDecl) + { + a = a->parentDecl; + b = b->parentDecl; + } + return a->parentDecl; +} + +int compareDecls(Decl& lhs, Decl& rhs) +{ + int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType); + if (res) + return res; + Decl* lLCAChild = &lhs; + Decl* rLCAChild = &rhs; + if (ContainerDecl* lca = findDeclsLowestCommonAncestor(lLCAChild, rLCAChild)) + { + res = compareThreeWays(lca->getDeclIndex(lLCAChild), lca->getDeclIndex(rLCAChild)); + } + else + { + res = comparePtrs( + lLCAChild->getName(), + rLCAChild->getName(), + [](Name const& lName, Name const& rName) + { return strcmp(lName.text.begin(), rName.text.begin()); }); + } + return res; +} + +int compareVals(Val& lhs, Val& rhs) +{ + int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType); + if (res) + return res; + res = compareValOperands(lhs, rhs, 0); + return res; +} + +int compareTypes(Type* lhs, Type* rhs) +{ + return compareVals(lhs, rhs); +} + +int compareTypes(Type& lhs, Type& rhs) +{ + return compareVals(lhs, rhs); +} + static void _getCanonicalConstraintTypes(List& outTypeList, Type* type) { if (auto andType = as(type)) @@ -11379,16 +11562,22 @@ OrderedDictionary> getCanonicalGenericCon for (auto genericTypeConstraintDecl : getMembersOfType(astBuilder, genericDecl)) { - assert( - genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType); - auto typeParamDecl = - as(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); - List* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); - if (!constraintTypes) - continue; - constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); + if (genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType) + { + auto typeParamDecl = as(genericTypeConstraintDecl.getDecl()->sub.type) + ->getDeclRef() + .getDecl(); + List* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); + if (!constraintTypes) + continue; + constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); + } + else + { + SLANG_UNEXPECTED("Cannot extract Cannonical Generic Constraints on non DeclRefTypes. " + "Use getCanonicalGenericConstraints2(...) instead."); + } } - OrderedDictionary> result; for (auto& constraints : genericConstraints) { @@ -11397,8 +11586,57 @@ OrderedDictionary> getCanonicalGenericCon { _getCanonicalConstraintTypes(typeList, type); } - // TODO: we also need to sort the types within the list for each generic type param. - result[constraints.key] = typeList; + const auto typeComparator = [&](Type* lhs, Type* rhs) + { return compareTypes(*lhs, *rhs) < 0; }; + typeList.sort(typeComparator); + result[constraints.key] = std::move(typeList); + } + return result; +} + + +OrderedDictionary> getCanonicalGenericConstraints2( + ASTBuilder* astBuilder, + DeclRef genericDecl) +{ + Dictionary> genericConstraints; + for (auto genericTypeConstraintDecl : + getMembersOfType(astBuilder, genericDecl)) + { + auto subExpr = genericTypeConstraintDecl.getDecl()->sub; + auto supExpr = genericTypeConstraintDecl.getDecl()->sup; + Type* typeToAdd = subExpr.type; + if (typeToAdd) + { + if (!genericConstraints.containsKey(typeToAdd)) + { + genericConstraints[typeToAdd] = HashSet(); + } + genericConstraints[typeToAdd].add(supExpr.type); + } + } + const auto typeComparator = [&](Type* lhs, Type* rhs) { return compareTypes(lhs, rhs) < 0; }; + const List sortedKeys = [&]() + { + List res; + res.reserve(genericConstraints.getCount()); + for (auto& t : genericConstraints) + { + res.add(t.first); + } + res.sort(typeComparator); + return res; + }(); + OrderedDictionary> result; + for (auto& key : sortedKeys) + { + List typeList; + for (auto type : genericConstraints[key]) + { + _getCanonicalConstraintTypes(typeList, type); + } + typeList.sort(typeComparator); + result[key] = std::move(typeList); } return result; } -- cgit v1.2.3