diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-29 11:29:14 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-29 11:29:14 -0800 |
| commit | 4fb3b10b81cf8c976ebd1ebb7fcde7708f022957 (patch) | |
| tree | 394a08e5b744fa85ac98c0b8758e994b0aab3a34 /source/slang/slang-check-constraint.cpp | |
| parent | 62426e94ef11fd6baa213757f87114ec174b406e (diff) | |
Improve generic type argument inference. (#3370)
* Improve generic type argument inference.
* Fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 65 |
1 files changed, 57 insertions, 8 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 8fd4061db..97dbbcfa3 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -261,8 +261,11 @@ namespace Slang DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( ConstraintSystem* system, DeclRef<GenericDecl> genericDeclRef, - ArrayView<Val*> knownGenericArgs) + ArrayView<Val*> knownGenericArgs, + ConversionCost& outBaseCost) { + outBaseCost = kConversionCost_None; + // For now the "solver" is going to be ridiculously simplistic. // The generic itself will have some constraints, and for now we add these @@ -340,6 +343,8 @@ namespace Slang } QualType type; + bool typeConstraintOptional = true; + for (auto& c : system->constraints) { if (c.decl != typeParam.getDecl()) @@ -348,11 +353,12 @@ namespace Slang auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue); SLANG_RELEASE_ASSERT(cType); - if (!type) + if (!type || (typeConstraintOptional && !c.isOptional)) { type = cType; + typeConstraintOptional = c.isOptional; } - else + else if (!typeConstraintOptional) { auto joinType = TryJoinTypes(type, cType); if (!joinType) @@ -397,6 +403,7 @@ namespace Slang // TODO(tfoley): figure out how this needs to interact with // compile-time integers that aren't just constants... IntVal* val = nullptr; + bool valOptional = true; for (auto& c : system->constraints) { if (c.decl != valParam.getDecl()) @@ -405,13 +412,14 @@ namespace Slang auto cVal = as<IntVal>(c.val); SLANG_RELEASE_ASSERT(cVal); - if (!val) + if (!val || (valOptional && !c.isOptional)) { val = cVal; + valOptional = c.isOptional; } else { - if(!val->equals(cVal)) + if(!valOptional && !val->equals(cVal)) { // failure! return DeclRef<Decl>(); @@ -450,6 +458,8 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... + HashSet<Decl*> constrainedGenericParams; + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() ) { DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef( @@ -458,6 +468,10 @@ namespace Slang // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); + + // Mark sub type as constrained. + if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type)) + constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); if (sub->equals(sup)) { @@ -475,6 +489,7 @@ namespace Slang { // We found a witness, so it will become an (implicit) argument. args.add(subTypeWitness); + outBaseCost += subTypeWitness->getOverloadResolutionCost(); } else { @@ -489,6 +504,13 @@ namespace Slang // system as being solved now, as a result of the witness we found. } + // Add a flat cost to all unconstrained generic params. + for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>()) + { + if (!constrainedGenericParams.contains(typeParamDecl)) + outBaseCost += kConversionCost_UnconstraintGenericParam; + } + // Make sure we haven't constructed any spurious constraints // that we aren't able to satisfy: for (auto c : system->constraints) @@ -810,6 +832,29 @@ namespace Slang return false; } + void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal) + { + // If `param` is an unconstrained integer val param, and `arg` is a const int val, + // we add a constraint to the system that `param` must be equal to `arg`. + // If `param` is already constrained, ignore and do nothing. + if (auto typeCastParam = as<TypeCastIntVal>(param)) + { + param = as<IntVal>(typeCastParam->getBase()); + } + auto intParam = as<GenericParamIntVal>(param); + if (!intParam) + return; + for (auto c : constraints.constraints) + if (c.decl == intParam->getDeclRef().getDecl()) + return; + Constraint c; + c.decl = intParam->getDeclRef().getDecl(); + c.isUsedAsLValue = paramIsLVal; + c.val = arg; + c.isOptional = true; + constraints.constraints.add(c); + } + bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, QualType fst, @@ -880,6 +925,12 @@ namespace Slang { if(auto sndScalarType = as<BasicExpressionType>(snd)) { + // Try unify the vector count param. In case the vector count is defined by a generic value + // parameter, we want to be able to infer that parameter should be 1. + // However, we don't want a failed unification to fail the entire generic argument inference, + // because a scalar can still be casted into a vector of any length. + + maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue); return TryUnifyTypes( constraints, QualType(fstVectorType->getElementType(), fst.isLeftValue), @@ -891,15 +942,13 @@ namespace Slang { if(auto sndVectorType = as<VectorExpressionType>(snd)) { + maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue); return TryUnifyTypes( constraints, QualType(fstScalarType, fst.isLeftValue), QualType(sndVectorType->getElementType(), snd.isLeftValue)); } } - - // TODO: the same thing for vectors... - return false; } |
