summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-support-types.h3
-rw-r--r--source/slang/slang-ast-val.cpp29
-rw-r--r--source/slang/slang-ast-val.h11
-rw-r--r--source/slang/slang-check-constraint.cpp65
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-check-impl.h50
-rw-r--r--source/slang/slang-check-overload.cpp63
-rw-r--r--source/slang/slang-check-resolve-val.cpp10
8 files changed, 192 insertions, 43 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 93c53a975..5da19f377 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -80,6 +80,9 @@ namespace Slang
// No conversion at all
kConversionCost_None = 0,
+ kConversionCost_GenericParamUpcast = 1,
+ kConversionCost_UnconstraintGenericParam = 20,
+
// Convert between matrices of different layout
kConversionCost_MatrixLayout = 5,
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index e860e1ec6..d1408a3fc 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -286,6 +286,11 @@ Val* DeclaredSubtypeWitness::_resolveImplOverride()
return this;
}
+ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride()
+{
+ return kConversionCost_GenericParamUpcast;
+}
+
Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>())
@@ -431,6 +436,11 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S
return astBuilder->getTransitiveSubtypeWitness(substSubToMid, substMidToSup);
}
+ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride()
+{
+ return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost();
+}
+
void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out)
{
// Note: we only print the constituent
@@ -471,6 +481,17 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
substSub, substSup, substWitness, getIndexInConjunction());
}
+ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
+{
+ auto witness = as<ConjunctionSubtypeWitness>(getConjunctionWitness());
+ if (!witness)
+ return kConversionCost_None;
+ auto index = getIndexInConjunction();
+ if (index < witness->getComponentCount())
+ return witness->getComponentWitness(index)->getOverloadResolutionCost();
+ return kConversionCost_None;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out)
@@ -541,6 +562,14 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
return result;
}
+ConversionCost ConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
+{
+ ConversionCost result = kConversionCost_None;
+ for (Index i = 0; i < getComponentCount(); i++)
+ result += getComponentWitness(i)->getOverloadResolutionCost();
+ return result;
+}
+
void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ExtractFromConjunctionSubtypeWitness(";
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index c45c42e02..f85a76187 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -457,6 +457,9 @@ class SubtypeWitness : public Witness
Type* getSub() { return as<Type>(getOperand(0)); }
Type* getSup() { return as<Type>(getOperand(1)); }
+
+ ConversionCost _getOverloadResolutionCostOverride();
+ ConversionCost getOverloadResolutionCost();
};
class TypeEqualityWitness : public SubtypeWitness
@@ -493,6 +496,8 @@ class DeclaredSubtypeWitness : public SubtypeWitness
{
setOperands(inSub, inSup, inDeclRef);
}
+
+ ConversionCost _getOverloadResolutionCostOverride();
};
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
@@ -520,6 +525,8 @@ class TransitiveSubtypeWitness : public SubtypeWitness
{
setOperands(subType, supType, inSubToMid, inMidToSup);
}
+
+ ConversionCost _getOverloadResolutionCostOverride();
};
// A witness that `sub : sup` because `sub` was wrapped into
@@ -580,6 +587,8 @@ class ConjunctionSubtypeWitness : public SubtypeWitness
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ ConversionCost _getOverloadResolutionCostOverride();
};
/// A witness that `T <: L` or `T <: R` because `T <: L&R`
@@ -609,6 +618,8 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ ConversionCost _getOverloadResolutionCostOverride();
};
/// A value that represents a modifier attached to some other value
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 8fd4061db..97dbbcfa3 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -261,8 +261,11 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- ArrayView<Val*> knownGenericArgs)
+ ArrayView<Val*> knownGenericArgs,
+ ConversionCost& outBaseCost)
{
+ outBaseCost = kConversionCost_None;
+
// For now the "solver" is going to be ridiculously simplistic.
// The generic itself will have some constraints, and for now we add these
@@ -340,6 +343,8 @@ namespace Slang
}
QualType type;
+ bool typeConstraintOptional = true;
+
for (auto& c : system->constraints)
{
if (c.decl != typeParam.getDecl())
@@ -348,11 +353,12 @@ namespace Slang
auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
SLANG_RELEASE_ASSERT(cType);
- if (!type)
+ if (!type || (typeConstraintOptional && !c.isOptional))
{
type = cType;
+ typeConstraintOptional = c.isOptional;
}
- else
+ else if (!typeConstraintOptional)
{
auto joinType = TryJoinTypes(type, cType);
if (!joinType)
@@ -397,6 +403,7 @@ namespace Slang
// TODO(tfoley): figure out how this needs to interact with
// compile-time integers that aren't just constants...
IntVal* val = nullptr;
+ bool valOptional = true;
for (auto& c : system->constraints)
{
if (c.decl != valParam.getDecl())
@@ -405,13 +412,14 @@ namespace Slang
auto cVal = as<IntVal>(c.val);
SLANG_RELEASE_ASSERT(cVal);
- if (!val)
+ if (!val || (valOptional && !c.isOptional))
{
val = cVal;
+ valOptional = c.isOptional;
}
else
{
- if(!val->equals(cVal))
+ if(!valOptional && !val->equals(cVal))
{
// failure!
return DeclRef<Decl>();
@@ -450,6 +458,8 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...
+ HashSet<Decl*> constrainedGenericParams;
+
for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
@@ -458,6 +468,10 @@ namespace Slang
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);
+
+ // Mark sub type as constrained.
+ if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
+ constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());
if (sub->equals(sup))
{
@@ -475,6 +489,7 @@ namespace Slang
{
// We found a witness, so it will become an (implicit) argument.
args.add(subTypeWitness);
+ outBaseCost += subTypeWitness->getOverloadResolutionCost();
}
else
{
@@ -489,6 +504,13 @@ namespace Slang
// system as being solved now, as a result of the witness we found.
}
+ // Add a flat cost to all unconstrained generic params.
+ for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
+ {
+ if (!constrainedGenericParams.contains(typeParamDecl))
+ outBaseCost += kConversionCost_UnconstraintGenericParam;
+ }
+
// Make sure we haven't constructed any spurious constraints
// that we aren't able to satisfy:
for (auto c : system->constraints)
@@ -810,6 +832,29 @@ namespace Slang
return false;
}
+ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
+ {
+ // If `param` is an unconstrained integer val param, and `arg` is a const int val,
+ // we add a constraint to the system that `param` must be equal to `arg`.
+ // If `param` is already constrained, ignore and do nothing.
+ if (auto typeCastParam = as<TypeCastIntVal>(param))
+ {
+ param = as<IntVal>(typeCastParam->getBase());
+ }
+ auto intParam = as<GenericParamIntVal>(param);
+ if (!intParam)
+ return;
+ for (auto c : constraints.constraints)
+ if (c.decl == intParam->getDeclRef().getDecl())
+ return;
+ Constraint c;
+ c.decl = intParam->getDeclRef().getDecl();
+ c.isUsedAsLValue = paramIsLVal;
+ c.val = arg;
+ c.isOptional = true;
+ constraints.constraints.add(c);
+ }
+
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
QualType fst,
@@ -880,6 +925,12 @@ namespace Slang
{
if(auto sndScalarType = as<BasicExpressionType>(snd))
{
+ // Try unify the vector count param. In case the vector count is defined by a generic value
+ // parameter, we want to be able to infer that parameter should be 1.
+ // However, we don't want a failed unification to fail the entire generic argument inference,
+ // because a scalar can still be casted into a vector of any length.
+
+ maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstVectorType->getElementType(), fst.isLeftValue),
@@ -891,15 +942,13 @@ namespace Slang
{
if(auto sndVectorType = as<VectorExpressionType>(snd))
{
+ maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstScalarType, fst.isLeftValue),
QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}
-
- // TODO: the same thing for vectors...
-
return false;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7c36bdd5f..98a2b18a1 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6630,7 +6630,9 @@ namespace Slang
if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type))
return DeclRef<ExtensionDecl>();
- auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>());
+
+ ConversionCost baseCost;
+ auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>(), baseCost);
if (!solvedDeclRef)
{
return DeclRef<ExtensionDecl>();
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 5b67dc413..31be012d3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1827,6 +1827,10 @@ namespace Slang
Val* val = nullptr; // the value to which we are constraining it
bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used in an l-value parameter?
bool satisfied = false; // Has this constraint been met?
+
+ // Is this constraint optional? An optional constraint provides a hint value to a parameter
+ // if it is otherwise unconstrained, but doesn't take precedence over a constraint that is not optional.
+ bool isOptional = false;
};
// A collection of constraints that will need to be satisfied (solved)
@@ -1944,7 +1948,8 @@ namespace Slang
DeclRef<Decl> trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
- ArrayView<Val*> knownGenericArgs);
+ ArrayView<Val*> knownGenericArgs,
+ ConversionCost& outBaseCost);
// State related to overload resolution for a call
@@ -2120,25 +2125,30 @@ namespace Slang
void AddOverloadCandidate(
OverloadResolveContext& context,
- OverloadCandidate& candidate);
+ OverloadCandidate& candidate,
+ ConversionCost baseCost);
void AddHigherOrderOverloadCandidates(
Expr* funcExpr,
- OverloadResolveContext& context);
+ OverloadResolveContext& context,
+ ConversionCost baseCost);
void AddFuncOverloadCandidate(
LookupResultItem item,
DeclRef<CallableDecl> funcDeclRef,
- OverloadResolveContext& context);
+ OverloadResolveContext& context,
+ ConversionCost baseCost);
void AddFuncOverloadCandidate(
FuncType* /*funcType*/,
- OverloadResolveContext& /*context*/);
+ OverloadResolveContext& /*context*/,
+ ConversionCost baseCost);
void AddFuncExprOverloadCandidate(
FuncType* funcType,
OverloadResolveContext& context,
- Expr* expr);
+ Expr* expr,
+ ConversionCost baseCost);
// Add a candidate callee for overload resolution, based on
// calling a particular `ConstructorDecl`.
@@ -2147,7 +2157,8 @@ namespace Slang
Type* type,
DeclRef<ConstructorDecl> ctorDeclRef,
OverloadResolveContext& context,
- Type* resultType);
+ Type* resultType,
+ ConversionCost baseCost);
// If the given declaration has generic parameters, then
// return the corresponding `GenericDecl` that holds the
@@ -2216,6 +2227,12 @@ namespace Slang
QualType fst,
QualType snd);
+ void maybeUnifyUnconstraintIntParam(
+ ConstraintSystem& constraints,
+ IntVal* param,
+ IntVal* arg,
+ bool paramIsLVal);
+
// Is the candidate extension declaration actually applicable to the given type
DeclRef<ExtensionDecl> applyExtensionToType(
ExtensionDecl* extDecl,
@@ -2226,12 +2243,26 @@ namespace Slang
// arguments to form a `DeclRef` to the inner declaration
// that could be applicable in the context of the given
// overloaded call.
+ // Also computes a `baseCost` for the inferred arguments,
+ // so that we can prefer a more specialized generic candidate
+ // when there is ambiguity. For example, given
+ // ```
+ // interface IBase;
+ // interface IDerived : IBase;
+ // struct Derived : IDerived {}
+ // void f1<T:IBase>(T b)
+ // void f2<T:IDerived>(T b);
+ // ```
+ // We will prefer f2 when seeing f(Derived()), because it takes
+ // less steps to upcast `Derived` to `IDerived` than it does
+ // to `IBase`.
//
DeclRef<Decl> inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
ArrayView<Val*> knownGenericArgs,
- List<QualType> *innerParameterTypes = nullptr);
+ ConversionCost &outBaseCost,
+ List<QualType> *innerParameterTypes = nullptr);
void AddTypeOverloadCandidates(
Type* type,
@@ -2239,7 +2270,8 @@ namespace Slang
void AddDeclRefOverloadCandidates(
LookupResultItem item,
- OverloadResolveContext& context);
+ OverloadResolveContext& context,
+ ConversionCost baseCost);
void AddOverloadCandidates(
LookupResult const& result,
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 2d7315cd2..d7d29a4e1 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1246,11 +1246,14 @@ namespace Slang
void SemanticsVisitor::AddOverloadCandidate(
OverloadResolveContext& context,
- OverloadCandidate& candidate)
+ OverloadCandidate& candidate,
+ ConversionCost baseCost)
{
// Try the candidate out, to see if it is applicable at all.
TryCheckOverloadCandidate(context, candidate);
+ candidate.conversionCostSum += baseCost;
+
// Now (potentially) add it to the set of candidate overloads to consider.
AddOverloadCandidateInner(context, candidate);
}
@@ -1258,7 +1261,8 @@ namespace Slang
void SemanticsVisitor::AddFuncOverloadCandidate(
LookupResultItem item,
DeclRef<CallableDecl> funcDeclRef,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
auto funcDecl = funcDeclRef.getDecl();
ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature);
@@ -1288,25 +1292,27 @@ namespace Slang
candidate.item = item;
candidate.resultType = getResultType(m_astBuilder, funcDeclRef);
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddFuncOverloadCandidate(
FuncType* funcType,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
candidate.funcType = funcType;
candidate.resultType = funcType->getResultType();
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddFuncExprOverloadCandidate(
FuncType* funcType,
OverloadResolveContext& context,
- Expr* expr)
+ Expr* expr,
+ ConversionCost baseCost)
{
SLANG_ASSERT(expr);
OverloadCandidate candidate;
@@ -1315,7 +1321,7 @@ namespace Slang
candidate.resultType = funcType->getResultType();
candidate.exprVal = expr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddCtorOverloadCandidate(
@@ -1323,7 +1329,8 @@ namespace Slang
Type* type,
DeclRef<ConstructorDecl> ctorDeclRef,
OverloadResolveContext& context,
- Type* resultType)
+ Type* resultType,
+ ConversionCost baseCost)
{
SLANG_UNUSED(type)
@@ -1346,13 +1353,14 @@ namespace Slang
candidate.item = ctorItem;
candidate.resultType = resultType;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
ArrayView<Val*> knownGenericArgs,
+ ConversionCost& outBaseCost,
List<QualType> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
@@ -1469,7 +1477,7 @@ namespace Slang
// so that the solver knows to accept those arguments as-is.
//
return trySolveConstraintSystem(
- &constraints, genericDeclRef, knownGenericArgs);
+ &constraints, genericDeclRef, knownGenericArgs, outBaseCost);
}
void SemanticsVisitor::AddTypeOverloadCandidates(
@@ -1517,8 +1525,10 @@ namespace Slang
auto genericDeclRef = genericItem.declRef.as<GenericDecl>();
SLANG_ASSERT(genericDeclRef);
+ ConversionCost baseCost = kConversionCost_None;
+
// Try to infer generic arguments, based on the context
- DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs);
+ DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost);
if (innerRef)
{
@@ -1528,7 +1538,7 @@ namespace Slang
LookupResultItem innerItem;
innerItem.breadcrumbs = genericItem.breadcrumbs;
innerItem.declRef = innerRef;
- AddDeclRefOverloadCandidates(innerItem, context);
+ AddDeclRefOverloadCandidates(innerItem, context, baseCost);
}
else
{
@@ -1546,11 +1556,12 @@ namespace Slang
void SemanticsVisitor::AddDeclRefOverloadCandidates(
LookupResultItem item,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
if (auto funcDeclRef = item.declRef.as<CallableDecl>())
{
- AddFuncOverloadCandidate(item, funcDeclRef, context);
+ AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost);
}
else if (auto aggTypeDeclRef = item.declRef.as<AggTypeDecl>())
{
@@ -1584,7 +1595,7 @@ namespace Slang
const auto type = localDeclRef.getDecl()->getType();
// We can only add overload candidates if this is known to be a function
if(const auto funType = as<FuncType>(type))
- AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr);
+ AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost);
else
return;
}
@@ -1603,12 +1614,12 @@ namespace Slang
{
for(auto item : result.items)
{
- AddDeclRefOverloadCandidates(item, context);
+ AddDeclRefOverloadCandidates(item, context, kConversionCost_None);
}
}
else
{
- AddDeclRefOverloadCandidates(result.item, context);
+ AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None);
}
}
@@ -1633,17 +1644,17 @@ namespace Slang
// The expression directly referenced a declaration,
// so we can use that declaration directly to look
// for anything applicable.
- AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context);
+ AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None);
}
else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
{
// The expression is the result of a higher order function application.
- AddHigherOrderOverloadCandidates(higherOrderExpr, context);
+ AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None);
}
else if (auto funcType = as<FuncType>(funcExprType))
{
// TODO(tfoley): deprecate this path...
- AddFuncOverloadCandidate(funcType, context);
+ AddFuncOverloadCandidate(funcType, context, kConversionCost_None);
}
else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr))
{
@@ -1683,7 +1694,8 @@ namespace Slang
void SemanticsVisitor::AddHigherOrderOverloadCandidates(
Expr* funcExpr,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
@@ -1705,7 +1717,7 @@ namespace Slang
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
candidate.exprVal = expr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>())
{
@@ -1721,10 +1733,12 @@ namespace Slang
// Try to infer generic arguments, based on the updated context.
OverloadResolveContext subContext = context;
+ ConversionCost baseCost1 = kConversionCost_None;
DeclRef<Decl> innerRef = inferGenericArguments(
baseFuncGenericDeclRef,
context,
ArrayView<Val*>(),
+ baseCost1,
&paramTypes);
if (!innerRef)
@@ -1762,7 +1776,7 @@ namespace Slang
}
candidate.exprVal = expr;
expr->type.type = diffFuncType;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost + baseCost1);
}
else
{
@@ -1868,7 +1882,6 @@ namespace Slang
context.originalExpr = expr;
context.funcLoc = funcExpr->loc;
-
context.argCount = expr->arguments.getCount();
context.args = expr->arguments.getBuffer();
context.loc = expr->loc;
@@ -2039,7 +2052,7 @@ namespace Slang
candidate.item = baseItem;
candidate.resultType = nullptr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, kConversionCost_None);
}
}
diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp
index 91722f82c..7cd78a1bf 100644
--- a/source/slang/slang-check-resolve-val.cpp
+++ b/source/slang/slang-check-resolve-val.cpp
@@ -45,4 +45,14 @@ Val* SubtypeWitness::_resolveImplOverride()
return as<SubtypeWitness>(defaultResolveImpl());
}
+ConversionCost SubtypeWitness::_getOverloadResolutionCostOverride()
+{
+ return kConversionCost_None;
+}
+
+ConversionCost SubtypeWitness::getOverloadResolutionCost()
+{
+ SLANG_AST_NODE_VIRTUAL_CALL(SubtypeWitness, getOverloadResolutionCost, ());
+}
+
}