diff options
| author | Ronan <ro.cailleau@gmail.com> | 2025-04-26 21:04:01 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-26 12:04:01 -0700 |
| commit | a5efbb1b775afb2f6b29b37d39947c41744bb005 (patch) | |
| tree | ae5c1e11544d2411816ee6fcfb29b2820d41fcb0 /source | |
| parent | d84aeeffdba388aec7a781c35973bf404d37fe80 (diff) | |
Added getCanonicalGenericConstraints2 (sorts constraints and allows more generic expressions) (#6787)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-base.h | 36 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 348 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-check.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 17 |
10 files changed, 370 insertions, 90 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 72da9cf56..5affcb756 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -146,6 +146,24 @@ struct ValNodeOperand ValNodeOperand() { values.intOperand = 0; } + int64_t getIntConstant() const + { + SLANG_ASSERT(kind == ValNodeOperandKind::ConstantValue); + return values.intOperand; + } + + Val* getVal() const + { + SLANG_ASSERT(kind == ValNodeOperandKind::ValNode); + return (Val*)values.nodeOperand; + } + + Decl* getDecl() const + { + SLANG_ASSERT(kind == ValNodeOperandKind::ASTNode); + return (Decl*)values.nodeOperand; + } + explicit ValNodeOperand(NodeBase* node) { if constexpr (sizeof(values.nodeOperand) < sizeof(values.intOperand)) @@ -424,23 +442,11 @@ class Val : public NodeBase Val* resolveImpl(); Val* resolve(); - Val* getOperand(Index index) const - { - SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ValNode); - return (Val*)m_operands[index].values.nodeOperand; - } + Val* getOperand(Index index) const { return m_operands[index].getVal(); } - Decl* getDeclOperand(Index index) const - { - SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ASTNode); - return (Decl*)(m_operands[index].values.nodeOperand); - } + Decl* getDeclOperand(Index index) const { return m_operands[index].getDecl(); } - int64_t getIntConstOperand(Index index) const - { - SLANG_ASSERT(m_operands[index].kind == ValNodeOperandKind::ConstantValue); - return m_operands[index].values.intOperand; - } + int64_t getIntConstOperand(Index index) const { return m_operands[index].getIntConstant(); } Index getOperandCount() const { return m_operands.getCount(); } diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index 530f983d9..4c5d32f71 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -105,6 +105,20 @@ void ContainerDecl::buildMemberDictionary() SLANG_ASSERT(isMemberDictionaryValid()); } +Index ContainerDecl::getDeclIndex(Decl* decl) +{ + if (Index* ptr = mapDeclMemberToIndex.tryGetValue(decl)) + { + return *ptr; + } + Index res = members.findFirstIndex([&](Decl* d) { return d == decl; }); + if (res >= Index(0)) + { + mapDeclMemberToIndex[decl] = res; + } + return res; +} + bool isLocalVar(const Decl* decl) { const auto varDecl = as<VarDecl>(decl); diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 261d2458a..5b0c883b0 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -47,7 +47,11 @@ class ContainerDecl : public Decl bool isMemberDictionaryValid() const { return dictionaryLastCount == members.getCount(); } - void invalidateMemberDictionary() { dictionaryLastCount = -1; } + void invalidateMemberDictionary() + { + dictionaryLastCount = -1; + mapDeclMemberToIndex.clear(); + } Dictionary<Name*, Decl*>& getMemberDictionary() { @@ -66,10 +70,22 @@ class ContainerDecl : public Decl if (member) { member->parentDecl = this; + auto index = members.getCount(); members.add(member); + mapDeclMemberToIndex[member] = index; } } + static void setParent(ContainerDecl* parent, Decl* child) + { + if (child) + child->parentDecl = parent; + if (parent) + parent->addMember(child); + } + + Index getDeclIndex(Decl* d); + SLANG_UNREFLECTED // We don't want to reflect the following fields private : @@ -84,6 +100,8 @@ class ContainerDecl : public Decl // This is built on demand before performing lookup. Dictionary<Name*, Decl*> memberDictionary; + Dictionary<Decl*, Index> mapDeclMemberToIndex; + // A list of transparent members, to be used in lookup // Note: this is only valid if `memberDictionaryIsValid` is true List<TransparentMemberInfo> transparentMembers; diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 46ba81d16..58c68c369 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -179,8 +179,7 @@ DeclStmt* ASTSynthesizer::emitVarDeclStmt(Type* type, Name* name, Expr* initVal) varDecl->type.type = type; varDecl->nameAndLoc.name = name; varDecl->initExpr = initVal; - varDecl->parentDecl = scope.m_scope->containerDecl; - varDecl->parentDecl->members.add(varDecl); + scope.m_scope->containerDecl->addMember(varDecl); auto stmt = m_builder->create<DeclStmt>(); stmt->decl = varDecl; _addStmtToScope(stmt); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index b68bea39c..591b7edde 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -71,9 +71,7 @@ public: ASTEmitScope scope = getCurrentScope(); auto scopeDecl = m_builder->create<ScopeDecl>(); auto newScope = m_builder->create<Scope>(); - scopeDecl->parentDecl = scope.m_parent; - if (scope.m_parent) - scope.m_parent->members.add(scopeDecl); + ContainerDecl::setParent(scope.m_parent, scopeDecl); newScope->parent = scope.m_scope; newScope->containerDecl = scopeDecl; scope.m_scope = newScope; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e511fbc39..1e524e27f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2626,9 +2626,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = context->conformingType; - assocTypeDef->parentDecl = context->parentDecl; + context->parentDecl->addMember(assocTypeDef); assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - context->parentDecl->members.add(assocTypeDef); markSelfDifferentialMembersOfType( as<AggTypeDecl>(context->parentDecl), @@ -2660,8 +2659,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness if (!aggTypeDecl) { aggTypeDecl = m_astBuilder->create<StructDecl>(); - aggTypeDecl->parentDecl = context->parentDecl; - context->parentDecl->members.add((aggTypeDecl)); + context->parentDecl->addMember(aggTypeDecl); aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; context->parentDecl->invalidateMemberDictionary(); @@ -2719,8 +2717,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness diffField->nameAndLoc = member->nameAndLoc; diffField->type.type = diffMemberType; diffField->checkState = DeclCheckState::SignatureChecked; - diffField->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(diffField); + aggTypeDecl->addMember(diffField); auto visibility = getDeclVisibility(member); addVisibilityModifier(diffField, visibility); @@ -2775,8 +2772,7 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness { auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>(); inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); - inheritanceIDiffernetiable->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(inheritanceIDiffernetiable); + aggTypeDecl->addMember(inheritanceIDiffernetiable); } // The `Differential` type of a `Differential` type is always itself. @@ -2797,9 +2793,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = satisfyingType; - assocTypeDef->parentDecl = aggTypeDecl; + aggTypeDecl->addMember(assocTypeDef); assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); - aggTypeDecl->members.add(assocTypeDef); } // Go through all members and collect their differential types. @@ -4319,7 +4314,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( typeParamDeclBase->astNodeType); synTypeParamDeclBase->nameAndLoc = typeParamDeclBase->getNameAndLoc(); synTypeParamDeclBase->parameterIndex = typeParamDeclBase->parameterIndex; - synTypeParamDeclBase->parentDecl = synGenericDecl; + synGenericDecl->addMember(synTypeParamDeclBase); // Note: we intentionally do not copy GenericTypeParamDecl::initType here, // because initType maybe dependent on the original type parameters, @@ -4327,7 +4322,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( // synthesized ones. It shouldn't be required for the implementing declaration to define // initType anyways, so we'll just save ourselves from the trouble. // - synGenericDecl->members.add(synTypeParamDeclBase); mapOrigToSynTypeParams.add(typeParamDeclBase, synTypeParamDeclBase); @@ -4345,7 +4339,7 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( { auto synValParamDecl = m_astBuilder->create<GenericValueParamDecl>(); synValParamDecl->nameAndLoc = valParamDecl->nameAndLoc; - synValParamDecl->parentDecl = synGenericDecl; + synGenericDecl->addMember(synValParamDecl); synValParamDecl->parameterIndex = valParamDecl->parameterIndex; synValParamDecl->type = valParamDecl->type; @@ -4355,7 +4349,6 @@ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( // synthesized ones. It shouldn't be required for the implementing declaration to define // initType anyways, so we'll just save ourselves from the trouble. // - synGenericDecl->members.add(synValParamDecl); mapOrigToSynTypeParams.add(valParamDecl, synGenericDecl); @@ -4537,8 +4530,7 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl( // We need to add the parameter as a child declaration of // the method we are building. // - synParamDecl->parentDecl = synthesized; - synthesized->members.add(synParamDecl); + synthesized->addMember(synParamDecl); // Add modifiers paramType.isLeftValue = true; @@ -5508,8 +5500,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( // We need to add the parameter as a child declaration of // the accessor we are building. // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + synAccessorDecl->addMember(synParamDecl); // For each paramter, we will create an argument expression // to represent it in the body of the accessor. @@ -5567,8 +5558,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness( addModifier(synAccessorDecl, m_astBuilder->create<ForceInlineAttribute>()); synAccessorDecl->body = synBodyStmt; - synAccessorDecl->parentDecl = synPropertyDecl; - synPropertyDecl->members.add(synAccessorDecl); + synPropertyDecl->addMember(synAccessorDecl); // Register the synthesized accessor. // @@ -5709,8 +5699,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements( // We need to add the parameter as a child declaration of // the accessor we are building. // - synParamDecl->parentDecl = synAccessorDecl; - synAccessorDecl->members.add(synParamDecl); + synAccessorDecl->addMember(synParamDecl); // For each paramter, we will create an argument expression // to represent it in the body of the accessor. @@ -5869,8 +5858,7 @@ bool SemanticsVisitor::synthesizeAccessorRequirements( synAccessorDecl->body = synBodyStmt; - synAccessorDecl->parentDecl = synAccesorContainer; - synAccesorContainer->members.add(synAccessorDecl); + synAccesorContainer->addMember(synAccessorDecl); // If synthesis of an accessor worked, then we will record it into // a local dictionary. We do *not* install the accessor into the @@ -6377,7 +6365,9 @@ bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness( } synFunc->loc = context->parentDecl->closingSourceLoc; synFunc->nameAndLoc.loc = synFunc->loc; - context->parentDecl->members.add(synFunc); + // synFunc already has its parent set + SLANG_ASSERT(context->parentDecl == synFunc->parentDecl); + context->parentDecl->addMember(synFunc); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, intrinsicOpModifier); witnessTable->add( @@ -6561,7 +6551,8 @@ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( seqStmt->stmts.add(synReturn); Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc; - context->parentDecl->members.add(witnessDecl); + SLANG_ASSERT(context->parentDecl == witnessDecl->parentDecl); + context->parentDecl->addMember(witnessDecl); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); @@ -7577,11 +7568,10 @@ void SemanticsDeclBasesVisitor::visitStructDecl(StructDecl* decl) IsSubTypeOptions::NoCaching)) { InheritanceDecl* conformanceDecl = m_astBuilder->create<InheritanceDecl>(); - conformanceDecl->parentDecl = decl; conformanceDecl->loc = decl->loc; conformanceDecl->base.type = defaultInitializableType; conformanceDecl->nameAndLoc.name = getName("$inheritance"); - decl->members.add(conformanceDecl); + decl->addMember(conformanceDecl); } } @@ -7940,10 +7930,9 @@ void SemanticsDeclBasesVisitor::visitEnumDecl(EnumDecl* decl) Type* enumTypeType = getASTBuilder()->getEnumTypeType(); InheritanceDecl* enumConformanceDecl = m_astBuilder->create<InheritanceDecl>(); - enumConformanceDecl->parentDecl = decl; enumConformanceDecl->loc = decl->loc; enumConformanceDecl->base.type = getASTBuilder()->getEnumTypeType(); - decl->members.add(enumConformanceDecl); + decl->addMember(enumConformanceDecl); // The `__EnumType` interface has one required member, the `__Tag` type. // We need to satisfy this requirement automatically, rather than require @@ -9556,8 +9545,7 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( default: break; } - decl->members.add(param); - param->parentDecl = decl; + decl->addMember(param); } } @@ -9576,8 +9564,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* .as<CallableDecl>(); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; + interfaceDecl->addMember(reqDecl); if (!decl->hasModifier<NoDiffThisAttribute>()) { @@ -9596,8 +9583,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + decl->addMember(reqRef); isDiffFunc = true; } if (decl->hasModifier<BackwardDifferentiableAttribute>()) @@ -9612,8 +9598,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; + interfaceDecl->addMember(reqDecl); if (!decl->hasModifier<NoDiffThisAttribute>()) { // Build decl-ref-type for this-type. @@ -9631,8 +9616,7 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + decl->addMember(reqRef); } isDiffFunc = true; } @@ -10123,8 +10107,7 @@ void SemanticsDeclHeaderVisitor::visitAbstractStorageDeclCommon(ContainerDecl* d GetterDecl* getterDecl = m_astBuilder->create<GetterDecl>(); getterDecl->loc = decl->loc; - getterDecl->parentDecl = decl; - decl->members.add(getterDecl); + decl->addMember(getterDecl); } } @@ -10263,8 +10246,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl) newValueParam->nameAndLoc.name = getName("newValue"); newValueParam->nameAndLoc.loc = decl->loc; - newValueParam->parentDecl = decl; - decl->members.add(newValueParam); + decl->addMember(newValueParam); } // The new-value parameter is expected to have the @@ -11355,6 +11337,207 @@ static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, Seman } } + +// Replace with <=> in C++20 +template<typename T> +int compareThreeWays(T a, T b) +{ + if (a > b) + return -1; + else if (b > a) + return 1; + else + return 0; +} + +// lhs and rhs cannot be nullptr +int compareDecls(Decl& lhs, Decl& rhs); + +// lhs and rhs cannot be nullptr +int compareVals(Val& lhs, Val& rhs); + +template<typename T, class Compare> +int comparePtrs(T* lhs, T* rhs, Compare const& compare) +{ + int res = 0; + if (lhs == rhs) + res = 0; + else if (!lhs) + res = -1; + else if (!rhs) + res = 1; + else + res = compare(*lhs, *rhs); + return res; +} + +// lhs or rhs might be nullptr +int compareDecls(Decl* lhs, Decl* rhs) +{ + return comparePtrs(lhs, rhs, [&](Decl& lhs, Decl& rhs) { return compareDecls(lhs, rhs); }); +} + +// lhs or rhs might be nullptr +int compareVals(Val* lhs, Val* rhs) +{ + return comparePtrs(lhs, rhs, [&](Val& lhs, Val& rhs) { return compareVals(lhs, rhs); }); +} + +// Compare operands of lhs and rhs from offset, +// and at most count operands, if the capacity allows it. +int compareValOperands(Val& lhs, Val& rhs, Index offset, Index count) +{ + const Index lN = std::clamp<Index>(lhs.getOperandCount() - offset, 0, count); + const Index rN = std::clamp<Index>(rhs.getOperandCount() - offset, 0, count); + int res = compareThreeWays(lN, rN); + if (res) + return res; + for (Index i = 0; i < lN; ++i) + { + auto lOp = lhs.m_operands[offset + i]; + auto rOp = rhs.m_operands[offset + i]; + res = compareThreeWays(lOp.kind, rOp.kind); + if (res) + { + break; + } + switch (lOp.kind) + { + case ValNodeOperandKind::ConstantValue: + res = compareThreeWays(lOp.getIntConstant(), rOp.getIntConstant()); + break; + case ValNodeOperandKind::ValNode: + res = compareVals(lOp.getVal(), rOp.getVal()); + break; + case ValNodeOperandKind::ASTNode: + res = compareDecls(lOp.getDecl(), rOp.getDecl()); + break; + } + if (res) + { + break; + } + } + return res; +} + +// Compare operands of lhs and rhs from offset to the end +int compareValOperands(Val& lhs, Val& rhs, Index offset) +{ + return compareValOperands(lhs, rhs, offset, std::numeric_limits<Index>::max()); +} + +// Find the lowest common ancestor (LCA) of nodes a and b +// Uppon return, a and b are modified to be direct children of the LCA +// Returns nullptr if a and b have no common ancestor +// (e.g. a and b are not in the same module) +// a and b are set to each respective module +ContainerDecl* findDeclsLowestCommonAncestor(Decl*& a, Decl*& b) +{ + auto ascendToRoot = [](Decl*& d) + { + UIndex depth = 0; + while (d->parentDecl) + { + ++depth; + d = d->parentDecl; + } + return depth; + }; + + Decl* aRoot = a; + Decl* bRoot = b; + auto aDepth = ascendToRoot(aRoot); + auto bDepth = ascendToRoot(bRoot); + if (aRoot != bRoot) // Not in the same tree / module + { + a = aRoot; + b = bRoot; + return nullptr; + } + // Level nodes + Decl** toAscend = nullptr; + Decl** reference = nullptr; + UIndex n = 0; + if (aDepth > bDepth) + { + toAscend = &a; + reference = &b; + n = aDepth - bDepth; + } + else if (bDepth > aDepth) + { + toAscend = &b; + reference = &a; + n = bDepth - aDepth; + } + if (n) + { + // Level until toAscend is one level under reference + while (n > UIndex(1)) + { + *toAscend = (*toAscend)->parentDecl; + --n; + } + // If toAscend was a child of reference + if ((*toAscend)->parentDecl == *reference) + { + return (*toAscend)->parentDecl; + } + else + { + *toAscend = (*toAscend)->parentDecl; + } + } + while (a->parentDecl != b->parentDecl) + { + a = a->parentDecl; + b = b->parentDecl; + } + return a->parentDecl; +} + +int compareDecls(Decl& lhs, Decl& rhs) +{ + int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType); + if (res) + return res; + Decl* lLCAChild = &lhs; + Decl* rLCAChild = &rhs; + if (ContainerDecl* lca = findDeclsLowestCommonAncestor(lLCAChild, rLCAChild)) + { + res = compareThreeWays(lca->getDeclIndex(lLCAChild), lca->getDeclIndex(rLCAChild)); + } + else + { + res = comparePtrs( + lLCAChild->getName(), + rLCAChild->getName(), + [](Name const& lName, Name const& rName) + { return strcmp(lName.text.begin(), rName.text.begin()); }); + } + return res; +} + +int compareVals(Val& lhs, Val& rhs) +{ + int res = compareThreeWays(lhs.astNodeType, rhs.astNodeType); + if (res) + return res; + res = compareValOperands(lhs, rhs, 0); + return res; +} + +int compareTypes(Type* lhs, Type* rhs) +{ + return compareVals(lhs, rhs); +} + +int compareTypes(Type& lhs, Type& rhs) +{ + return compareVals(lhs, rhs); +} + static void _getCanonicalConstraintTypes(List<Type*>& outTypeList, Type* type) { if (auto andType = as<AndType>(type)) @@ -11379,16 +11562,22 @@ OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericCon for (auto genericTypeConstraintDecl : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDecl)) { - assert( - genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType); - auto typeParamDecl = - as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); - List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); - if (!constraintTypes) - continue; - constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); + if (genericTypeConstraintDecl.getDecl()->sub.type->astNodeType == ASTNodeType::DeclRefType) + { + auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type) + ->getDeclRef() + .getDecl(); + List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); + if (!constraintTypes) + continue; + constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); + } + else + { + SLANG_UNEXPECTED("Cannot extract Cannonical Generic Constraints on non DeclRefTypes. " + "Use getCanonicalGenericConstraints2(...) instead."); + } } - OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> result; for (auto& constraints : genericConstraints) { @@ -11397,8 +11586,57 @@ OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericCon { _getCanonicalConstraintTypes(typeList, type); } - // TODO: we also need to sort the types within the list for each generic type param. - result[constraints.key] = typeList; + const auto typeComparator = [&](Type* lhs, Type* rhs) + { return compareTypes(*lhs, *rhs) < 0; }; + typeList.sort(typeComparator); + result[constraints.key] = std::move(typeList); + } + return result; +} + + +OrderedDictionary<Type*, List<Type*>> getCanonicalGenericConstraints2( + ASTBuilder* astBuilder, + DeclRef<ContainerDecl> genericDecl) +{ + Dictionary<Type*, HashSet<Type*>> genericConstraints; + for (auto genericTypeConstraintDecl : + getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericDecl)) + { + auto subExpr = genericTypeConstraintDecl.getDecl()->sub; + auto supExpr = genericTypeConstraintDecl.getDecl()->sup; + Type* typeToAdd = subExpr.type; + if (typeToAdd) + { + if (!genericConstraints.containsKey(typeToAdd)) + { + genericConstraints[typeToAdd] = HashSet<Type*>(); + } + genericConstraints[typeToAdd].add(supExpr.type); + } + } + const auto typeComparator = [&](Type* lhs, Type* rhs) { return compareTypes(lhs, rhs) < 0; }; + const List<Type*> sortedKeys = [&]() + { + List<Type*> res; + res.reserve(genericConstraints.getCount()); + for (auto& t : genericConstraints) + { + res.add(t.first); + } + res.sort(typeComparator); + return res; + }(); + OrderedDictionary<Type*, List<Type*>> result; + for (auto& key : sortedKeys) + { + List<Type*> typeList; + for (auto type : genericConstraints[key]) + { + _getCanonicalConstraintTypes(typeList, type); + } + typeList.sort(typeComparator); + result[key] = std::move(typeList); } return result; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 7b774f300..2c595dd4a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -663,8 +663,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( auto structDecl = m_astBuilder->create<StructDecl>(); auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); - conformanceDecl->parentDecl = structDecl; - structDecl->members.add(conformanceDecl); + structDecl->addMember(conformanceDecl); structDecl->parentDecl = parent; synthesizedDecl = structDecl; @@ -678,10 +677,9 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); structDecl->members.add(typeDef); - synthesizedDecl->parentDecl = parent; synthesizedDecl->nameAndLoc.name = item.declRef.getName(); synthesizedDecl->loc = parent->loc; - parent->members.add(synthesizedDecl); + parent->addMember(synthesizedDecl); parent->invalidateMemberDictionary(); // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can @@ -697,7 +695,6 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( // auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = item.declRef.getName(); - typeDef->parentDecl = parent; // Compute the decl's type as if it is referred to from itself. This is important // because subType may have substitutions from the context it is used in, while this @@ -708,7 +705,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( synthesizedDecl = parent; - parent->members.add(typeDef); + parent->addMember(typeDef); parent->invalidateMemberDictionary(); markSelfDifferentialMembersOfType(parent, subType); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e0b203fb6..cebcbe540 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -285,8 +285,7 @@ AttributeDecl* SemanticsVisitor::lookUpAttributeDecl(Name* attributeName, Scope* paramDecl->loc = member->loc; paramDecl->setCheckState(DeclCheckState::DefinitionChecked); - paramDecl->parentDecl = attrDecl; - attrDecl->members.add(paramDecl); + attrDecl->addMember(paramDecl); } } @@ -297,8 +296,7 @@ AttributeDecl* SemanticsVisitor::lookUpAttributeDecl(Name* attributeName, Scope* // // TODO: handle the case where `parentDecl` is generic? // - attrDecl->parentDecl = parentDecl; - parentDecl->members.add(attrDecl); + parentDecl->addMember(attrDecl); SLANG_ASSERT(!parentDecl->isMemberDictionaryValid()); diff --git a/source/slang/slang-check.h b/source/slang/slang-check.h index f1392e9ce..f4d86cff9 100644 --- a/source/slang/slang-check.h +++ b/source/slang/slang-check.h @@ -29,4 +29,7 @@ Type* unwrapModifiedType(Type* type); OrderedDictionary<GenericTypeParamDeclBase*, List<Type*>> getCanonicalGenericConstraints( ASTBuilder* builder, DeclRef<ContainerDecl> genericDecl); +OrderedDictionary<Type*, List<Type*>> getCanonicalGenericConstraints2( + ASTBuilder* builder, + DeclRef<ContainerDecl> genericDecl); } // namespace Slang diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index f5878cb1d..aa30eef9d 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -576,14 +576,23 @@ void emitQualifiedName(ManglingContext* context, DeclRef<Decl> declRef, bool inc } auto canonicalizedConstraints = - getCanonicalGenericConstraints(context->astBuilder, parentGenericDeclRef); + getCanonicalGenericConstraints2(context->astBuilder, parentGenericDeclRef); for (auto& constraint : canonicalizedConstraints) { - for (auto type : constraint.value) + if (constraint.value.getCount() > 0) { emitRaw(context, "C"); - emitQualifiedName(context, makeDeclRef(constraint.key), true); - emitType(context, type); + emitType(context, constraint.key); + int counter = 0; + for (auto type : constraint.value) + { + if (counter > 0) + { + emitRaw(context, "_"); + } + ++counter; + emitType(context, type); + } } } } |
