summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-constraint.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-11-29 11:29:14 -0800
committerGitHub <noreply@github.com>2023-11-29 11:29:14 -0800
commit4fb3b10b81cf8c976ebd1ebb7fcde7708f022957 (patch)
tree394a08e5b744fa85ac98c0b8758e994b0aab3a34 /source/slang/slang-check-constraint.cpp
parent62426e94ef11fd6baa213757f87114ec174b406e (diff)
Improve generic type argument inference. (#3370)
* Improve generic type argument inference. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-constraint.cpp')
-rw-r--r--source/slang/slang-check-constraint.cpp65
1 files changed, 57 insertions, 8 deletions
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 8fd4061db..97dbbcfa3 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -261,8 +261,11 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- ArrayView<Val*> knownGenericArgs)
+ ArrayView<Val*> knownGenericArgs,
+ ConversionCost& outBaseCost)
{
+ outBaseCost = kConversionCost_None;
+
// For now the "solver" is going to be ridiculously simplistic.
// The generic itself will have some constraints, and for now we add these
@@ -340,6 +343,8 @@ namespace Slang
}
QualType type;
+ bool typeConstraintOptional = true;
+
for (auto& c : system->constraints)
{
if (c.decl != typeParam.getDecl())
@@ -348,11 +353,12 @@ namespace Slang
auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
SLANG_RELEASE_ASSERT(cType);
- if (!type)
+ if (!type || (typeConstraintOptional && !c.isOptional))
{
type = cType;
+ typeConstraintOptional = c.isOptional;
}
- else
+ else if (!typeConstraintOptional)
{
auto joinType = TryJoinTypes(type, cType);
if (!joinType)
@@ -397,6 +403,7 @@ namespace Slang
// TODO(tfoley): figure out how this needs to interact with
// compile-time integers that aren't just constants...
IntVal* val = nullptr;
+ bool valOptional = true;
for (auto& c : system->constraints)
{
if (c.decl != valParam.getDecl())
@@ -405,13 +412,14 @@ namespace Slang
auto cVal = as<IntVal>(c.val);
SLANG_RELEASE_ASSERT(cVal);
- if (!val)
+ if (!val || (valOptional && !c.isOptional))
{
val = cVal;
+ valOptional = c.isOptional;
}
else
{
- if(!val->equals(cVal))
+ if(!valOptional && !val->equals(cVal))
{
// failure!
return DeclRef<Decl>();
@@ -450,6 +458,8 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...
+ HashSet<Decl*> constrainedGenericParams;
+
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
@@ -458,6 +468,10 @@ namespace Slang
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);
+
+ // Mark sub type as constrained.
+ if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
+ constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());
if (sub->equals(sup))
{
@@ -475,6 +489,7 @@ namespace Slang
{
// We found a witness, so it will become an (implicit) argument.
args.add(subTypeWitness);
+ outBaseCost += subTypeWitness->getOverloadResolutionCost();
}
else
{
@@ -489,6 +504,13 @@ namespace Slang
// system as being solved now, as a result of the witness we found.
}
+ // Add a flat cost to all unconstrained generic params.
+ for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
+ {
+ if (!constrainedGenericParams.contains(typeParamDecl))
+ outBaseCost += kConversionCost_UnconstraintGenericParam;
+ }
+
// Make sure we haven't constructed any spurious constraints
// that we aren't able to satisfy:
for (auto c : system->constraints)
@@ -810,6 +832,29 @@ namespace Slang
return false;
}
+ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
+ {
+ // If `param` is an unconstrained integer val param, and `arg` is a const int val,
+ // we add a constraint to the system that `param` must be equal to `arg`.
+ // If `param` is already constrained, ignore and do nothing.
+ if (auto typeCastParam = as<TypeCastIntVal>(param))
+ {
+ param = as<IntVal>(typeCastParam->getBase());
+ }
+ auto intParam = as<GenericParamIntVal>(param);
+ if (!intParam)
+ return;
+ for (auto c : constraints.constraints)
+ if (c.decl == intParam->getDeclRef().getDecl())
+ return;
+ Constraint c;
+ c.decl = intParam->getDeclRef().getDecl();
+ c.isUsedAsLValue = paramIsLVal;
+ c.val = arg;
+ c.isOptional = true;
+ constraints.constraints.add(c);
+ }
+
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
QualType fst,
@@ -880,6 +925,12 @@ namespace Slang
{
if(auto sndScalarType = as<BasicExpressionType>(snd))
{
+ // Try unify the vector count param. In case the vector count is defined by a generic value
+ // parameter, we want to be able to infer that parameter should be 1.
+ // However, we don't want a failed unification to fail the entire generic argument inference,
+ // because a scalar can still be casted into a vector of any length.
+
+ maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstVectorType->getElementType(), fst.isLeftValue),
@@ -891,15 +942,13 @@ namespace Slang
{
if(auto sndVectorType = as<VectorExpressionType>(snd))
{
+ maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstScalarType, fst.isLeftValue),
QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}
-
- // TODO: the same thing for vectors...
-
return false;
}