diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-14 18:41:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-14 18:41:48 -0700 |
| commit | 071f1b6062b459928ebfd6f2f60a8d6ad021112b (patch) | |
| tree | 2ba65eb40f39701db6fc775a9258ec8079d161a0 /source/slang/slang-check-constraint.cpp | |
| parent | 35a3d32c87f079749f6b100d01b289c3da02d7d6 (diff) | |
Variadic Generics Part 1: parsing and type checking. (#4833)
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 273 |
1 files changed, 239 insertions, 34 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 1195ed1f9..0f6da156d 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -252,6 +252,27 @@ namespace Slang } } + // We can recursively join two TypePacks. + if (auto leftTypePack = as<ConcreteTypePack>(left)) + { + if (auto rightTypePack = as<ConcreteTypePack>(right)) + { + if(leftTypePack->getTypeCount() != rightTypePack->getTypeCount()) + return nullptr; + ShortList<Type*> joinedTypes; + for (Index i = 0; i < leftTypePack->getTypeCount(); ++i) + { + auto joinedType = TryJoinTypes( + QualType(leftTypePack->getElementType(i), left.isLeftValue), + QualType(rightTypePack->getElementType(i), right.isLeftValue)); + if(!joinedType) + return nullptr; + joinedTypes.add(joinedType); + } + return m_astBuilder->getTypePack(joinedTypes.getArrayView().arrayView); + } + } + // TODO: all the cases for vectors apply to matrices too! // Default case is that we just fail. @@ -285,7 +306,7 @@ namespace Slang // that `X<T>.IndexType == T`. for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef) ) { - if(!TryUnifyTypes(*system, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) + if(!TryUnifyTypes(*system, ValUnificationContext(), getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) return DeclRef<Decl>(); } @@ -322,8 +343,14 @@ namespace Slang Count paramCounter = 0; for (auto m : getMembers(m_astBuilder, genericDeclRef)) { - if (auto typeParam = m.as<GenericTypeParamDecl>()) + if (auto typeParam = m.as<GenericTypeParamDeclBase>()) { + // If the parameter is a type pack, then we may have + // constraints that apply to invidual elements of the pack. + // We will need to handle the type pack case slightly differently. + // + bool isPack = as<GenericTypePackParamDecl>(typeParam) != nullptr; + // If the parameter is one where we already know // the argument value to use, we don't bother with // trying to solve for it, and treat any constraints @@ -342,13 +369,32 @@ namespace Slang continue; } - QualType type; + + // We will use a temporary list to hold the resolved types + // for this generic parameter. + // For normal type parameters, there should be only one type + // in the list. For type pack parameters, there can be one type + // for each element in the pack. + // + ShortList<QualType> types; + if (!isPack) + types.setCount(1); + bool typeConstraintOptional = true; for (auto& c : system->constraints) { if (c.decl != typeParam.getDecl()) continue; + QualType* ptype = nullptr; + if (isPack) + { + types.setCount(Math::Max(types.getCount(), c.indexInPack + 1)); + ptype = &types[c.indexInPack]; + } + else + ptype = &types[0]; + QualType& type = *ptype; auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue); SLANG_RELEASE_ASSERT(cType); @@ -372,12 +418,40 @@ namespace Slang c.satisfied = true; } - if (!type) + // Fail if any of the resolved type element is empty. + for (auto t: types) { - // failure! - return DeclRef<Decl>(); + if (!t) + return DeclRef<Decl>(); + } + if (!isPack) + { + // If the generic parameter is not a pack, we can simply add the first type. + SLANG_ASSERT(types.getCount() == 1); + args.add(types[0]); + } + else + { + // If the generic parameter is a pack, and we are supplying one single pack argument, + // we can use it as is. + if (types.getCount() == 1 && isTypePack(types[0])) + { + args.add(types[0]); + } + else + { + // If we are supplying 0 or multiple arguments for the pack, we need to create a type pack + // and add it to the argument list. + ShortList<Type*> typeList; + bool isLVal = true; + for (auto t : types) + { + typeList.add(t); + isLVal = isLVal && t.isLeftValue; + } + args.add(QualType(m_astBuilder->getTypePack(typeList.getArrayView().arrayView), isLVal)); + } } - args.add(type); } else if (auto valParam = m.as<GenericValueParamDecl>()) { @@ -472,6 +546,8 @@ namespace Slang // Mark sub type as constrained. if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type)) constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); + else if (auto subEachType = as<EachType>(constraintDeclRef.getDecl()->sub.type)) + constrainedGenericParams.add(as<DeclRefType>(subEachType->getElementType())->getDeclRef().getDecl()); if (sub->equals(sup)) { @@ -526,6 +602,7 @@ namespace Slang bool SemanticsVisitor::TryUnifyVals( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, Val* fst, bool fstLVal, Val* snd, @@ -536,7 +613,7 @@ namespace Slang { if (auto sndType = as<Type>(snd)) { - return TryUnifyTypes(constraints, QualType(fstType, fstLVal), QualType(sndType, sndLVal)); + return TryUnifyTypes(constraints, unifyCtx, QualType(fstType, fstLVal), QualType(sndType, sndLVal)); } } @@ -564,9 +641,9 @@ namespace Slang bool okay = false; if (fstParam) - okay |= TryUnifyIntParam(constraints, fstParam->getDeclRef(), sndInt); + okay |= TryUnifyIntParam(constraints, unifyCtx, fstParam->getDeclRef(), sndInt); if (sndParam) - okay |= TryUnifyIntParam(constraints, sndParam->getDeclRef(), fstInt); + okay |= TryUnifyIntParam(constraints, unifyCtx, sndParam->getDeclRef(), fstInt); return okay; } @@ -579,6 +656,7 @@ namespace Slang SLANG_ASSERT(constraintDecl1); SLANG_ASSERT(constraintDecl2); return TryUnifyTypes(constraints, + unifyCtx, constraintDecl1.getDecl()->getSup().type, constraintDecl2.getDecl()->getSup().type); } @@ -592,6 +670,7 @@ namespace Slang if (auto sndWit = as<SubtypeWitness>(snd)) { return TryUnifyTypes(constraints, + unifyCtx, fstWit->getSup(), sndWit->getSup()); } @@ -605,6 +684,7 @@ namespace Slang bool SemanticsVisitor::tryUnifyDeclRef( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, DeclRefBase* fst, bool fstIsLVal, DeclRefBase* snd, @@ -620,11 +700,12 @@ namespace Slang return true; if (fstGen == nullptr || sndGen == nullptr) return false; - return tryUnifyGenericAppDeclRef(constraints, fstGen, fstIsLVal, sndGen, sndIsLVal); + return tryUnifyGenericAppDeclRef(constraints, unifyCtx, fstGen, fstIsLVal, sndGen, sndIsLVal); } bool SemanticsVisitor::tryUnifyGenericAppDeclRef( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, GenericAppDeclRef* fst, bool fstIsLVal, GenericAppDeclRef* snd, @@ -645,7 +726,7 @@ namespace Slang bool okay = true; for (Index aa = 0; aa < argCount; ++aa) { - if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal)) + if (!TryUnifyVals(constraints, unifyCtx, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal)) { okay = false; } @@ -655,7 +736,7 @@ namespace Slang auto fstBase = fst->getBase(); auto sndBase = snd->getBase(); - if (!tryUnifyDeclRef(constraints, fstBase, fstIsLVal, sndBase, sndIsLVal)) + if (!tryUnifyDeclRef(constraints, unifyCtx, fstBase, fstIsLVal, sndBase, sndIsLVal)) { okay = false; } @@ -665,13 +746,15 @@ namespace Slang bool SemanticsVisitor::TryUnifyTypeParam( ConstraintSystem& constraints, - GenericTypeParamDecl* typeParamDecl, + ValUnificationContext unificationContext, + GenericTypeParamDeclBase* typeParamDecl, QualType type) { // We want to constrain the given type parameter // to equal the given type. Constraint constraint; constraint.decl = typeParamDecl; + constraint.indexInPack = unificationContext.indexInTypePack; constraint.val = type; constraint.isUsedAsLValue = type.isLeftValue; constraints.constraints.add(constraint); @@ -681,9 +764,12 @@ namespace Slang bool SemanticsVisitor::TryUnifyIntParam( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, GenericValueParamDecl* paramDecl, IntVal* val) { + SLANG_UNUSED(unifyCtx); + // We only want to accumulate constraints on // the parameters of the declarations being // specialized (don't accidentially constrain @@ -704,12 +790,13 @@ namespace Slang bool SemanticsVisitor::TryUnifyIntParam( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, DeclRef<VarDeclBase> const& varRef, IntVal* val) { if(auto genericValueParamRef = varRef.as<GenericValueParamDecl>()) { - return TryUnifyIntParam(constraints, genericValueParamRef.getDecl(), val); + return TryUnifyIntParam(constraints, unifyCtx, genericValueParamRef.getDecl(), val); } else { @@ -719,6 +806,7 @@ namespace Slang bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, QualType fst, QualType snd) { @@ -728,7 +816,7 @@ namespace Slang if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) if (typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); if (auto sndDeclRefType = as<DeclRefType>(snd)) { @@ -736,7 +824,7 @@ namespace Slang if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) if (typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); // If they refer to different declarations, we need to check if one type's super type // matches the other type, if so we can unify them. @@ -775,6 +863,7 @@ namespace Slang // to each declaration reference. if (!tryUnifyDeclRef( constraints, + unifyCtx, fstDeclRef, fst.isLeftValue, sndDeclRef, @@ -795,18 +884,46 @@ namespace Slang return false; for(Index i = 0; i < numParams; ++i) { - if(!TryUnifyTypes(constraints, fstFunType->getParamType(i), sndFunType->getParamType(i))) + if(!TryUnifyTypes(constraints, unifyCtx, fstFunType->getParamType(i), sndFunType->getParamType(i))) return false; } - return TryUnifyTypes(constraints, fstFunType->getResultType(), sndFunType->getResultType()); + return TryUnifyTypes(constraints, unifyCtx, fstFunType->getResultType(), sndFunType->getResultType()); + } + } + else if (auto expandType = as<ExpandType>(fst)) + { + if (auto sndExpandType = as<ExpandType>(snd)) + { + return TryUnifyTypes(constraints, unifyCtx, expandType->getPatternType(), sndExpandType->getPatternType()); + } + } + else if (auto eachType = as<EachType>(fst)) + { + if (auto sndEachType = as<EachType>(snd)) + { + return TryUnifyTypes(constraints, unifyCtx, eachType->getElementType(), sndEachType->getElementType()); + } + } + else if (auto typePack = as<ConcreteTypePack>(fst)) + { + if (auto sndTypePack = as<ConcreteTypePack>(snd)) + { + if (typePack->getTypeCount() != sndTypePack->getTypeCount()) + return false; + for (Index i = 0; i < typePack->getTypeCount(); ++i) + { + if (!TryUnifyTypes(constraints, unifyCtx, QualType(typePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) + return false; + } + return true; } } - return false; } bool SemanticsVisitor::TryUnifyConjunctionType( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, QualType fst, QualType snd) { @@ -820,20 +937,23 @@ namespace Slang // if (auto fstAndType = as<AndType>(fst)) { - return TryUnifyTypes(constraints, QualType(fstAndType->getLeft(), fst.isLeftValue), snd) - && TryUnifyTypes(constraints, QualType(fstAndType->getRight(), fst.isLeftValue), snd); + return TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getLeft(), fst.isLeftValue), snd) + && TryUnifyTypes(constraints, unifyCtx, QualType(fstAndType->getRight(), fst.isLeftValue), snd); } else if (auto sndAndType = as<AndType>(snd)) { - return TryUnifyTypes(constraints, fst, QualType(sndAndType->getLeft(), snd.isLeftValue)) - || TryUnifyTypes(constraints, fst, QualType(sndAndType->getRight(), snd.isLeftValue)); + return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getLeft(), snd.isLeftValue)) + || TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndAndType->getRight(), snd.isLeftValue)); } else return false; } - void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal) + void SemanticsVisitor::maybeUnifyUnconstraintIntParam( + ConstraintSystem& constraints, ValUnificationContext unifyCtx, IntVal* param, IntVal* arg, bool paramIsLVal) { + SLANG_UNUSED(unifyCtx); + // If `param` is an unconstrained integer val param, and `arg` is a const int val, // we add a constraint to the system that `param` must be equal to `arg`. // If `param` is already constrained, ignore and do nothing. @@ -857,6 +977,7 @@ namespace Slang bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, + ValUnificationContext unifyCtx, QualType fst, QualType snd) { @@ -883,7 +1004,49 @@ namespace Slang // if (as<AndType>(fst) || as<AndType>(snd)) { - return TryUnifyConjunctionType(constraints, fst, snd); + return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); + } + + // If one of the types is a type pack, we need to recursively unify the element types. + if (auto fstTypePack = as<ConcreteTypePack>(fst)) + { + if (auto sndTypePack = as<ConcreteTypePack>(snd)) + { + if (fstTypePack->getTypeCount() != sndTypePack->getTypeCount()) + return false; + for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) + { + if (!TryUnifyTypes(constraints, unifyCtx,QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) + return false; + } + return true; + } + else if (auto sndExpandType = as<ExpandType>(snd)) + { + for (Index i = 0; i < fstTypePack->getTypeCount(); ++i) + { + ValUnificationContext subUnifyCtx = unifyCtx; + subUnifyCtx.indexInTypePack = i; + if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstTypePack->getElementType(i), fst.isLeftValue), QualType(sndExpandType->getPatternType(), snd.isLeftValue))) + return false; + } + return true; + } + } + + if (auto sndTypePack = as<ConcreteTypePack>(snd)) + { + if (auto fstExpandType = as<ExpandType>(fst)) + { + for (Index i = 0; i < sndTypePack->getTypeCount(); ++i) + { + ValUnificationContext subUnifyCtx = unifyCtx; + subUnifyCtx.indexInTypePack = i; + if (!TryUnifyTypes(constraints, subUnifyCtx, QualType(fstExpandType->getPatternType(), fst.isLeftValue), QualType(sndTypePack->getElementType(i), snd.isLeftValue))) + return false; + } + return true; + } } // A generic parameter type can unify with anything. @@ -897,7 +1060,13 @@ namespace Slang if (auto typeParamDecl = as<GenericTypeParamDecl>(fstDeclRef.getDecl())) { if(typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, snd); + } + else if (auto typePackParamDecl = as<GenericTypePackParamDecl>(fstDeclRef.getDecl())) + { + if (typePackParamDecl->parentDecl == constraints.genericDecl + && isTypePack(snd)) + return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, snd); } } @@ -905,15 +1074,21 @@ namespace Slang { auto sndDeclRef = sndDeclRefType->getDeclRef(); - if (auto typeParamDecl = as<GenericTypeParamDecl>(sndDeclRef.getDecl())) + if (auto typeParamDecl = as<GenericTypeParamDeclBase>(sndDeclRef.getDecl())) { if(typeParamDecl->parentDecl == constraints.genericDecl) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); + return TryUnifyTypeParam(constraints, unifyCtx, typeParamDecl, fst); + } + else if (auto typePackParamDecl = as<GenericTypePackParamDecl>(sndDeclRef.getDecl())) + { + if (typePackParamDecl->parentDecl == constraints.genericDecl + && isTypePack(fst)) + return TryUnifyTypeParam(constraints, unifyCtx, typePackParamDecl, fst); } } // If we can unify the types structurally, then we are golden - if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) + if(TryUnifyTypesByStructuralMatch(constraints, unifyCtx, fst, snd)) return true; // Now we need to consider cases where coercion might @@ -930,9 +1105,10 @@ namespace Slang // However, we don't want a failed unification to fail the entire generic argument inference, // because a scalar can still be casted into a vector of any length. - maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue); + maybeUnifyUnconstraintIntParam(constraints, unifyCtx, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue); return TryUnifyTypes( constraints, + unifyCtx, QualType(fstVectorType->getElementType(), fst.isLeftValue), QualType(sndScalarType, snd.isLeftValue)); } @@ -942,18 +1118,47 @@ namespace Slang { if(auto sndVectorType = as<VectorExpressionType>(snd)) { - maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue); + maybeUnifyUnconstraintIntParam(constraints, unifyCtx, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue); return TryUnifyTypes( constraints, + unifyCtx, QualType(fstScalarType, fst.isLeftValue), QualType(sndVectorType->getElementType(), snd.isLeftValue)); } } if (auto fstUniformParamGroupType = as<UniformParameterGroupType>(fst)) - return TryUnifyTypes(constraints, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd); + return TryUnifyTypes(constraints, unifyCtx, QualType(fstUniformParamGroupType->getElementType(), fst.isLeftValue), snd); if (auto sndUniformParamGroupType = as<UniformParameterGroupType>(snd)) - return TryUnifyTypes(constraints, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue)); + return TryUnifyTypes(constraints, unifyCtx, fst, QualType(sndUniformParamGroupType->getElementType(), snd.isLeftValue)); + + // Each T can coerce with any DeclRefType. + if (auto eachSnd = as<EachType>(snd)) + { + if (auto innerSnd = eachSnd->getElementDeclRefType()) + { + if (auto sndTypePackParamDecl = as<GenericTypePackParamDecl>(innerSnd->getDeclRef().getDecl())) + { + if (innerSnd->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) + { + return TryUnifyTypeParam(constraints, unifyCtx, sndTypePackParamDecl, fst); + } + } + } + } + if (auto eachFst = as<EachType>(fst)) + { + if (auto innerFst = eachFst->getElementDeclRefType()) + { + if (auto fstTypePackParamDecl = as<GenericTypePackParamDecl>(innerFst->getDeclRef().getDecl())) + { + if (innerFst->getDeclRef().getDecl()->parentDecl == constraints.genericDecl) + { + return TryUnifyTypeParam(constraints, unifyCtx, fstTypePackParamDecl, snd); + } + } + } + } return false; } |
