summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-14 18:41:48 -0700
committerGitHub <noreply@github.com>2024-08-14 18:41:48 -0700
commit071f1b6062b459928ebfd6f2f60a8d6ad021112b (patch)
tree2ba65eb40f39701db6fc775a9258ec8079d161a0 /source/slang/slang-check-constraint.cpp
parent35a3d32c87f079749f6b100d01b289c3da02d7d6 (diff)
Variadic Generics Part 1: parsing and type checking. (#4833)
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp273
1 files changed, 239 insertions, 34 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 1195ed1f9..0f6da156d 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -252,6 +252,27 @@ namespace Slang
}
}
+ // We can recursively join two TypePacks.
+ if (auto leftTypePack = as<ConcreteTypePack>(left))
+ {
+ if (auto rightTypePack = as<ConcreteTypePack>(right))
+ {
+ if(leftTypePack->getTypeCount() != rightTypePack->getTypeCount())
+ return nullptr;
+ ShortList<Type*> joinedTypes;
+ for (Index i = 0; i < leftTypePack->getTypeCount(); ++i)
+ {
+ auto joinedType = TryJoinTypes(
+ QualType(leftTypePack->getElementType(i), left.isLeftValue),
+ QualType(rightTypePack->getElementType(i), right.isLeftValue));
+ if(!joinedType)
+ return nullptr;
+ joinedTypes.add(joinedType);
+ }
+ return m_astBuilder->getTypePack(joinedTypes.getArrayView().arrayView);
+ }
+ }
+
// TODO: all the cases for vectors apply to matrices too!
// Default case is that we just fail.
@@ -285,7 +306,7 @@ namespace Slang
// that `X<T>.IndexType == T`.
for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) )
{
- if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef)))
+ if(!TryUnifyTypes(*system, ValUnificationContext(), getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef)))
return DeclRef<Decl>();
}
@@ -322,8 +343,14 @@ namespace Slang
Count paramCounter = 0;
for (auto m : getMembers(m_astBuilder, genericDeclRef))
{
- if (auto typeParam = m.as<GenericTypeParamDecl>())
+ if (auto typeParam = m.as<GenericTypeParamDeclBase>())
{
+ // If the parameter is a type pack, then we may have
+ // constraints that apply to invidual elements of the pack.
+ // We will need to handle the type pack case slightly differently.
+ //
+ bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr;
+
// If the parameter is one where we already know
// the argument value to use, we don't bother with
// trying to solve for it, and treat any constraints
@@ -342,13 +369,32 @@ namespace Slang
continue;
}
- QualType type;
+
+ // We will use a temporary list to hold the resolved types
+ // for this generic parameter.
+ // For normal type parameters, there should be only one type
+ // in the list. For type pack parameters, there can be one type
+ // for each element in the pack.
+ //
+ ShortList<QualType> types;
+ if (!isPack)
+ types.setCount(1);
+
bool typeConstraintOptional = true;
for (auto& c : system->constraints)
{
if (c.decl != typeParam.getDecl())
continue;
+ QualType* ptype = nullptr;
+ if (isPack)
+ {
+ types.setCount(Math::Max(types.getCount(), c.indexInPack + 1));
+ ptype = &types[c.indexInPack];
+ }
+ else
+ ptype = &types[0];
+ QualType& type = *ptype;
auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
SLANG_RELEASE_ASSERT(cType);
@@ -372,12 +418,40 @@ namespace Slang
c.satisfied = true;
}
- if (!type)
+ // Fail if any of the resolved type element is empty.
+ for (auto t: types)
{
- // failure!
- return DeclRef<Decl>();
+ if (!t)
+ return DeclRef<Decl>();
+ }
+ if (!isPack)
+ {
+ // If the generic parameter is not a pack, we can simply add the first type.
+ SLANG_ASSERT(types.getCount() == 1);
+ args.add(types[0]);
+ }
+ else
+ {
+ // If the generic parameter is a pack, and we are supplying one single pack argument,
+ // we can use it as is.
+ if (types.getCount() == 1 && isTypePack(types[0]))
+ {
+ args.add(types[0]);
+ }
+ else
+ {
+ // If we are supplying 0 or multiple arguments for the pack, we need to create a type pack
+ // and add it to the argument list.
+ ShortList<Type*> typeList;
+ bool isLVal = true;
+ for (auto t : types)
+ {
+ typeList.add(t);
+ isLVal = isLVal && t.isLeftValue;
+ }
+ args.add(QualType(m_astBuilder->getTypePack(typeList.getArrayView().arrayView), isLVal));
+ }
}
- args.add(type);
}
else if (auto valParam = m.as<GenericValueParamDecl>())
{
@@ -472,6 +546,8 @@ namespace Slang
// Mark sub type as constrained.
if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());
+ else if (auto subEachType = as<EachType>(constraintDeclRef.getDecl()->sub.type))
+ constrainedGenericParams.add(as<DeclRefType>(subEachType->getElementType())->getDeclRef().getDecl());
if (sub->equals(sup))
{
@@ -526,6 +602,7 @@ namespace Slang
bool SemanticsVisitor::TryUnifyVals(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
Val* fst,
bool fstLVal,
Val* snd,
@@ -536,7 +613,7 @@ namespace Slang
{
if (auto sndType = as<Type>(snd))
{
- return TryUnifyTypes(constraints, QualType(fstType, fstLVal), QualType(sndType, sndLVal));
+ return TryUnifyTypes(constraints, unifyCtx, QualType(fstType, fstLVal), QualType(sndType, sndLVal));
}
}
@@ -564,9 +641,9 @@ namespace Slang
bool okay = false;
if (fstParam)
- okay |= TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt);
+ okay |= TryUnifyIntParam(constraints, unifyCtx, fstParam->getDeclRef(), sndInt);
if (sndParam)
- okay |= TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt);
+ okay |= TryUnifyIntParam(constraints, unifyCtx, sndParam->getDeclRef(), fstInt);
return okay;
}
@@ -579,6 +656,7 @@ namespace Slang
SLANG_ASSERT(constraintDecl1);
SLANG_ASSERT(constraintDecl2);
return TryUnifyTypes(constraints,
+ unifyCtx,
constraintDecl1.getDecl()->getSup().type,
constraintDecl2.getDecl()->getSup().type);
}
@@ -592,6 +670,7 @@ namespace Slang
if (auto sndWit = as<SubtypeWitness>(snd))
{
return TryUnifyTypes(constraints,
+ unifyCtx,
fstWit->getSup(),
sndWit->getSup());
}
@@ -605,6 +684,7 @@ namespace Slang
bool SemanticsVisitor::tryUnifyDeclRef(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
DeclRefBase* fst,
bool fstIsLVal,
DeclRefBase* snd,
@@ -620,11 +700,12 @@ namespace Slang
return true;
if (fstGen == nullptr || sndGen == nullptr)
return false;
- return tryUnifyGenericAppDeclRef(constraints, fstGen, fstIsLVal, sndGen, sndIsLVal);
+ return tryUnifyGenericAppDeclRef(constraints, unifyCtx, fstGen, fstIsLVal, sndGen, sndIsLVal);
}
bool SemanticsVisitor::tryUnifyGenericAppDeclRef(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
GenericAppDeclRef* fst,
bool fstIsLVal,
GenericAppDeclRef* snd,
@@ -645,7 +726,7 @@ namespace Slang
bool okay = true;
for (Index aa = 0; aa < argCount; ++aa)
{
- if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal))
+ if (!TryUnifyVals(constraints, unifyCtx, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal))
{
okay = false;
}
@@ -655,7 +736,7 @@ namespace Slang
auto fstBase = fst->getBase();
auto sndBase = snd->getBase();
- if (!tryUnifyDeclRef(constraints, fstBase, fstIsLVal, sndBase, sndIsLVal))
+ if (!tryUnifyDeclRef(constraints, unifyCtx, fstBase, fstIsLVal, sndBase, sndIsLVal))
{
okay = false;
}
@@ -665,13 +746,15 @@ namespace Slang
bool SemanticsVisitor::TryUnifyTypeParam(
ConstraintSystem& constraints,
- GenericTypeParamDecl* typeParamDecl,
+ ValUnificationContext unificationContext,
+ GenericTypeParamDeclBase* typeParamDecl,
QualType type)
{
// We want to constrain the given type parameter
// to equal the given type.
Constraint constraint;
constraint.decl = typeParamDecl;
+ constraint.indexInPack = unificationContext.indexInTypePack;
constraint.val = type;
constraint.isUsedAsLValue = type.isLeftValue;
constraints.constraints.add(constraint);
@@ -681,9 +764,12 @@ namespace Slang
bool SemanticsVisitor::TryUnifyIntParam(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
GenericValueParamDecl* paramDecl,
IntVal* val)
{
+ SLANG_UNUSED(unifyCtx);
+
// We only want to accumulate constraints on
// the parameters of the declarations being
// specialized (don't accidentially constrain
@@ -704,12 +790,13 @@ namespace Slang
bool SemanticsVisitor::TryUnifyIntParam(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
DeclRef<VarDeclBase> const& varRef,
IntVal* val)
{
if(auto genericValueParamRef = varRef.as<GenericValueParamDecl>())
{
- return TryUnifyIntParam(constraints, genericValueParamRef.getDecl(), val);
+ return TryUnifyIntParam(constraints, unifyCtx, genericValueParamRef.getDecl(), val);
}
else
{
@@ -719,6 +806,7 @@ namespace Slang
bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
QualType fst,
QualType snd)
{
@@ -728,7 +816,7 @@ namespace Slang
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
if (typeParamDecl->parentDecl == constraints.genericDecl)
- return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+ return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd);
if (auto sndDeclRefType = as<DeclRefType>(snd))
{
@@ -736,7 +824,7 @@ namespace Slang
if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
if (typeParamDecl->parentDecl == constraints.genericDecl)
- return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+ return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst);
// If they refer to different declarations, we need to check if one type's super type
// matches the other type, if so we can unify them.
@@ -775,6 +863,7 @@ namespace Slang
// to each declaration reference.
if (!tryUnifyDeclRef(
constraints,
+ unifyCtx,
fstDeclRef,
fst.isLeftValue,
sndDeclRef,
@@ -795,18 +884,46 @@ namespace Slang
return false;
for(Index i = 0; i < numParams; ++i)
{
- if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i)))
+ if(!TryUnifyTypes(constraints, unifyCtx, fstFunType->getParamType(i), sndFunType->getParamType(i)))
return false;
}
- return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType());
+ return TryUnifyTypes(constraints, unifyCtx, fstFunType->getResultType(), sndFunType->getResultType());
+ }
+ }
+ else if (auto expandType = as<ExpandType>(fst))
+ {
+ if (auto sndExpandType = as<ExpandType>(snd))
+ {
+ return TryUnifyTypes(constraints, unifyCtx, expandType->getPatternType(), sndExpandType->getPatternType());
+ }
+ }
+ else if (auto eachType = as<EachType>(fst))
+ {
+ if (auto sndEachType = as<EachType>(snd))
+ {
+ return TryUnifyTypes(constraints, unifyCtx, eachType->getElementType(), sndEachType->getElementType());
+ }
+ }
+ else if (auto typePack = as<ConcreteTypePack>(fst))
+ {
+ if (auto sndTypePack = as<ConcreteTypePack>(snd))
+ {
+ if (typePack->getTypeCount() != sndTypePack->getTypeCount())
+ return false;
+ for (Index i = 0; i < typePack->getTypeCount(); ++i)
+ {
+ if (!TryUnifyTypes(constraints, unifyCtx, QualType(typePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue)))
+ return false;
+ }
+ return true;
}
}
-
return false;
}
bool SemanticsVisitor::TryUnifyConjunctionType(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
QualType fst,
QualType snd)
{
@@ -820,20 +937,23 @@ namespace Slang
//
if (auto fstAndType = as<AndType>(fst))
{
- return TryUnifyTypes(constraints, QualType(fstAndType->getLeft(), fst.isLeftValue), snd)
- && TryUnifyTypes(constraints, QualType(fstAndType->getRight(), fst.isLeftValue), snd);
+ return TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getLeft(), fst.isLeftValue), snd)
+ && TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getRight(), fst.isLeftValue), snd);
}
else if (auto sndAndType = as<AndType>(snd))
{
- return TryUnifyTypes(constraints, fst, QualType(sndAndType->getLeft(), snd.isLeftValue))
- || TryUnifyTypes(constraints, fst, QualType(sndAndType->getRight(), snd.isLeftValue));
+ return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getLeft(), snd.isLeftValue))
+ || TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getRight(), snd.isLeftValue));
}
else
return false;
}
- void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
+ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(
+ ConstraintSystem& constraints, ValUnificationContext unifyCtx, IntVal* param, IntVal* arg, bool paramIsLVal)
{
+ SLANG_UNUSED(unifyCtx);
+
// If `param` is an unconstrained integer val param, and `arg` is a const int val,
// we add a constraint to the system that `param` must be equal to `arg`.
// If `param` is already constrained, ignore and do nothing.
@@ -857,6 +977,7 @@ namespace Slang
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
+ ValUnificationContext unifyCtx,
QualType fst,
QualType snd)
{
@@ -883,7 +1004,49 @@ namespace Slang
//
if (as<AndType>(fst) || as<AndType>(snd))
{
- return TryUnifyConjunctionType(constraints, fst, snd);
+ return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd);
+ }
+
+ // If one of the types is a type pack, we need to recursively unify the element types.
+ if (auto fstTypePack = as<ConcreteTypePack>(fst))
+ {
+ if (auto sndTypePack = as<ConcreteTypePack>(snd))
+ {
+ if (fstTypePack->getTypeCount() != sndTypePack->getTypeCount())
+ return false;
+ for (Index i = 0; i < fstTypePack->getTypeCount(); ++i)
+ {
+ if (!TryUnifyTypes(constraints, unifyCtx,QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue)))
+ return false;
+ }
+ return true;
+ }
+ else if (auto sndExpandType = as<ExpandType>(snd))
+ {
+ for (Index i = 0; i < fstTypePack->getTypeCount(); ++i)
+ {
+ ValUnificationContext subUnifyCtx = unifyCtx;
+ subUnifyCtx.indexInTypePack = i;
+ if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndExpandType->getPatternType(), snd.isLeftValue)))
+ return false;
+ }
+ return true;
+ }
+ }
+
+ if (auto sndTypePack = as<ConcreteTypePack>(snd))
+ {
+ if (auto fstExpandType = as<ExpandType>(fst))
+ {
+ for (Index i = 0; i < sndTypePack->getTypeCount(); ++i)
+ {
+ ValUnificationContext subUnifyCtx = unifyCtx;
+ subUnifyCtx.indexInTypePack = i;
+ if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstExpandType->getPatternType(), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue)))
+ return false;
+ }
+ return true;
+ }
}
// A generic parameter type can unify with anything.
@@ -897,7 +1060,13 @@ namespace Slang
if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl()))
{
if(typeParamDecl->parentDecl == constraints.genericDecl)
- return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+ return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd);
+ }
+ else if (auto typePackParamDecl = as<GenericTypePackParamDecl>(fstDeclRef.getDecl()))
+ {
+ if (typePackParamDecl->parentDecl == constraints.genericDecl
+ && isTypePack(snd))
+ return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, snd);
}
}
@@ -905,15 +1074,21 @@ namespace Slang
{
auto sndDeclRef = sndDeclRefType->getDeclRef();
- if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl()))
+ if (auto typeParamDecl = as<GenericTypeParamDeclBase>(sndDeclRef.getDecl()))
{
if(typeParamDecl->parentDecl == constraints.genericDecl)
- return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+ return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst);
+ }
+ else if (auto typePackParamDecl = as<GenericTypePackParamDecl>(sndDeclRef.getDecl()))
+ {
+ if (typePackParamDecl->parentDecl == constraints.genericDecl
+ && isTypePack(fst))
+ return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, fst);
}
}
// If we can unify the types structurally, then we are golden
- if(TryUnifyTypesByStructuralMatch(constraints, fst, snd))
+ if(TryUnifyTypesByStructuralMatch(constraints, unifyCtx, fst, snd))
return true;
// Now we need to consider cases where coercion might
@@ -930,9 +1105,10 @@ namespace Slang
// However, we don't want a failed unification to fail the entire generic argument inference,
// because a scalar can still be casted into a vector of any length.
- maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
+ maybeUnifyUnconstraintIntParam(constraints, unifyCtx, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
return TryUnifyTypes(
constraints,
+ unifyCtx,
QualType(fstVectorType->getElementType(), fst.isLeftValue),
QualType(sndScalarType, snd.isLeftValue));
}
@@ -942,18 +1118,47 @@ namespace Slang
{
if(auto sndVectorType = as<VectorExpressionType>(snd))
{
- maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
+ maybeUnifyUnconstraintIntParam(constraints, unifyCtx, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
return TryUnifyTypes(
constraints,
+ unifyCtx,
QualType(fstScalarType, fst.isLeftValue),
QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}
if (auto fstUniformParamGroupType = as<UniformParameterGroupType>(fst))
- return TryUnifyTypes(constraints, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd);
+ return TryUnifyTypes(constraints, unifyCtx, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd);
if (auto sndUniformParamGroupType = as<UniformParameterGroupType>(snd))
- return TryUnifyTypes(constraints, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue));
+ return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue));
+
+ // Each T can coerce with any DeclRefType.
+ if (auto eachSnd = as<EachType>(snd))
+ {
+ if (auto innerSnd = eachSnd->getElementDeclRefType())
+ {
+ if (auto sndTypePackParamDecl = as<GenericTypePackParamDecl>(innerSnd->getDeclRef().getDecl()))
+ {
+ if (innerSnd->getDeclRef().getDecl()->parentDecl == constraints.genericDecl)
+ {
+ return TryUnifyTypeParam(constraints, unifyCtx, sndTypePackParamDecl, fst);
+ }
+ }
+ }
+ }
+ if (auto eachFst = as<EachType>(fst))
+ {
+ if (auto innerFst = eachFst->getElementDeclRefType())
+ {
+ if (auto fstTypePackParamDecl = as<GenericTypePackParamDecl>(innerFst->getDeclRef().getDecl()))
+ {
+ if (innerFst->getDeclRef().getDecl()->parentDecl == constraints.genericDecl)
+ {
+ return TryUnifyTypeParam(constraints, unifyCtx, fstTypePackParamDecl, snd);
+ }
+ }
+ }
+ }
return false;
}