From 4fb3b10b81cf8c976ebd1ebb7fcde7708f022957 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 29 Nov 2023 11:29:14 -0800 Subject: Improve generic type argument inference. (#3370) * Improve generic type argument inference. * Fix. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-check-overload.cpp | 63 +++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 25 deletions(-) (limited to 'source/slang/slang-check-overload.cpp') diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2d7315cd2..d7d29a4e1 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1246,11 +1246,14 @@ namespace Slang void SemanticsVisitor::AddOverloadCandidate( OverloadResolveContext& context, - OverloadCandidate& candidate) + OverloadCandidate& candidate, + ConversionCost baseCost) { // Try the candidate out, to see if it is applicable at all. TryCheckOverloadCandidate(context, candidate); + candidate.conversionCostSum += baseCost; + // Now (potentially) add it to the set of candidate overloads to consider. AddOverloadCandidateInner(context, candidate); } @@ -1258,7 +1261,8 @@ namespace Slang void SemanticsVisitor::AddFuncOverloadCandidate( LookupResultItem item, DeclRef funcDeclRef, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { auto funcDecl = funcDeclRef.getDecl(); ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature); @@ -1288,25 +1292,27 @@ namespace Slang candidate.item = item; candidate.resultType = getResultType(m_astBuilder, funcDeclRef); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncOverloadCandidate( FuncType* funcType, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; candidate.funcType = funcType; candidate.resultType = funcType->getResultType(); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncExprOverloadCandidate( FuncType* funcType, OverloadResolveContext& context, - Expr* expr) + Expr* expr, + ConversionCost baseCost) { SLANG_ASSERT(expr); OverloadCandidate candidate; @@ -1315,7 +1321,7 @@ namespace Slang candidate.resultType = funcType->getResultType(); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddCtorOverloadCandidate( @@ -1323,7 +1329,8 @@ namespace Slang Type* type, DeclRef ctorDeclRef, OverloadResolveContext& context, - Type* resultType) + Type* resultType, + ConversionCost baseCost) { SLANG_UNUSED(type) @@ -1346,13 +1353,14 @@ namespace Slang candidate.item = ctorItem; candidate.resultType = resultType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } DeclRef SemanticsVisitor::inferGenericArguments( DeclRef genericDeclRef, OverloadResolveContext& context, ArrayView knownGenericArgs, + ConversionCost& outBaseCost, List *innerParameterTypes) { // We have been asked to infer zero or more arguments to @@ -1469,7 +1477,7 @@ namespace Slang // so that the solver knows to accept those arguments as-is. // return trySolveConstraintSystem( - &constraints, genericDeclRef, knownGenericArgs); + &constraints, genericDeclRef, knownGenericArgs, outBaseCost); } void SemanticsVisitor::AddTypeOverloadCandidates( @@ -1517,8 +1525,10 @@ namespace Slang auto genericDeclRef = genericItem.declRef.as(); SLANG_ASSERT(genericDeclRef); + ConversionCost baseCost = kConversionCost_None; + // Try to infer generic arguments, based on the context - DeclRef innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs); + DeclRef innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost); if (innerRef) { @@ -1528,7 +1538,7 @@ namespace Slang LookupResultItem innerItem; innerItem.breadcrumbs = genericItem.breadcrumbs; innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context); + AddDeclRefOverloadCandidates(innerItem, context, baseCost); } else { @@ -1546,11 +1556,12 @@ namespace Slang void SemanticsVisitor::AddDeclRefOverloadCandidates( LookupResultItem item, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { if (auto funcDeclRef = item.declRef.as()) { - AddFuncOverloadCandidate(item, funcDeclRef, context); + AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost); } else if (auto aggTypeDeclRef = item.declRef.as()) { @@ -1584,7 +1595,7 @@ namespace Slang const auto type = localDeclRef.getDecl()->getType(); // We can only add overload candidates if this is known to be a function if(const auto funType = as(type)) - AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr); + AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost); else return; } @@ -1603,12 +1614,12 @@ namespace Slang { for(auto item : result.items) { - AddDeclRefOverloadCandidates(item, context); + AddDeclRefOverloadCandidates(item, context, kConversionCost_None); } } else { - AddDeclRefOverloadCandidates(result.item, context); + AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None); } } @@ -1633,17 +1644,17 @@ namespace Slang // The expression directly referenced a declaration, // so we can use that declaration directly to look // for anything applicable. - AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); + AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None); } else if (auto higherOrderExpr = as(funcExpr)) { // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context); + AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None); } else if (auto funcType = as(funcExprType)) { // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); + AddFuncOverloadCandidate(funcType, context, kConversionCost_None); } else if (auto overloadedExpr = as(funcExpr)) { @@ -1683,7 +1694,8 @@ namespace Slang void SemanticsVisitor::AddHigherOrderOverloadCandidates( Expr* funcExpr, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an @@ -1705,7 +1717,7 @@ namespace Slang candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as()) { @@ -1721,10 +1733,12 @@ namespace Slang // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; + ConversionCost baseCost1 = kConversionCost_None; DeclRef innerRef = inferGenericArguments( baseFuncGenericDeclRef, context, ArrayView(), + baseCost1, ¶mTypes); if (!innerRef) @@ -1762,7 +1776,7 @@ namespace Slang } candidate.exprVal = expr; expr->type.type = diffFuncType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost + baseCost1); } else { @@ -1868,7 +1882,6 @@ namespace Slang context.originalExpr = expr; context.funcLoc = funcExpr->loc; - context.argCount = expr->arguments.getCount(); context.args = expr->arguments.getBuffer(); context.loc = expr->loc; @@ -2039,7 +2052,7 @@ namespace Slang candidate.item = baseItem; candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, kConversionCost_None); } } -- cgit v1.2.3