diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-20 20:51:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-20 20:51:57 -0700 |
| commit | f9f6a28df40f418ddd0c8ff3b9cacccdb085e202 (patch) | |
| tree | a6bafa63cee4f9bbcfe496de54af6e5727bb021e /source/slang/slang-check-constraint.cpp | |
| parent | 03e1e17745920c8e3a7b6f4e3b1e64062589604a (diff) | |
Support dependent generic constraints. (#4870)
* Support dependent generic constraints.
* Fix warning.
* Update comment.
* Fix.
* Add a test case to verify fix of #3804.
* Address review.
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 294 |
1 files changed, 185 insertions, 109 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 0f6da156d..afcde8a5b 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -57,6 +57,7 @@ namespace Slang { Type* SemanticsVisitor::TryJoinVectorAndScalarType( + ConstraintSystem* constraints, VectorExpressionType* vectorType, BasicExpressionType* scalarType) { @@ -65,6 +66,7 @@ namespace Slang // That is, the join of a vector and a scalar type is // a vector type with a joined element type. auto joinElementType = TryJoinTypes( + constraints, vectorType->getElementType(), scalarType); if(!joinElementType) @@ -76,6 +78,7 @@ namespace Slang } Type* SemanticsVisitor::_tryJoinTypeWithInterface( + ConstraintSystem* constraints, Type* type, Type* interfaceType) { @@ -158,6 +161,31 @@ namespace Slang return bestType; } + // If `interfaceType` represents some generic interface type, such as `IFoo<T>`, and `type` conforms to + // some `IFoo<X>`, then we should attempt to unify the them to discover constraints for + // `T`. + if (auto interfaceDeclRef = isDeclRefTypeOf<InterfaceDecl>(interfaceType)) + { + if (as<GenericAppDeclRef>(interfaceDeclRef.declRefBase)) + { + auto inheritanceInfo = getShared()->getInheritanceInfo(type); + for (auto facet : inheritanceInfo.facets) + { + if (facet->origin.declRef.getDecl() == interfaceDeclRef.getDecl()) + { + auto unificationResult = TryUnifyTypes( + *constraints, + ValUnificationContext(), + QualType(facet->getType()), + interfaceType); + + if (unificationResult) + return type; + } + } + } + } + // For all other cases, we will just bail out for now. // // TODO: In the future we should build some kind of side data structure @@ -174,6 +202,7 @@ namespace Slang } Type* SemanticsVisitor::TryJoinTypes( + ConstraintSystem* constraints, QualType left, QualType right) { @@ -201,7 +230,7 @@ namespace Slang // We can also join a vector and a scalar if(auto rightVector = as<VectorExpressionType>(right)) { - return TryJoinVectorAndScalarType(rightVector, leftBasic); + return TryJoinVectorAndScalarType(constraints, rightVector, leftBasic); } } @@ -217,6 +246,7 @@ namespace Slang // Try to join the element types auto joinElementType = TryJoinTypes( + constraints, QualType(leftVector->getElementType(), left.isLeftValue), QualType(rightVector->getElementType(), right.isLeftValue)); if(!joinElementType) @@ -230,7 +260,7 @@ namespace Slang // We can also join a vector and a scalar if(auto rightBasic = as<BasicExpressionType>(right)) { - return TryJoinVectorAndScalarType(leftVector, rightBasic); + return TryJoinVectorAndScalarType(constraints, leftVector, rightBasic); } } @@ -240,7 +270,7 @@ namespace Slang if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // - return _tryJoinTypeWithInterface(right, left); + return _tryJoinTypeWithInterface(constraints, right, left); } } if(auto rightDeclRefType = as<DeclRefType>(right)) @@ -248,7 +278,7 @@ namespace Slang if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() ) { // - return _tryJoinTypeWithInterface(left, right); + return _tryJoinTypeWithInterface(constraints, left, right); } } @@ -263,6 +293,7 @@ namespace Slang for (Index i = 0; i < leftTypePack->getTypeCount(); ++i) { auto joinedType = TryJoinTypes( + constraints, QualType(leftTypePack->getElementType(i), left.isLeftValue), QualType(rightTypePack->getElementType(i), right.isLeftValue)); if(!joinedType) @@ -285,6 +316,8 @@ namespace Slang ArrayView<Val*> knownGenericArgs, ConversionCost& outBaseCost) { + ensureDecl(genericDeclRef.getDecl(), DeclCheckState::ReadyForLookup); + outBaseCost = kConversionCost_None; // For now the "solver" is going to be ridiculously simplistic. @@ -310,7 +343,7 @@ namespace Slang return DeclRef<Decl>(); } - // Once have built up the full list of constraints we are trying to satisfy, + // Once have built up the initial list of constraints we are trying to satisfy, // we will attempt to solve for each parameter in a way that satisfies all // the constraints that apply to that parameter. // @@ -321,7 +354,7 @@ namespace Slang // solution for how to assign the parameters in a way that satisfies all // the constraints. // - List<Val*> args; + ShortList<Val*> args; // If the context is such that some of the arguments are already specified // or known, we need to go ahead and use those arguments direclty (whether @@ -337,38 +370,44 @@ namespace Slang } } - // We will then iterate over the explicit parameters of the generic - // and try to solve for each. - // - Count paramCounter = 0; - for (auto m : getMembers(m_astBuilder, genericDeclRef)) + // The state of currently solved arguments. + struct SolvedArg + { + IntVal* val = nullptr; + bool isOptional = true; + ShortList<QualType, 8> types; + }; + ShortList<SolvedArg> solvedArgs; + + // We will then iterate over the constraints trying to solve all generic parameters. + // Note that we do not use ranged for here, because processing one constraint may lead to + // new constraints being discovered. + for (Index constraintIndex = 0; constraintIndex < system->constraints.getCount(); constraintIndex++) { - if (auto typeParam = m.as<GenericTypeParamDeclBase>()) + // Note: it is important to keep a copy of the constraint here instead of + // using a reference, because the constraint list may be modified during the + // loop as we discover new constraints. + // + auto c = system->constraints[constraintIndex]; + if (auto typeParam = as<GenericTypeParamDeclBase>(c.decl)) { - // If the parameter is a type pack, then we may have - // constraints that apply to invidual elements of the pack. - // We will need to handle the type pack case slightly differently. - // - bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr; - + SLANG_ASSERT(typeParam->parameterIndex != -1); // If the parameter is one where we already know // the argument value to use, we don't bother with // trying to solve for it, and treat any constraints // on such a parameter as implicitly solved-for. // - Index paramIndex = paramCounter++; - if (paramIndex < knownGenericArgCount) + if (typeParam->parameterIndex < knownGenericArgCount) { - for (auto& c : system->constraints) - { - if (c.decl != typeParam.getDecl()) - continue; - - c.satisfied = true; - } + system->constraints[constraintIndex].satisfied = true; continue; } + // If the parameter is a type pack, then we may have + // constraints that apply to invidual elements of the pack. + // We will need to handle the type pack case slightly differently. + // + bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr; // We will use a temporary list to hold the resolved types // for this generic parameter. @@ -376,50 +415,128 @@ namespace Slang // in the list. For type pack parameters, there can be one type // for each element in the pack. // - ShortList<QualType> types; + if (solvedArgs.getCount() <= typeParam->parameterIndex) + { + solvedArgs.setCount(typeParam->parameterIndex + 1); + } + auto& types = solvedArgs[typeParam->parameterIndex].types; if (!isPack) types.setCount(1); - bool typeConstraintOptional = true; + bool& typeConstraintOptional = solvedArgs[typeParam->parameterIndex].isOptional; - for (auto& c : system->constraints) + QualType* ptype = nullptr; + if (isPack) { - if (c.decl != typeParam.getDecl()) - continue; - QualType* ptype = nullptr; - if (isPack) - { - types.setCount(Math::Max(types.getCount(), c.indexInPack + 1)); - ptype = &types[c.indexInPack]; - } - else - ptype = &types[0]; - QualType& type = *ptype; + types.setCount(Math::Max(types.getCount(), c.indexInPack + 1)); + ptype = &types[c.indexInPack]; + } + else + ptype = &types[0]; + QualType& type = *ptype; - auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue); - SLANG_RELEASE_ASSERT(cType); + auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue); + SLANG_RELEASE_ASSERT(cType); - if (!type || (typeConstraintOptional && !c.isOptional)) + if (!type || (typeConstraintOptional && !c.isOptional)) + { + type = cType; + typeConstraintOptional = c.isOptional; + } + else if (!typeConstraintOptional) + { + // If the type parameter is already constrained to a known type, + // we need to make sure our resolved type can satisfy both constraints. + // We do so by updating the resolved type to be the "join" of the current + // solution and the type in the new constraint. If such join cannot be found, + // it means it is not possible to have a compatible solution that meets all + // constraints and we should fail. + // + // Another detail here is that during type joining, we may discover + // new constraints from the base types of the types being joined. + // We will pass the constraint system to `TryJoinTypes` which can + // add new constraints to the system, and we will process the new constraints + // in the next iteration. + // + auto joinType = TryJoinTypes(system, type, cType); + if (!joinType) { - type = cType; - typeConstraintOptional = c.isOptional; + // failure! + return DeclRef<Decl>(); } - else if (!typeConstraintOptional) + type = QualType(joinType, type.isLeftValue || cType.isLeftValue); + } + + c.satisfied = true; + } + else if (auto valParam = as<GenericValueParamDecl>(c.decl)) + { + SLANG_ASSERT(valParam->parameterIndex != -1); + + // If the parameter is one where we already know + // the argument value to use, we don't bother with + // trying to solve for it, and treat any constraints + // on such a parameter as implicitly solved-for. + // + if (valParam->parameterIndex < knownGenericArgCount) + { + system->constraints[constraintIndex].satisfied = true; + continue; + } + + if (solvedArgs.getCount() <= valParam->parameterIndex) + solvedArgs.setCount(valParam->parameterIndex + 1); + IntVal*& val = solvedArgs[valParam->parameterIndex].val; + bool& valOptional = solvedArgs[valParam->parameterIndex].isOptional; + + auto cVal = as<IntVal>(c.val); + SLANG_RELEASE_ASSERT(cVal); + + if (!val || (valOptional && !c.isOptional)) + { + val = cVal; + valOptional = c.isOptional; + } + else + { + if(!valOptional && !val->equals(cVal)) { - auto joinType = TryJoinTypes(type, cType); - if (!joinType) - { - // failure! - return DeclRef<Decl>(); - } - type = QualType(joinType, type.isLeftValue || cType.isLeftValue); + // failure! + return DeclRef<Decl>(); } - - c.satisfied = true; } + c.satisfied = true; + } + system->constraints[constraintIndex].satisfied = c.satisfied; + } + + // After we processed all constraints, `solvedTypes` and `solvedVals` + // should have been filled with the resolved types and values for the + // generic parameters. We can now verify if they are complete and consolidate + // them into final argument list. + for (auto member : genericDeclRef.getDecl()->members) + { + if (auto typeParam = as<GenericTypeParamDeclBase>(member)) + { + SLANG_ASSERT(typeParam->parameterIndex != -1); + + if (typeParam->parameterIndex < knownGenericArgCount) + continue; + bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr; + if (typeParam->parameterIndex >= solvedArgs.getCount()) + { + // If the parameter is not a type pack and we don't have a + // resolved type for it, we should fail. + if (!isPack) + return DeclRef<Decl>(); + // If the parameter is a type pack, we should add an empty + // type list to solvedTypes. + solvedArgs.setCount(typeParam->parameterIndex + 1); + } + auto& types = solvedArgs[typeParam->parameterIndex].types; // Fail if any of the resolved type element is empty. - for (auto t: types) + for (auto t : types) { if (!t) return DeclRef<Decl>(); @@ -427,7 +544,9 @@ namespace Slang if (!isPack) { // If the generic parameter is not a pack, we can simply add the first type. - SLANG_ASSERT(types.getCount() == 1); + if (types.getCount() != 1) + return DeclRef<Decl>(); + args.add(types[0]); } else @@ -453,56 +572,17 @@ namespace Slang } } } - else if (auto valParam = m.as<GenericValueParamDecl>()) + else if (auto valParam = as<GenericValueParamDecl>(member)) { - // If the parameter is one where we already know - // the argument value to use, we don't bother with - // trying to solve for it, and treat any constraints - // on such a parameter as implicitly solved-for. - // - Index paramIndex = paramCounter++; - if (paramIndex < knownGenericArgCount) - { - for (auto& c : system->constraints) - { - if (c.decl != typeParam.getDecl()) - continue; + SLANG_ASSERT(valParam->parameterIndex != -1); - c.satisfied = true; - } + if (valParam->parameterIndex < knownGenericArgCount) continue; - } - - // TODO(tfoley): maybe support more than integers some day? - // TODO(tfoley): figure out how this needs to interact with - // compile-time integers that aren't just constants... - IntVal* val = nullptr; - bool valOptional = true; - for (auto& c : system->constraints) - { - if (c.decl != valParam.getDecl()) - continue; - - auto cVal = as<IntVal>(c.val); - SLANG_RELEASE_ASSERT(cVal); - if (!val || (valOptional && !c.isOptional)) - { - val = cVal; - valOptional = c.isOptional; - } - else - { - if(!valOptional && !val->equals(cVal)) - { - // failure! - return DeclRef<Decl>(); - } - } - - c.satisfied = true; - } + if (valParam->parameterIndex >= solvedArgs.getCount()) + return DeclRef<Decl>(); + auto val = solvedArgs[valParam->parameterIndex].val; if (!val) { // failure! @@ -510,10 +590,6 @@ namespace Slang } args.add(val); } - else - { - // ignore anything that isn't a generic parameter - } } // After we've solved for the explicit arguments, we need to @@ -537,7 +613,7 @@ namespace Slang for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef( - genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>(); + genericDeclRef, args.getArrayView().arrayView, constraintDecl).as<GenericTypeConstraintDecl>(); // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); @@ -597,7 +673,7 @@ namespace Slang } } - return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView()); + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView); } bool SemanticsVisitor::TryUnifyVals( |
