diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-29 11:29:14 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-29 11:29:14 -0800 |
| commit | 4fb3b10b81cf8c976ebd1ebb7fcde7708f022957 (patch) | |
| tree | 394a08e5b744fa85ac98c0b8758e994b0aab3a34 /source/slang/slang-check-overload.cpp | |
| parent | 62426e94ef11fd6baa213757f87114ec174b406e (diff) | |
Improve generic type argument inference. (#3370)
* Improve generic type argument inference.
* Fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 63 |
1 files changed, 38 insertions, 25 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2d7315cd2..d7d29a4e1 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1246,11 +1246,14 @@ namespace Slang void SemanticsVisitor::AddOverloadCandidate( OverloadResolveContext& context, - OverloadCandidate& candidate) + OverloadCandidate& candidate, + ConversionCost baseCost) { // Try the candidate out, to see if it is applicable at all. TryCheckOverloadCandidate(context, candidate); + candidate.conversionCostSum += baseCost; + // Now (potentially) add it to the set of candidate overloads to consider. AddOverloadCandidateInner(context, candidate); } @@ -1258,7 +1261,8 @@ namespace Slang void SemanticsVisitor::AddFuncOverloadCandidate( LookupResultItem item, DeclRef<CallableDecl> funcDeclRef, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { auto funcDecl = funcDeclRef.getDecl(); ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature); @@ -1288,25 +1292,27 @@ namespace Slang candidate.item = item; candidate.resultType = getResultType(m_astBuilder, funcDeclRef); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncOverloadCandidate( FuncType* funcType, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; candidate.funcType = funcType; candidate.resultType = funcType->getResultType(); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncExprOverloadCandidate( FuncType* funcType, OverloadResolveContext& context, - Expr* expr) + Expr* expr, + ConversionCost baseCost) { SLANG_ASSERT(expr); OverloadCandidate candidate; @@ -1315,7 +1321,7 @@ namespace Slang candidate.resultType = funcType->getResultType(); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddCtorOverloadCandidate( @@ -1323,7 +1329,8 @@ namespace Slang Type* type, DeclRef<ConstructorDecl> ctorDeclRef, OverloadResolveContext& context, - Type* resultType) + Type* resultType, + ConversionCost baseCost) { SLANG_UNUSED(type) @@ -1346,13 +1353,14 @@ namespace Slang candidate.item = ctorItem; candidate.resultType = resultType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } DeclRef<Decl> SemanticsVisitor::inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, ArrayView<Val*> knownGenericArgs, + ConversionCost& outBaseCost, List<QualType> *innerParameterTypes) { // We have been asked to infer zero or more arguments to @@ -1469,7 +1477,7 @@ namespace Slang // so that the solver knows to accept those arguments as-is. // return trySolveConstraintSystem( - &constraints, genericDeclRef, knownGenericArgs); + &constraints, genericDeclRef, knownGenericArgs, outBaseCost); } void SemanticsVisitor::AddTypeOverloadCandidates( @@ -1517,8 +1525,10 @@ namespace Slang auto genericDeclRef = genericItem.declRef.as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); + ConversionCost baseCost = kConversionCost_None; + // Try to infer generic arguments, based on the context - DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs); + DeclRef<Decl> innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost); if (innerRef) { @@ -1528,7 +1538,7 @@ namespace Slang LookupResultItem innerItem; innerItem.breadcrumbs = genericItem.breadcrumbs; innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context); + AddDeclRefOverloadCandidates(innerItem, context, baseCost); } else { @@ -1546,11 +1556,12 @@ namespace Slang void SemanticsVisitor::AddDeclRefOverloadCandidates( LookupResultItem item, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { if (auto funcDeclRef = item.declRef.as<CallableDecl>()) { - AddFuncOverloadCandidate(item, funcDeclRef, context); + AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost); } else if (auto aggTypeDeclRef = item.declRef.as<AggTypeDecl>()) { @@ -1584,7 +1595,7 @@ namespace Slang const auto type = localDeclRef.getDecl()->getType(); // We can only add overload candidates if this is known to be a function if(const auto funType = as<FuncType>(type)) - AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr); + AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost); else return; } @@ -1603,12 +1614,12 @@ namespace Slang { for(auto item : result.items) { - AddDeclRefOverloadCandidates(item, context); + AddDeclRefOverloadCandidates(item, context, kConversionCost_None); } } else { - AddDeclRefOverloadCandidates(result.item, context); + AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None); } } @@ -1633,17 +1644,17 @@ namespace Slang // The expression directly referenced a declaration, // so we can use that declaration directly to look // for anything applicable. - AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); + AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None); } else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) { // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context); + AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None); } else if (auto funcType = as<FuncType>(funcExprType)) { // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); + AddFuncOverloadCandidate(funcType, context, kConversionCost_None); } else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr)) { @@ -1683,7 +1694,8 @@ namespace Slang void SemanticsVisitor::AddHigherOrderOverloadCandidates( Expr* funcExpr, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an @@ -1705,7 +1717,7 @@ namespace Slang candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>()) { @@ -1721,10 +1733,12 @@ namespace Slang // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; + ConversionCost baseCost1 = kConversionCost_None; DeclRef<Decl> innerRef = inferGenericArguments( baseFuncGenericDeclRef, context, ArrayView<Val*>(), + baseCost1, ¶mTypes); if (!innerRef) @@ -1762,7 +1776,7 @@ namespace Slang } candidate.exprVal = expr; expr->type.type = diffFuncType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost + baseCost1); } else { @@ -1868,7 +1882,6 @@ namespace Slang context.originalExpr = expr; context.funcLoc = funcExpr->loc; - context.argCount = expr->arguments.getCount(); context.args = expr->arguments.getBuffer(); context.loc = expr->loc; @@ -2039,7 +2052,7 @@ namespace Slang candidate.item = baseItem; candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, kConversionCost_None); } } |
