diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 34 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 63 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-inheritance.cpp | 157 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 166 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 3 |
17 files changed, 439 insertions, 102 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 4ed38ef69..36aa3a313 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -583,6 +583,11 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl TypeExp sub; TypeExp sup; + // If this decl is defined in a where clause, store the source location of the where token. + SourceLoc whereTokenLoc = SourceLoc(); + + bool isEqualityConstraint = false; + // Overrides should be public so base classes can access const TypeExp& _getSupOverride() const { return sup; } }; diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 40a943dc0..24f98391c 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -477,6 +477,18 @@ void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl) } else if (auto typeConstraint = as<TypeConstraintDecl>(decl)) { + if (auto genericTypeConstraint = as<GenericTypeConstraintDecl>(typeConstraint)) + { + // A generic constraint decl has a left hand side and right hand side expression + // for the base and super type of the constraint. + // In the case of a folded-in constraint syntax as in `Foo<T:IBar>`, + // the left hand side of the constraint is represented by the same token + // as the parameter decl itself, so we don't need to traverse into it. + // In the case of `Foo<T> where T:IBar`, the left hand side is its own + // expression so we do want to traverse it. + if (genericTypeConstraint->whereTokenLoc.isValid()) + visitExpr(genericTypeConstraint->sub.exp); + } visitExpr(typeConstraint->getSup().exp); } else if (auto typedefDecl = as<TypeDefDecl>(decl)) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 21948dc04..83a4cf353 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1510,14 +1510,6 @@ namespace Slang const RequirementDictionary& getRequirementDictionary() { - if (m_requirementDictionary.getCount() != m_requirements.getCount()) - { - for (Index i = m_requirementDictionary.getCount(); i < m_requirements.getCount(); i++) - { - auto& r = m_requirements[i]; - m_requirementDictionary.add(r.key, r.value); - } - } return m_requirementDictionary; } @@ -1532,11 +1524,8 @@ namespace Slang // Whether or not this witness table is an extern declaration. bool isExtern = false; - // Satisfying values of each requirement. - List<KeyValuePair<Decl*, RequirementWitness>> m_requirements; - // Cached dictionary for looking up satisfying values. - SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary; + RequirementDictionary m_requirementDictionary; RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 3a1318696..401d73e29 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -72,7 +72,7 @@ class DeclRefType : public Type }; template<typename T> -DeclRef<T> isDeclRefTypeOf(Type* type) +DeclRef<T> isDeclRefTypeOf(Val* type) { if (auto declRefType = as<DeclRefType>(type)) { diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 2599ce46a..8752195cb 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -626,6 +626,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness return as<DeclRefBase>(getOperand(2)); } + bool isEquality() + { + if (auto declRef = getDeclRef().as<GenericTypeConstraintDecl>()) + return declRef.getDecl()->isEqualityConstraint; + return false; + } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -895,4 +902,31 @@ void SubstitutionSet::forEachSubstitutionArg(F func) const } } } + +inline bool isTypeEqualityWitness(Val* witness) +{ + if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness)) + { + return declaredWitness->isEquality(); + } + else if (as<TypeEqualityWitness>(witness)) + { + return true; + } + else if (auto eachWitness = as<EachSubtypeWitness>(witness)) + { + return isTypeEqualityWitness(eachWitness->getPatternTypeWitness()); + } + else if (auto typePackWitness = as<TypePackSubtypeWitness>(witness)) + { + for (Index i = 0; i < typePackWitness->getCount(); i++) + { + if (!isTypeEqualityWitness(typePackWitness->getWitness(i))) + return false; + } + return true; + } + return false; +} + } // namespace Slang diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 3e3ed5297..90b0e44f5 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -648,7 +648,7 @@ namespace Slang else if (auto subEachType = as<EachType>(constraintDeclRef.getDecl()->sub.type)) constrainedGenericParams.add(as<DeclRefType>(subEachType->getElementType())->getDeclRef().getDecl()); - if (sub->equals(sup)) + if (sub->equals(sup) && isDeclRefTypeOf<InterfaceDecl>(sup)) { // We are trying to use an interface type itself to conform to the // type constraint. We can reach this case when the user code does @@ -674,6 +674,14 @@ namespace Slang sup, system->additionalSubtypeWitnesses ? IsSubTypeOptions::NoCaching : IsSubTypeOptions::None); } + + if (constraintDecl->isEqualityConstraint) + { + // If constraint is an equality constraint, we need to make sure + // the witness is equality witness. + if (!isTypeEqualityWitness(subTypeWitness)) + subTypeWitness = nullptr; + } if(subTypeWitness) { diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 85040dc55..c0d7feaff 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1019,6 +1019,22 @@ namespace Slang *outCost = kConversionCost_CastToInterface; return true; } + else if (auto fromIsToWitness = tryGetSubtypeWitness(toType, fromType)) + { + // Is toType and fromType the same via some type equality witness? + // If so there is no need to do any conversion. + // + if (isTypeEqualityWitness(fromIsToWitness)) + { + if (outToExpr) + { + *outToExpr = createCastToSuperTypeExpr(toType, fromExpr, fromIsToWitness); + } + if (outCost) + *outCost = 0; + return true; + } + } // Disallow converting to a ParameterGroupType. // diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ad3f94fc3..02e3241a9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2278,14 +2278,18 @@ namespace Slang markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable)) { - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. m_astBuilder->incrementEpoch(); return true; } + else + { + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); + } // Something went wrong. return false; @@ -2471,19 +2475,16 @@ namespace Slang // conformance on the synthesized decl. checkAggTypeConformance(aggTypeDecl); - if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + if (!doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) { - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - - // Incrase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. - m_astBuilder->incrementEpoch(); - return true; + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. + // If not, there is something wrong with the code synthesis logic. For now we just return false + // instead of crashing so the user can work around the issues. + witnessTable->m_requirementDictionary.remove(requirementDeclRef.getDecl()); + return false; } - - // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. - // If not, there is something wrong with the code synthesis logic. For now we just return false - // instead of crashing so the user can work around the issues. - return false; + return true; } void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) @@ -2497,7 +2498,7 @@ namespace Slang CheckConstraintSubType(decl->sub); decl->sub = TranslateTypeNodeForced(decl->sub); decl->sup = TranslateTypeNodeForced(decl->sup); - if (!isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type)) + if (!decl->isEqualityConstraint && !isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sub.type)) { getSink()->diagnose(decl->sup.exp, Diagnostics::invalidTypeForConstraint, decl->sup); } @@ -3548,18 +3549,28 @@ namespace Slang bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable) { + SLANG_UNUSED(satisfyingType); + // We will enumerate the type constraints placed on the // associated type and see if they can be satisfied. // bool conformance = true; Val* witness = nullptr; - for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef)) + for (auto requiredConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, requiredAssociatedTypeDeclRef)) { // Grab the type we expect to conform to from the constraint. auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); + auto subType = getSub(m_astBuilder, requiredConstraintDeclRef); + // Perform a search for a witness to the subtype relationship. - witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + witness = tryGetSubtypeWitness(subType, requiredSuperType); + if (witness) + { + auto genConstraint = as<GenericTypeConstraintDecl>(requiredConstraintDeclRef.getDecl()); + if (genConstraint && genConstraint->isEqualityConstraint && !isTypeEqualityWitness(witness)) + witness = nullptr; + } if (witness) { // If a subtype witness was found, then the conformance @@ -3588,6 +3599,15 @@ namespace Slang if (declRefType->getDeclRef().getDecl()->hasModifier<ToBeSynthesizedModifier>()) return false; } + + // Register the satisfying type to the witness table + // before checking the constraints, since the subtype of + // the constraints maybe referencing the satisfying type via + // witness lookups. + auto requirementWitness = RequirementWitness(satisfyingType->getCanonicalType()); + witnessTable->m_requirementDictionary[requiredAssociatedTypeDeclRef.getDecl()] + = requirementWitness; + // We need to confirm that the chosen type `satisfyingType`, // meets all the constraints placed on the associated type // requirement `requiredAssociatedTypeDeclRef`. @@ -3601,15 +3621,11 @@ namespace Slang // TODO: if any conformance check failed, we should probably include // that in an error message produced about not satisfying the requirement. - if(conformance) + if (!conformance) { - // If all the constraints were satisfied, then the chosen - // type can indeed satisfy the interface requirement. - witnessTable->add( - requiredAssociatedTypeDeclRef.getDecl(), - RequirementWitness(satisfyingType->getCanonicalType())); + witnessTable->m_requirementDictionary.remove(requiredAssociatedTypeDeclRef.getDecl()); } - + return conformance; } @@ -9667,7 +9683,8 @@ namespace Slang ASTNodeType::DeclRefType); auto typeParamDecl = as<DeclRefType>(genericTypeConstraintDecl.getDecl()->sub.type)->getDeclRef().getDecl(); List<Type*>* constraintTypes = genericConstraints.tryGetValue(typeParamDecl); - assert(constraintTypes); + if (!constraintTypes) + continue; constraintTypes->add(genericTypeConstraintDecl.getDecl()->getSup().type); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f39d5be1d..dc4568f8a 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -766,6 +766,12 @@ namespace Slang InheritanceInfo _calcInheritanceInfo(Type* type, InheritanceCircularityInfo* circularityInfo); InheritanceInfo _calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* correspondingType, InheritanceCircularityInfo* circularityInfo); + // Get the inner most generic decl that a decl-ref is dependent on. + // For example, `Foo<T>` depends on the generic decl that defines `T`. + // + DeclRef<GenericDecl> getDependentGenericParent(DeclRef<Decl> declRef); + void getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef); + struct DirectBaseInfo { FacetList facets; diff --git a/source/slang/slang-check-inheritance.cpp b/source/slang/slang-check-inheritance.cpp index 20f41c1bb..0d3929901 100644 --- a/source/slang/slang-check-inheritance.cpp +++ b/source/slang/slang-check-inheritance.cpp @@ -97,6 +97,51 @@ namespace Slang return info; } + void SharedSemanticsContext::getDependentGenericParentImpl(DeclRef<GenericDecl>& genericParent, DeclRef<Decl> declRef) + { + auto mergeParent = [](DeclRef<GenericDecl>& currentParent, DeclRef<GenericDecl> newParent) + { + if (!currentParent) + { + currentParent = newParent; + return; + } + if (currentParent == newParent) + return; + if (newParent.getDecl()->isChildOf(currentParent.getDecl())) + currentParent = newParent; + }; + + if (declRef.as<GenericTypeParamDeclBase>()) + { + if (!genericParent) + mergeParent(genericParent, declRef.getParent().as<GenericDecl>()); + return; + } + else if (auto lookupDeclRef = as<LookupDeclRef>(declRef.declRefBase)) + { + if (auto lookupSourceDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource())) + getDependentGenericParentImpl(genericParent, lookupSourceDeclRef); + } + else if (auto genericAppDeclRef = as<GenericAppDeclRef>(declRef.declRefBase)) + { + for (Index i = 0; i < genericAppDeclRef->getArgCount(); i++) + { + if (auto argDeclRef = isDeclRefTypeOf<Decl>(genericAppDeclRef->getArg(i))) + { + getDependentGenericParentImpl(genericParent, argDeclRef); + } + } + } + } + + DeclRef<GenericDecl> SharedSemanticsContext::getDependentGenericParent(DeclRef<Decl> declRef) + { + DeclRef<GenericDecl> genericParent; + getDependentGenericParentImpl(genericParent, declRef); + return genericParent; + } + InheritanceInfo SharedSemanticsContext::_calcInheritanceInfo(DeclRef<Decl> declRef, DeclRefType* declRefType, InheritanceCircularityInfo* circularityInfo) { // This method is the main engine for computing linearized inheritance @@ -305,39 +350,82 @@ namespace Slang // We now look at the structure of the declaration itself // to help us enumerate the direct bases. // - if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>()) + auto currentDeclRef = declRef; + for (; currentDeclRef;) { - // In the case where we have an aggregate type or `extension` - // declaration, we can use the explicit list of direct bases. - // - for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) + if (auto aggTypeDeclBaseRef = currentDeclRef.as<AggTypeDeclBase>()) { - visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - - // Note: In certain cases something takes the *syntactic* form of an inheritance - // clause, but it is not actually something that should be treated as implying - // a subtype relationship. For example, an `enum` declaration can use what looks - // like an inheritance clause to indicate its underlying "tag type." - // - // We skip such pseudo-inheritance relationships for the purposes of determining - // the linearized list of bases. + // In the case where we have an aggregate type or `extension` + // declaration, we can use the explicit list of direct bases. // - if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) - continue; + for (auto typeConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(_getASTBuilder(), aggTypeDeclBaseRef)) + { + // Note: In certain cases something takes the *syntactic* form of an inheritance + // clause, but it is not actually something that should be treated as implying + // a subtype relationship. For example, an `enum` declaration can use what looks + // like an inheritance clause to indicate its underlying "tag type." + // + // We skip such pseudo-inheritance relationships for the purposes of determining + // the linearized list of bases. + // + if (typeConstraintDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) + continue; - // The base type and subtype witness can easily be determined - // using the `InheritanceDecl`. - // - auto baseType = getSup(astBuilder, typeConstraintDeclRef); - auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( - selfType, - baseType, - typeConstraintDeclRef); + // The only case we will ever see a GenericTypeConstraintDecl inside a AggTypeDecl is when + // AggTypeDecl is a associatedtype decl. In this case, we will only lookup the type constraint + // if the constraint is on the associated type itself. + // + auto genericTypeConstraintDeclRef = typeConstraintDeclRef.as<GenericTypeConstraintDecl>(); + if (genericTypeConstraintDeclRef) + { + // If the base expr on the constraint isn't even a `VarExpr`, then it can't be referencing + // the associated type itself and we can skip this constraint. + if (!genericTypeConstraintDeclRef.getDecl()->sub.type + && !as<VarExpr>(genericTypeConstraintDeclRef.getDecl()->sub.exp)) + continue; + } + + visitor.ensureDecl(typeConstraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); - addDirectBaseType(baseType, satisfyingWitness); + // For generic type constraint decls, always make sure it is about the type being checked. + // + if (genericTypeConstraintDeclRef) + { + auto subType = getSub(astBuilder, genericTypeConstraintDeclRef); + if (subType != selfType) + continue; + } + else if (currentDeclRef != declRef) + { + continue; + } + // The base type and subtype witness can easily be determined + // using the `InheritanceDecl`. + // + auto baseType = getSup(astBuilder, typeConstraintDeclRef); + auto satisfyingWitness = astBuilder->getDeclaredSubtypeWitness( + selfType, + baseType, + typeConstraintDeclRef); + + addDirectBaseType(baseType, satisfyingWitness); + } } + if (currentDeclRef.as<AssocTypeDecl>()) + { + // If the current type is an associated type, continue inspecting the base/parent of the + // associatedtype to discover additional constraints defined on the parent associatedtype decls. + // + if (auto lookupDeclRef = as<LookupDeclRef>(currentDeclRef.declRefBase)) + { + currentDeclRef = isDeclRefTypeOf<Decl>(lookupDeclRef->getLookupSource()).as<AssocTypeDecl>(); + continue; + } + } + break; } - else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDeclBase>()) + + if (auto genericDeclRef = getDependentGenericParent(declRef)) { // The constraints placed on a generic type parameter are siblings of that // parameter in its parent `GenericDecl`, so we need to enumerate all of @@ -349,13 +437,11 @@ namespace Slang // representation would need to take into account canonicalization of // constraints. - auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); - SLANG_ASSERT(genericDeclRef); ensureDecl(&visitor, genericDeclRef.getDecl(), DeclCheckState::CanSpecializeGeneric); if (auto extensionDecl = as<ExtensionDecl>(genericDeclRef.getDecl()->inner)) { - if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == genericTypeParamDeclRef) + if (isDeclRefTypeOf<GenericTypeParamDecl>(extensionDecl->targetType.type) == declRef) { // If `T` is a generic parameter where the same generic is an extension on `T`, // then we need to add the extension itself as a facet. @@ -377,13 +463,10 @@ namespace Slang auto superType = getSup(astBuilder, constraintDeclRef); // We only consider constraints where the type represented - // by `genericTypeParamDeclRef` is the subtype, since those + // by `declRef` is the subtype, since those // constraints are the ones that give us information about // the declared supertypes. // - // TODO: consider whether other kinds of constraints could - // also apply here. - // auto subDeclRefType = as<DeclRefType>(subType); if (!subDeclRefType) { @@ -394,7 +477,7 @@ namespace Slang if (!subDeclRefType) continue; } - if (subDeclRefType->getDeclRef() != genericTypeParamDeclRef) + if (subDeclRefType->getDeclRef() != declRef) continue; // Because the constraint is a declared inheritance relationship, @@ -402,9 +485,9 @@ namespace Slang // as in all the preceding cases. // auto satisfyingWitness = _getASTBuilder()->getDeclaredSubtypeWitness( - selfType, - superType, - constraintDeclRef); + selfType, + superType, + constraintDeclRef); addDirectBaseType(superType, satisfyingWitness); } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f308f340d..5afca48b3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -291,6 +291,10 @@ INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE) // A placeholder witness that ThisType implements the enclosing interface. // Used only in interface definitions. INST(ThisTypeWitness, thisTypeWitness, 1, 0) + +// A placeholder witness for the fact that two types are equal. +INST(TypeEqualityWitness, TypeEqualityWitness, 2, HOISTABLE) + INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0) INST(Module, module, 0, PARENT) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 8f648aabd..fc963697a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4075,6 +4075,8 @@ public: IRInst* createThisTypeWitness(IRType* interfaceType); + IRInst* getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2); + IRInterfaceRequirementEntry* createInterfaceRequirementEntry( IRInst* requirementKey, IRInst* requirementVal); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index acc6abd57..d02c01105 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4614,6 +4614,15 @@ namespace Slang return witness; } + IRInst* IRBuilder::getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2) + { + IRInst* operands[2] = { type1, type2 }; + return (IRType*)createIntrinsicInst( + witnessType, + kIROp_TypeEqualityWitness, + 2, + operands); + } IRStructType* IRBuilder::createStructType() { @@ -8347,6 +8356,7 @@ namespace Slang case kIROp_InterfaceRequirementEntry: case kIROp_Block: case kIROp_Each: + case kIROp_TypeEqualityWitness: return false; /// Liveness markers have no side effects diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 2d6ed2568..06b7c937f 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -717,6 +717,14 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) ASTLookupExprVisitor visitor(&context); if (visitor.dispatchIfNotNull(typeConstraint->getSup().exp)) return true; + if (auto genTypeConstraint = as<GenericTypeConstraintDecl>(node)) + { + if (genTypeConstraint->whereTokenLoc.isValid()) + { + if (visitor.dispatchIfNotNull(genTypeConstraint->sub.exp)) + return true; + } + } } else if (auto typedefDecl = as<TypeDefDecl>(node)) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f62bb631e..813467743 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1753,6 +1753,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower lowerType(context, val->getSup()))); } + LoweredValInfo visitTypeEqualityWitness(TypeEqualityWitness* val) + { + auto subType = lowerType(context, val->getSub()); + auto supType = lowerType(context, val->getSup()); + auto witnessType = context->irBuilder->getWitnessTableType( + lowerType(context, val->getSup())); + return LoweredValInfo::simple(context->irBuilder->getTypeEqualityWitness(witnessType, subType, supType)); + } + LoweredValInfo visitTransitiveSubtypeWitness( TransitiveSubtypeWitness* val) { @@ -4989,6 +4998,19 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> auto superType = lowerType(context, expr->type); auto value = lowerRValueExpr(context, expr->valueArg); + // First, we check if the witness is a type equality witness. + // If so, we can simply emit a bit cast to the target type that should eventually + // fold out to a no-op. + // Note: if we are going to equivalent but not identical types in the future, + // then the cast between equivalent types shouldn't be as simple as a bit cast + // and will require actual coercion logic between the two types. + // For now, we don't support type equivalence witness so this is safe for + // equal types. + if (isTypeEqualityWitness(expr->witnessArg)) + { + return LoweredValInfo::simple(getBuilder()->emitBitCast(superType, getSimpleVal(context, value))); + } + // The actual operation that we need to perform here // depends on the kind of subtype relationship we // are making use of. @@ -5039,7 +5061,6 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return emitCastToConcreteSuperTypeRec(value, superType, expr->witnessArg); } } - SLANG_UNEXPECTED("unexpected case of subtype relationship"); UNREACHABLE_RETURN(LoweredValInfo()); } @@ -8400,8 +8421,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> for (auto constraintDecl : decl->getMembersOfType<GenericTypeConstraintDecl>()) { auto baseType = lowerType(context, constraintDecl->sup.type); - SLANG_ASSERT(baseType && baseType->getOp() == kIROp_InterfaceType); - constraintInterfaces.add((IRInterfaceType*)baseType); + if (baseType && baseType->getOp() == kIROp_InterfaceType) + constraintInterfaces.add((IRInterfaceType*)baseType); } auto assocType = context->irBuilder->getAssociatedType( constraintInterfaces.getArrayView().arrayView); @@ -10686,7 +10707,8 @@ LoweredValInfo emitDeclRef( for (auto argVal : genericSubst->getArgs()) { auto irArgVal = lowerSimpleVal(context, argVal); - SLANG_ASSERT(irArgVal); + if (!irArgVal) + continue; // It is possible that some of the arguments to the generic // represent conformances to conjunction types like `A & B`. diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 121c68be7..5729adb29 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -148,6 +148,12 @@ namespace Slang resetLookupScope(); } + void PushScope(Scope* newScope) + { + currentScope = newScope; + resetLookupScope(); + } + void pushScopeAndSetParent(ContainerDecl* containerDecl) { containerDecl->parentDecl = currentScope->containerDecl; @@ -306,6 +312,8 @@ namespace Slang static Expr* _parseGenericArg(Parser* parser); + static Expr* parsePrefixExpr(Parser* parser); + // static void Unexpected( @@ -721,6 +729,16 @@ namespace Slang return false; } + bool AdvanceIf(Parser* parser, char const* text, Token* outToken) + { + if (parser->LookAheadToken(text)) + { + *outToken = parser->ReadToken(); + return true; + } + return false; + } + /// Information on how to parse certain pairs of matches tokens struct MatchedTokenInfo { @@ -1496,10 +1514,47 @@ namespace Slang } else { - // default case is a type parameter - paramDecl = parser->astBuilder->create<GenericTypeParamDecl>(); - parser->FillPosition(paramDecl); - paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + // Disambiguate between a type parameter and a value parameter. + // If next token is "typename", then it is a type parameter. + bool isTypeParam = AdvanceIf(parser, "typename"); + if (!isTypeParam) + { + // Otherwise, if the next token is an identifier, followed by a colon, comma, '=' or '>', then it is a type parameter. + isTypeParam = parser->LookAheadToken(TokenType::Identifier); + auto nextNextTokenType = peekTokenType(parser, 1); + switch (nextNextTokenType) + { + case TokenType::Colon: + case TokenType::Comma: + case TokenType::OpGreater: + case TokenType::OpAssign: + break; + default: + isTypeParam = false; + break; + } + } + + if (isTypeParam) + { + // Parse as a type parameter. + paramDecl = parser->astBuilder->create<GenericTypeParamDecl>(); + parser->FillPosition(paramDecl); + paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + } + else + { + // Parse as a traditional syntax value parameter in the form of `type paramName`. + auto valueParamDecl = parser->astBuilder->create<GenericValueParamDecl>(); + parser->FillPosition(valueParamDecl); + valueParamDecl->type = parser->ParseTypeExp(); + valueParamDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); + if (AdvanceIf(parser, TokenType::OpAssign)) + { + valueParamDecl->initExpr = parser->ParseInitExpr(); + } + return valueParamDecl; + } } if (AdvanceIf(parser, TokenType::Colon)) { @@ -1600,7 +1655,43 @@ namespace Slang } else { - return parseInner(nullptr); + auto genericParent = parser->currentScope ? as<GenericDecl>(parser->currentScope->containerDecl) : nullptr; + return parseInner(genericParent); + } + } + + static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericParent) + { + if (!genericParent) + return; + Token whereToken; + while (AdvanceIf(parser, "where", &whereToken)) + { + auto subType = parser->ParseTypeExp(); + if (AdvanceIf(parser, TokenType::Colon)) + { + for (;;) + { + auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); + constraint->whereTokenLoc = whereToken.loc; + parser->FillPosition(constraint); + constraint->sub = subType; + constraint->sup = parser->ParseTypeExp(); + AddMember(genericParent, constraint); + if (!AdvanceIf(parser, TokenType::Comma)) + break; + } + } + else if (AdvanceIf(parser, TokenType::OpEql)) + { + auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>(); + constraint->whereTokenLoc = whereToken.loc; + constraint->isEqualityConstraint = true; + parser->FillPosition(constraint); + constraint->sub = subType; + constraint->sup = parser->ParseTypeExp(); + AddMember(genericParent, constraint); + } } } @@ -1702,7 +1793,7 @@ namespace Slang decl->loc = declaratorInfo.nameAndLoc.loc; decl->nameAndLoc = declaratorInfo.nameAndLoc; - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // HACK: The return type of the function will already have been // parsed in a scope that didn't include the function's generic @@ -1731,6 +1822,12 @@ namespace Slang } _parseOptSemantics(parser, decl); + + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); + decl->body = parseOptBody(parser); if (auto block = as<BlockStmt>(decl->body)) { @@ -2597,7 +2694,7 @@ namespace Slang } else if (parser->LookAheadToken("expand") || parser->LookAheadToken("each")) { - typeSpec.expr = parser->ParseExpression(); + typeSpec.expr = parsePrefixExpr(parser); return typeSpec; } // Uncomment should we decide to enable (a,b,c) tuple types @@ -3335,12 +3432,13 @@ namespace Slang static NodeBase* parseExtensionDecl(Parser* parser, void* /*userData*/) { - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { ExtensionDecl* decl = parser->astBuilder->create<ExtensionDecl>(); parser->FillPosition(decl); decl->targetType = parser->ParseTypeExp(); parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parseDeclBody(parser, decl); return decl; }); @@ -3357,16 +3455,27 @@ namespace Slang parser->FillPosition(paramConstraint); // substitution needs to be filled during check - Type* paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); + Type* paramType = nullptr; + if (as<GenericTypeParamDeclBase>(decl)) + { + paramType = DeclRefType::create(parser->astBuilder, DeclRef<Decl>(decl)); - SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); - paramTypeExpr->loc = decl->loc; - paramTypeExpr->base.type = paramType; - paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); + SharedTypeExpr* paramTypeExpr = parser->astBuilder->create<SharedTypeExpr>(); + paramTypeExpr->loc = decl->loc; + paramTypeExpr->base.type = paramType; + paramTypeExpr->type = QualType(parser->astBuilder->getTypeType(paramType)); - paramConstraint->sub = TypeExp(paramTypeExpr); - paramConstraint->sup = parser->ParseTypeExp(); + paramConstraint->sub = TypeExp(paramTypeExpr); + } + else if (as<AssocTypeDecl>(decl)) + { + auto varExpr = parser->astBuilder->create<VarExpr>(); + varExpr->scope = parser->currentScope; + varExpr->name = decl->getName(); + paramConstraint->sub.exp = varExpr; + } + paramConstraint->sup = parser->ParseTypeExp(); AddMember(decl, paramConstraint); } while (AdvanceIf(parser, TokenType::Comma)); } @@ -3380,6 +3489,7 @@ namespace Slang assocTypeDecl->nameAndLoc = NameLoc(nameToken); assocTypeDecl->loc = nameToken.loc; parseOptionalGenericConstraints(parser, assocTypeDecl); + maybeParseGenericConstraints(parser, assocTypeDecl); parser->ReadToken(TokenType::Semicolon); return assocTypeDecl; } @@ -3424,11 +3534,12 @@ namespace Slang AdvanceIf(parser, TokenType::CompletionRequest); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // We allow for an inheritance clause on a `struct` // so that it can conform to interfaces. parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parseDeclBody(parser, decl); return decl; }); @@ -3661,7 +3772,7 @@ namespace Slang { ConstructorDecl* decl = parser->astBuilder->create<ConstructorDecl>(); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { // Note: we leave the source location of this decl as invalid, to // trigger the fallback logic that fills in the location of the @@ -3680,6 +3791,10 @@ namespace Slang decl->nameAndLoc.name = getName(parser, "$init"); parseParameterList(parser, decl); + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); decl->body = parseOptBody(parser); @@ -3997,7 +4112,7 @@ namespace Slang parser->FillPosition(decl); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { parser->PushScope(decl); parseModernParamList(parser, decl); @@ -4009,6 +4124,10 @@ namespace Slang { decl->returnType = parser->ParseTypeExp(); } + auto funcScope = parser->currentScope; + parser->PopScope(); + maybeParseGenericConstraints(parser, genericParent); + parser->PushScope(funcScope); decl->body = parseOptBody(parser); if (auto blockStmt = as<BlockStmt>(decl->body)) decl->closingSourceLoc = blockStmt->closingSourceLoc; @@ -4025,8 +4144,9 @@ namespace Slang parser->FillPosition(decl); decl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { + maybeParseGenericConstraints(parser, genericParent); if( expect(parser, TokenType::OpAssign) ) { decl->type = parser->ParseTypeExp(); @@ -4860,7 +4980,7 @@ namespace Slang rs->nameAndLoc.name = generateName(this); rs->nameAndLoc.loc = rs->loc; } - return parseOptGenericDecl(this, [&](GenericDecl*) + return parseOptGenericDecl(this, [&](GenericDecl* genericParent) { // We allow for an inheritance clause on a `struct` // so that it can conform to interfaces. @@ -4878,6 +4998,7 @@ namespace Slang rs->hasBody = false; return rs; } + maybeParseGenericConstraints(this, genericParent); parseDeclBody(this, rs); return rs; }); @@ -4977,9 +5098,10 @@ namespace Slang addModifier(decl, parser->astBuilder->create<TransparentModifier>()); } - return parseOptGenericDecl(parser, [&](GenericDecl*) + return parseOptGenericDecl(parser, [&](GenericDecl* genericParent) { parseOptionalInheritanceClause(parser, decl); + maybeParseGenericConstraints(parser, genericParent); parser->ReadToken(TokenType::LBrace); Token closingToken; parser->pushScopeAndSetParent(decl); @@ -7583,7 +7705,7 @@ namespace Slang { EachExpr* eachExpr = parser->astBuilder->create<EachExpr>(); eachExpr->loc = loc; - eachExpr->baseExpr = parser->ParseLeafExpression(); + eachExpr->baseExpr = parsePostfixExpr(parser); return eachExpr; } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index a55f0eb1a..ec7169e04 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt RefPtr<WitnessTable> result = new WitnessTable(); result->baseType = as<Type>(newBaseType); result->witnessedType = as<Type>(newWitnessedType); - for (auto requirement : m_requirements) + for (auto requirement : m_requirementDictionary) { auto newRequirement = requirement.value.specialize(astBuilder, subst); result->add(requirement.key, newRequirement); @@ -500,7 +500,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt void WitnessTable::add(Decl* decl, RequirementWitness const& witness) { - m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness)); m_requirementDictionary.add(decl, witness); } |
