summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-20 20:51:57 -0700
committerGitHub <noreply@github.com>2024-08-20 20:51:57 -0700
commitf9f6a28df40f418ddd0c8ff3b9cacccdb085e202 (patch)
treea6bafa63cee4f9bbcfe496de54af6e5727bb021e /source/slang/slang-check-constraint.cpp
parent03e1e17745920c8e3a7b6f4e3b1e64062589604a (diff)
Support dependent generic constraints. (#4870)
* Support dependent generic constraints. * Fix warning. * Update comment. * Fix. * Add a test case to verify fix of #3804. * Address review.
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp294
1 files changed, 185 insertions, 109 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 0f6da156d..afcde8a5b 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -57,6 +57,7 @@
namespace Slang
{
Type* SemanticsVisitor::TryJoinVectorAndScalarType(
+ ConstraintSystem* constraints,
VectorExpressionType* vectorType,
BasicExpressionType* scalarType)
{
@@ -65,6 +66,7 @@ namespace Slang
// That is, the join of a vector and a scalar type is
// a vector type with a joined element type.
auto joinElementType = TryJoinTypes(
+ constraints,
vectorType->getElementType(),
scalarType);
if(!joinElementType)
@@ -76,6 +78,7 @@ namespace Slang
}
Type* SemanticsVisitor::_tryJoinTypeWithInterface(
+ ConstraintSystem* constraints,
Type* type,
Type* interfaceType)
{
@@ -158,6 +161,31 @@ namespace Slang
return bestType;
}
+ // If `interfaceType` represents some generic interface type, such as `IFoo<T>`, and `type` conforms to
+ // some `IFoo<X>`, then we should attempt to unify the them to discover constraints for
+ // `T`.
+ if (auto interfaceDeclRef = isDeclRefTypeOf<InterfaceDecl>(interfaceType))
+ {
+ if (as<GenericAppDeclRef>(interfaceDeclRef.declRefBase))
+ {
+ auto inheritanceInfo = getShared()->getInheritanceInfo(type);
+ for (auto facet : inheritanceInfo.facets)
+ {
+ if (facet->origin.declRef.getDecl() == interfaceDeclRef.getDecl())
+ {
+ auto unificationResult = TryUnifyTypes(
+ *constraints,
+ ValUnificationContext(),
+ QualType(facet->getType()),
+ interfaceType);
+
+ if (unificationResult)
+ return type;
+ }
+ }
+ }
+ }
+
// For all other cases, we will just bail out for now.
//
// TODO: In the future we should build some kind of side data structure
@@ -174,6 +202,7 @@ namespace Slang
}
Type* SemanticsVisitor::TryJoinTypes(
+ ConstraintSystem* constraints,
QualType left,
QualType right)
{
@@ -201,7 +230,7 @@ namespace Slang
// We can also join a vector and a scalar
if(auto rightVector = as<VectorExpressionType>(right))
{
- return TryJoinVectorAndScalarType(rightVector, leftBasic);
+ return TryJoinVectorAndScalarType(constraints, rightVector, leftBasic);
}
}
@@ -217,6 +246,7 @@ namespace Slang
// Try to join the element types
auto joinElementType = TryJoinTypes(
+ constraints,
QualType(leftVector->getElementType(), left.isLeftValue),
QualType(rightVector->getElementType(), right.isLeftValue));
if(!joinElementType)
@@ -230,7 +260,7 @@ namespace Slang
// We can also join a vector and a scalar
if(auto rightBasic = as<BasicExpressionType>(right))
{
- return TryJoinVectorAndScalarType(leftVector, rightBasic);
+ return TryJoinVectorAndScalarType(constraints, leftVector, rightBasic);
}
}
@@ -240,7 +270,7 @@ namespace Slang
if( auto leftInterfaceRef = leftDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
- return _tryJoinTypeWithInterface(right, left);
+ return _tryJoinTypeWithInterface(constraints, right, left);
}
}
if(auto rightDeclRefType = as<DeclRefType>(right))
@@ -248,7 +278,7 @@ namespace Slang
if( auto rightInterfaceRef = rightDeclRefType->getDeclRef().as<InterfaceDecl>() )
{
//
- return _tryJoinTypeWithInterface(left, right);
+ return _tryJoinTypeWithInterface(constraints, left, right);
}
}
@@ -263,6 +293,7 @@ namespace Slang
for (Index i = 0; i < leftTypePack->getTypeCount(); ++i)
{
auto joinedType = TryJoinTypes(
+ constraints,
QualType(leftTypePack->getElementType(i), left.isLeftValue),
QualType(rightTypePack->getElementType(i), right.isLeftValue));
if(!joinedType)
@@ -285,6 +316,8 @@ namespace Slang
ArrayView<Val*> knownGenericArgs,
ConversionCost& outBaseCost)
{
+ ensureDecl(genericDeclRef.getDecl(), DeclCheckState::ReadyForLookup);
+
outBaseCost = kConversionCost_None;
// For now the "solver" is going to be ridiculously simplistic.
@@ -310,7 +343,7 @@ namespace Slang
return DeclRef<Decl>();
}
- // Once have built up the full list of constraints we are trying to satisfy,
+ // Once have built up the initial list of constraints we are trying to satisfy,
// we will attempt to solve for each parameter in a way that satisfies all
// the constraints that apply to that parameter.
//
@@ -321,7 +354,7 @@ namespace Slang
// solution for how to assign the parameters in a way that satisfies all
// the constraints.
//
- List<Val*> args;
+ ShortList<Val*> args;
// If the context is such that some of the arguments are already specified
// or known, we need to go ahead and use those arguments direclty (whether
@@ -337,38 +370,44 @@ namespace Slang
}
}
- // We will then iterate over the explicit parameters of the generic
- // and try to solve for each.
- //
- Count paramCounter = 0;
- for (auto m : getMembers(m_astBuilder, genericDeclRef))
+ // The state of currently solved arguments.
+ struct SolvedArg
+ {
+ IntVal* val = nullptr;
+ bool isOptional = true;
+ ShortList<QualType, 8> types;
+ };
+ ShortList<SolvedArg> solvedArgs;
+
+ // We will then iterate over the constraints trying to solve all generic parameters.
+ // Note that we do not use ranged for here, because processing one constraint may lead to
+ // new constraints being discovered.
+ for (Index constraintIndex = 0; constraintIndex < system->constraints.getCount(); constraintIndex++)
{
- if (auto typeParam = m.as<GenericTypeParamDeclBase>())
+ // Note: it is important to keep a copy of the constraint here instead of
+ // using a reference, because the constraint list may be modified during the
+ // loop as we discover new constraints.
+ //
+ auto c = system->constraints[constraintIndex];
+ if (auto typeParam = as<GenericTypeParamDeclBase>(c.decl))
{
- // If the parameter is a type pack, then we may have
- // constraints that apply to invidual elements of the pack.
- // We will need to handle the type pack case slightly differently.
- //
- bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr;
-
+ SLANG_ASSERT(typeParam->parameterIndex != -1);
// If the parameter is one where we already know
// the argument value to use, we don't bother with
// trying to solve for it, and treat any constraints
// on such a parameter as implicitly solved-for.
//
- Index paramIndex = paramCounter++;
- if (paramIndex < knownGenericArgCount)
+ if (typeParam->parameterIndex < knownGenericArgCount)
{
- for (auto& c : system->constraints)
- {
- if (c.decl != typeParam.getDecl())
- continue;
-
- c.satisfied = true;
- }
+ system->constraints[constraintIndex].satisfied = true;
continue;
}
+ // If the parameter is a type pack, then we may have
+ // constraints that apply to invidual elements of the pack.
+ // We will need to handle the type pack case slightly differently.
+ //
+ bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr;
// We will use a temporary list to hold the resolved types
// for this generic parameter.
@@ -376,50 +415,128 @@ namespace Slang
// in the list. For type pack parameters, there can be one type
// for each element in the pack.
//
- ShortList<QualType> types;
+ if (solvedArgs.getCount() <= typeParam->parameterIndex)
+ {
+ solvedArgs.setCount(typeParam->parameterIndex + 1);
+ }
+ auto& types = solvedArgs[typeParam->parameterIndex].types;
if (!isPack)
types.setCount(1);
- bool typeConstraintOptional = true;
+ bool& typeConstraintOptional = solvedArgs[typeParam->parameterIndex].isOptional;
- for (auto& c : system->constraints)
+ QualType* ptype = nullptr;
+ if (isPack)
{
- if (c.decl != typeParam.getDecl())
- continue;
- QualType* ptype = nullptr;
- if (isPack)
- {
- types.setCount(Math::Max(types.getCount(), c.indexInPack + 1));
- ptype = &types[c.indexInPack];
- }
- else
- ptype = &types[0];
- QualType& type = *ptype;
+ types.setCount(Math::Max(types.getCount(), c.indexInPack + 1));
+ ptype = &types[c.indexInPack];
+ }
+ else
+ ptype = &types[0];
+ QualType& type = *ptype;
- auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
- SLANG_RELEASE_ASSERT(cType);
+ auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
+ SLANG_RELEASE_ASSERT(cType);
- if (!type || (typeConstraintOptional && !c.isOptional))
+ if (!type || (typeConstraintOptional && !c.isOptional))
+ {
+ type = cType;
+ typeConstraintOptional = c.isOptional;
+ }
+ else if (!typeConstraintOptional)
+ {
+ // If the type parameter is already constrained to a known type,
+ // we need to make sure our resolved type can satisfy both constraints.
+ // We do so by updating the resolved type to be the "join" of the current
+ // solution and the type in the new constraint. If such join cannot be found,
+ // it means it is not possible to have a compatible solution that meets all
+ // constraints and we should fail.
+ //
+ // Another detail here is that during type joining, we may discover
+ // new constraints from the base types of the types being joined.
+ // We will pass the constraint system to `TryJoinTypes` which can
+ // add new constraints to the system, and we will process the new constraints
+ // in the next iteration.
+ //
+ auto joinType = TryJoinTypes(system, type, cType);
+ if (!joinType)
{
- type = cType;
- typeConstraintOptional = c.isOptional;
+ // failure!
+ return DeclRef<Decl>();
}
- else if (!typeConstraintOptional)
+ type = QualType(joinType, type.isLeftValue || cType.isLeftValue);
+ }
+
+ c.satisfied = true;
+ }
+ else if (auto valParam = as<GenericValueParamDecl>(c.decl))
+ {
+ SLANG_ASSERT(valParam->parameterIndex != -1);
+
+ // If the parameter is one where we already know
+ // the argument value to use, we don't bother with
+ // trying to solve for it, and treat any constraints
+ // on such a parameter as implicitly solved-for.
+ //
+ if (valParam->parameterIndex < knownGenericArgCount)
+ {
+ system->constraints[constraintIndex].satisfied = true;
+ continue;
+ }
+
+ if (solvedArgs.getCount() <= valParam->parameterIndex)
+ solvedArgs.setCount(valParam->parameterIndex + 1);
+ IntVal*& val = solvedArgs[valParam->parameterIndex].val;
+ bool& valOptional = solvedArgs[valParam->parameterIndex].isOptional;
+
+ auto cVal = as<IntVal>(c.val);
+ SLANG_RELEASE_ASSERT(cVal);
+
+ if (!val || (valOptional && !c.isOptional))
+ {
+ val = cVal;
+ valOptional = c.isOptional;
+ }
+ else
+ {
+ if(!valOptional && !val->equals(cVal))
{
- auto joinType = TryJoinTypes(type, cType);
- if (!joinType)
- {
- // failure!
- return DeclRef<Decl>();
- }
- type = QualType(joinType, type.isLeftValue || cType.isLeftValue);
+ // failure!
+ return DeclRef<Decl>();
}
-
- c.satisfied = true;
}
+ c.satisfied = true;
+ }
+ system->constraints[constraintIndex].satisfied = c.satisfied;
+ }
+
+ // After we processed all constraints, `solvedTypes` and `solvedVals`
+ // should have been filled with the resolved types and values for the
+ // generic parameters. We can now verify if they are complete and consolidate
+ // them into final argument list.
+ for (auto member : genericDeclRef.getDecl()->members)
+ {
+ if (auto typeParam = as<GenericTypeParamDeclBase>(member))
+ {
+ SLANG_ASSERT(typeParam->parameterIndex != -1);
+
+ if (typeParam->parameterIndex < knownGenericArgCount)
+ continue;
+ bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr;
+ if (typeParam->parameterIndex >= solvedArgs.getCount())
+ {
+ // If the parameter is not a type pack and we don't have a
+ // resolved type for it, we should fail.
+ if (!isPack)
+ return DeclRef<Decl>();
+ // If the parameter is a type pack, we should add an empty
+ // type list to solvedTypes.
+ solvedArgs.setCount(typeParam->parameterIndex + 1);
+ }
+ auto& types = solvedArgs[typeParam->parameterIndex].types;
// Fail if any of the resolved type element is empty.
- for (auto t: types)
+ for (auto t : types)
{
if (!t)
return DeclRef<Decl>();
@@ -427,7 +544,9 @@ namespace Slang
if (!isPack)
{
// If the generic parameter is not a pack, we can simply add the first type.
- SLANG_ASSERT(types.getCount() == 1);
+ if (types.getCount() != 1)
+ return DeclRef<Decl>();
+
args.add(types[0]);
}
else
@@ -453,56 +572,17 @@ namespace Slang
}
}
}
- else if (auto valParam = m.as<GenericValueParamDecl>())
+ else if (auto valParam = as<GenericValueParamDecl>(member))
{
- // If the parameter is one where we already know
- // the argument value to use, we don't bother with
- // trying to solve for it, and treat any constraints
- // on such a parameter as implicitly solved-for.
- //
- Index paramIndex = paramCounter++;
- if (paramIndex < knownGenericArgCount)
- {
- for (auto& c : system->constraints)
- {
- if (c.decl != typeParam.getDecl())
- continue;
+ SLANG_ASSERT(valParam->parameterIndex != -1);
- c.satisfied = true;
- }
+ if (valParam->parameterIndex < knownGenericArgCount)
continue;
- }
-
- // TODO(tfoley): maybe support more than integers some day?
- // TODO(tfoley): figure out how this needs to interact with
- // compile-time integers that aren't just constants...
- IntVal* val = nullptr;
- bool valOptional = true;
- for (auto& c : system->constraints)
- {
- if (c.decl != valParam.getDecl())
- continue;
-
- auto cVal = as<IntVal>(c.val);
- SLANG_RELEASE_ASSERT(cVal);
- if (!val || (valOptional && !c.isOptional))
- {
- val = cVal;
- valOptional = c.isOptional;
- }
- else
- {
- if(!valOptional && !val->equals(cVal))
- {
- // failure!
- return DeclRef<Decl>();
- }
- }
-
- c.satisfied = true;
- }
+ if (valParam->parameterIndex >= solvedArgs.getCount())
+ return DeclRef<Decl>();
+ auto val = solvedArgs[valParam->parameterIndex].val;
if (!val)
{
// failure!
@@ -510,10 +590,6 @@ namespace Slang
}
args.add(val);
}
- else
- {
- // ignore anything that isn't a generic parameter
- }
}
// After we've solved for the explicit arguments, we need to
@@ -537,7 +613,7 @@ namespace Slang
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
- genericDeclRef, args.getArrayView(), constraintDecl).as<GenericTypeConstraintDecl>();
+ genericDeclRef, args.getArrayView().arrayView, constraintDecl).as<GenericTypeConstraintDecl>();
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
@@ -597,7 +673,7 @@ namespace Slang
}
}
- return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView());
+ return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView);
}
bool SemanticsVisitor::TryUnifyVals(