summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-overload.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-11-29 11:29:14 -0800
committerGitHub <noreply@github.com>2023-11-29 11:29:14 -0800
commit4fb3b10b81cf8c976ebd1ebb7fcde7708f022957 (patch)
tree394a08e5b744fa85ac98c0b8758e994b0aab3a34 /source/slang/slang-check-overload.cpp
parent62426e94ef11fd6baa213757f87114ec174b406e (diff)
Improve generic type argument inference. (#3370)
* Improve generic type argument inference. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
-rw-r--r--source/slang/slang-check-overload.cpp63
1 files changed, 38 insertions, 25 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 2d7315cd2..d7d29a4e1 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1246,11 +1246,14 @@ namespace Slang
void SemanticsVisitor::AddOverloadCandidate(
OverloadResolveContext& context,
- OverloadCandidate& candidate)
+ OverloadCandidate& candidate,
+ ConversionCost baseCost)
{
// Try the candidate out, to see if it is applicable at all.
TryCheckOverloadCandidate(context, candidate);
+ candidate.conversionCostSum += baseCost;
+
// Now (potentially) add it to the set of candidate overloads to consider.
AddOverloadCandidateInner(context, candidate);
}
@@ -1258,7 +1261,8 @@ namespace Slang
void SemanticsVisitor::AddFuncOverloadCandidate(
LookupResultItem item,
DeclRef<CallableDecl> funcDeclRef,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
auto funcDecl = funcDeclRef.getDecl();
ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature);
@@ -1288,25 +1292,27 @@ namespace Slang
candidate.item = item;
candidate.resultType = getResultType(m_astBuilder, funcDeclRef);
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddFuncOverloadCandidate(
FuncType* funcType,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
candidate.funcType = funcType;
candidate.resultType = funcType->getResultType();
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddFuncExprOverloadCandidate(
FuncType* funcType,
OverloadResolveContext& context,
- Expr* expr)
+ Expr* expr,
+ ConversionCost baseCost)
{
SLANG_ASSERT(expr);
OverloadCandidate candidate;
@@ -1315,7 +1321,7 @@ namespace Slang
candidate.resultType = funcType->getResultType();
candidate.exprVal = expr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
void SemanticsVisitor::AddCtorOverloadCandidate(
@@ -1323,7 +1329,8 @@ namespace Slang
Type* type,
DeclRef<ConstructorDecl> ctorDeclRef,
OverloadResolveContext& context,
- Type* resultType)
+ Type* resultType,
+ ConversionCost baseCost)
{
SLANG_UNUSED(type)
@@ -1346,13 +1353,14 @@ namespace Slang
candidate.item = ctorItem;
candidate.resultType = resultType;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
ArrayView<Val*> knownGenericArgs,
+ ConversionCost& outBaseCost,
List<QualType> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
@@ -1469,7 +1477,7 @@ namespace Slang
// so that the solver knows to accept those arguments as-is.
//
return trySolveConstraintSystem(
- &constraints, genericDeclRef, knownGenericArgs);
+ &constraints, genericDeclRef, knownGenericArgs, outBaseCost);
}
void SemanticsVisitor::AddTypeOverloadCandidates(
@@ -1517,8 +1525,10 @@ namespace Slang
auto genericDeclRef = genericItem.declRef.as<GenericDecl>();
SLANG_ASSERT(genericDeclRef);
+ ConversionCost baseCost = kConversionCost_None;
+
// Try to infer generic arguments, based on the context
- DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs);
+ DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost);
if (innerRef)
{
@@ -1528,7 +1538,7 @@ namespace Slang
LookupResultItem innerItem;
innerItem.breadcrumbs = genericItem.breadcrumbs;
innerItem.declRef = innerRef;
- AddDeclRefOverloadCandidates(innerItem, context);
+ AddDeclRefOverloadCandidates(innerItem, context, baseCost);
}
else
{
@@ -1546,11 +1556,12 @@ namespace Slang
void SemanticsVisitor::AddDeclRefOverloadCandidates(
LookupResultItem item,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
if (auto funcDeclRef = item.declRef.as<CallableDecl>())
{
- AddFuncOverloadCandidate(item, funcDeclRef, context);
+ AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost);
}
else if (auto aggTypeDeclRef = item.declRef.as<AggTypeDecl>())
{
@@ -1584,7 +1595,7 @@ namespace Slang
const auto type = localDeclRef.getDecl()->getType();
// We can only add overload candidates if this is known to be a function
if(const auto funType = as<FuncType>(type))
- AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr);
+ AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost);
else
return;
}
@@ -1603,12 +1614,12 @@ namespace Slang
{
for(auto item : result.items)
{
- AddDeclRefOverloadCandidates(item, context);
+ AddDeclRefOverloadCandidates(item, context, kConversionCost_None);
}
}
else
{
- AddDeclRefOverloadCandidates(result.item, context);
+ AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None);
}
}
@@ -1633,17 +1644,17 @@ namespace Slang
// The expression directly referenced a declaration,
// so we can use that declaration directly to look
// for anything applicable.
- AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context);
+ AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None);
}
else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
{
// The expression is the result of a higher order function application.
- AddHigherOrderOverloadCandidates(higherOrderExpr, context);
+ AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None);
}
else if (auto funcType = as<FuncType>(funcExprType))
{
// TODO(tfoley): deprecate this path...
- AddFuncOverloadCandidate(funcType, context);
+ AddFuncOverloadCandidate(funcType, context, kConversionCost_None);
}
else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr))
{
@@ -1683,7 +1694,8 @@ namespace Slang
void SemanticsVisitor::AddHigherOrderOverloadCandidates(
Expr* funcExpr,
- OverloadResolveContext& context)
+ OverloadResolveContext& context,
+ ConversionCost baseCost)
{
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
@@ -1705,7 +1717,7 @@ namespace Slang
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
candidate.exprVal = expr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost);
}
else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>())
{
@@ -1721,10 +1733,12 @@ namespace Slang
// Try to infer generic arguments, based on the updated context.
OverloadResolveContext subContext = context;
+ ConversionCost baseCost1 = kConversionCost_None;
DeclRef<Decl> innerRef = inferGenericArguments(
baseFuncGenericDeclRef,
context,
ArrayView<Val*>(),
+ baseCost1,
&paramTypes);
if (!innerRef)
@@ -1762,7 +1776,7 @@ namespace Slang
}
candidate.exprVal = expr;
expr->type.type = diffFuncType;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, baseCost + baseCost1);
}
else
{
@@ -1868,7 +1882,6 @@ namespace Slang
context.originalExpr = expr;
context.funcLoc = funcExpr->loc;
-
context.argCount = expr->arguments.getCount();
context.args = expr->arguments.getBuffer();
context.loc = expr->loc;
@@ -2039,7 +2052,7 @@ namespace Slang
candidate.item = baseItem;
candidate.resultType = nullptr;
- AddOverloadCandidate(context, candidate);
+ AddOverloadCandidate(context, candidate, kConversionCost_None);
}
}