diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-16 12:17:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-16 12:17:49 -0800 |
| commit | 801aa3b44254341018a1acbe754f2ce3b0900e2a (patch) | |
| tree | b3066778522edb99bf64c0ac80c91b0b4cb788f8 /source/slang/slang-check-overload.cpp | |
| parent | 09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff) | |
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions.
* Replace `goto` with `break` to pacify clang.
* Fix.
* Fixes.
* Fix more tests.
* Fix lowerWitnessTable parameter error.
* Exclude attributes from ast printing.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 187 |
1 files changed, 62 insertions, 125 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index fe9de9433..83774303b 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -714,22 +714,7 @@ namespace Slang callExpr->originalFunctionExpr = callExpr->functionExpr; callExpr->type = QualType(candidate.resultType); - - // If the callee is the result of a higher-order function invocation, - // set it's base function to the declaration corresponding to the - // resolved overload. - // - if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr)) - { - higherOrderInvoke->baseFunction = ConstructLookupResultExpr( - candidate.item, - baseExpr, - higherOrderInvoke->loc, - callExpr->functionExpr); - - higherOrderInvoke->type = candidate.funcType; - } - + callExpr->functionExpr = candidate.exprVal; return callExpr; } @@ -1252,10 +1237,19 @@ namespace Slang // to match it up with the arguments accordingly... if (auto funcDeclRef = partiallySpecializedInnerRef.as<CallableDecl>()) { - auto params = getParameters(funcDeclRef).toArray(); + List<Type*> paramTypes; + if (!innerParameterTypes) + { + auto params = getParameters(funcDeclRef).toArray(); + for (auto param : params) + { + paramTypes.add(getType(m_astBuilder, param)); + } + innerParameterTypes = ¶mTypes; + } Index valueArgCount = context.getArgCount(); - Index valueParamCount = params.getCount(); + Index valueParamCount = innerParameterTypes->getCount(); // If there are too many arguments, we cannot possibly have a match. // @@ -1295,7 +1289,7 @@ namespace Slang TryUnifyTypes( constraints, context.getArgTypeForInference(aa, this), - (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]); + (*innerParameterTypes)[aa]); } } else @@ -1495,6 +1489,11 @@ namespace Slang // for anything applicable. AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); } + else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context); + } else if (auto funcType = as<FuncType>(funcExprType)) { // TODO(tfoley): deprecate this path... @@ -1511,11 +1510,6 @@ namespace Slang AddOverloadCandidates(item, context); } } - else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) - { - // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context); - } else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr)) { // A partially-applied generic is allowed as an overload candidate, @@ -1550,90 +1544,43 @@ namespace Slang // if-else ladder. if (auto expr = as<HigherOrderInvokeExpr>(funcExpr)) { - if (auto origFuncType = as<FuncType>(expr->baseFunction->type)) + auto funcDeclRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(expr->baseFunction)); + if (!funcDeclRefExpr) + return; + if (auto baseFuncDeclRef = funcDeclRefExpr->declRef.as<CallableDecl>()) { - - auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>(); - SLANG_ASSERT(baseFuncDeclRef); - + // Base is a normal or fully specialized generic function. OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-decl-ref) - candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + if (auto diffExpr = as<DifferentiateExpr>(expr)) { - // Case: __bwd_diff(name-resolved-to-decl-ref) - candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType)); + candidate.funcType = as<FuncType>(diffExpr->type.type); } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); - + candidate.exprVal = expr; AddOverloadCandidate(context, candidate); } - else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type)) - { - - if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction)) - { - for (auto item : overloadExpr->lookupResult2.items) - { - auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); - if (!funcType) - continue; - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-decl-ref) - funcType = as<FuncType>(processJVPFuncType(funcType)); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) - { - // Case: __bwd_diff(name-resolved-to-decl-ref) - funcType = as<FuncType>(processBackwardDiffFuncType(funcType)); - } - if (!funcType) - continue; - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = funcType; - candidate.resultType = candidate.funcType->getResultType(); - candidate.item = LookupResultItem(item.declRef); - - AddOverloadCandidate(context, candidate); - } - } - else - { - // Unhandled overload expr. - funcExpr->type = this->getASTBuilder()->getErrorType(); - getSink()->diagnose(funcExpr->loc, - Diagnostics::unimplemented, - funcExpr->type); - } - } - else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>()) + else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>()) { - // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( getInner(baseFuncGenericDeclRef), baseFuncGenericDeclRef.substitutions); - - // Pull parameter list of inner function. - auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); // Process func type to generate JVP func type. - auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ? - as<FuncType>(processJVPFuncType(funcType)) : - as<FuncType>(processBackwardDiffFuncType(funcType)); + auto diffFuncType = as<FuncType>(expr->type.type); + if (!diffFuncType) + { + // This shouldn't happen, but we check to be safe. + return; + } // Extract parameter list from processed type. List<Type*> paramTypes; - for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++) - paramTypes.add(jvpFuncType->getParamType(ii)); + for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++) + paramTypes.add(diffFuncType->getParamType(ii)); // Try to infer generic arguments, based on the updated context. DeclRef<Decl> innerRef = inferGenericArguments( @@ -1641,39 +1588,39 @@ namespace Slang context, nullptr, ¶mTypes); - + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; if (innerRef) { - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - - // Note that we call processJVPFuncType() again here - // in order to process the specialized version of the original func type. - // This could potentially be a declRef.substitute(jvpFuncType) - // - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-generic-decl) - candidate.funcType = as<FuncType>(processJVPFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) - { - // Case: __bwd_diff(name-resolved-to-generic-decl) - candidate.funcType = as<FuncType>(processBackwardDiffFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); - } - - candidate.resultType = candidate.funcType->getResultType(); + diffFuncType = as<FuncType>(innerRef.substitute(m_astBuilder, diffFuncType)); candidate.item = LookupResultItem(innerRef); - - AddOverloadCandidate(context, candidate); } else { - SLANG_UNEXPECTED("Could not resolve generic candidate"); + candidate.item = LookupResultItem(funcDeclRefExpr->declRef); } + candidate.funcType = as<FuncType>(diffFuncType); + candidate.resultType = candidate.funcType->getResultType(); + // Substitute all types in the high-order expression chain. + Expr* inner = expr; + HigherOrderInvokeExpr* lastInner = nullptr; + while (auto hoInner = as<HigherOrderInvokeExpr>(inner)) + { + lastInner = hoInner; + hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); + inner = hoInner->baseFunction; + } + // Set inner expression to resolved declref expr. + if (lastInner) + { + auto baseExpr = GetBaseExpr(funcDeclRefExpr); + lastInner->baseFunction = ConstructLookupResultExpr(candidate.item, baseExpr, funcDeclRefExpr->loc, funcDeclRefExpr); + } + candidate.exprVal = expr; + expr->type.type = diffFuncType; + AddOverloadCandidate(context, candidate); } else { @@ -1683,6 +1630,7 @@ namespace Slang Diagnostics::expectedFunction, funcExpr->type); } + } } @@ -1769,18 +1717,7 @@ namespace Slang context.args = expr->arguments.getBuffer(); context.loc = expr->loc; - if (auto funcMemberExpr = as<MemberExpr>(funcExpr)) - { - context.baseExpr = funcMemberExpr->baseExpression; - } - else if (auto funcOverloadExpr = as<OverloadedExpr>(funcExpr)) - { - context.baseExpr = funcOverloadExpr->base; - } - else if (auto funcOverloadExpr2 = as<OverloadedExpr2>(funcExpr)) - { - context.baseExpr = funcOverloadExpr2->base; - } + context.baseExpr = GetBaseExpr(funcExpr); // TODO: We should have a special case here where an `InvokeExpr` // with a single argument where the base/func expression names |
