diff options
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 142 |
1 files changed, 139 insertions, 3 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 7dba3986a..eadf2f63d 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -715,6 +715,21 @@ namespace Slang callExpr->originalFunctionExpr = callExpr->functionExpr; callExpr->type = QualType(candidate.resultType); + // If the callee is the result of a higher-order function invocation, + // set it's base function to the declaration corresponding to the + // resolved overload. + // + if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr)) + { + higherOrderInvoke->baseFunction = ConstructLookupResultExpr( + candidate.item, + baseExpr, + higherOrderInvoke->loc, + callExpr->functionExpr); + + higherOrderInvoke->type = candidate.funcType; + } + return callExpr; } @@ -1174,7 +1189,8 @@ namespace Slang DeclRef<Decl> SemanticsVisitor::inferGenericArguments( DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs) + GenericSubstitution* substWithKnownGenericArgs, + List<Type*> *innerParameterTypes) { // We have been asked to infer zero or more arguments to // `genericDeclRef`, in a context where it is being applied @@ -1279,7 +1295,7 @@ namespace Slang TryUnifyTypes( constraints, context.getArgTypeForInference(aa, this), - getType(m_astBuilder, params[aa])); + (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]); } } else @@ -1495,6 +1511,11 @@ namespace Slang AddOverloadCandidates(item, context); } } + else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context); + } else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr)) { // A partially-applied generic is allowed as an overload candidate, @@ -1520,6 +1541,121 @@ namespace Slang } } + void SemanticsVisitor::AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context) + { + // Lookup the higher order function and process types accordingly. In the future, + // if there are enough varieties, we can have dispatch logic instead of an + // if-else ladder. + if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr)) + { + if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-decl-ref) + + auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); + SLANG_ASSERT(baseFuncDeclRef); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType)); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(baseFuncDeclRef); + + AddOverloadCandidate(context, candidate); + } + else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-multiple-decl-ref) + + if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) + { + for (auto item : overloadExpr->lookupResult2.items) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as<FuncType>(processJVPFuncType( + this->getASTBuilder(), + as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(item.declRef); + + AddOverloadCandidate(context, candidate); + } + } + else + { + // Unhandled overload expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::unimplemented, + funcExpr->type); + } + } + else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) + { + // Case: __jvp(name-resolved-to-generic-decl) + + // Get inner function + DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( + getInner(baseFuncGenericDeclRef), + baseFuncGenericDeclRef.substitutions); + + // Pull parameter list of inner function. + auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); + + // Process func type to generate JVP func type. + auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType)); + + // Extract parameter list from processed type. + List<Type*> paramTypes; + + for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++) + paramTypes.add(jvpFuncType->getParamType(ii)); + + // Try to infer generic arguments, based on the updated context. + DeclRef<Decl> innerRef = inferGenericArguments( + baseFuncGenericDeclRef, + context, + nullptr, + ¶mTypes); + + if (innerRef) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + + // Note that we call processJVPFuncType() again here + // in order to process the specialized version of the original func type. + // This could potentially be a declRef.substitute(jvpFuncType) + // + candidate.funcType = as<FuncType>(processJVPFuncType( + this->getASTBuilder(), + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(innerRef); + + AddOverloadCandidate(context, candidate); + } + else + { + SLANG_UNEXPECTED("Could not resolve generic candidate"); + } + + } + else + { + // Unhandled case for the inner expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::expectedFunction, + funcExpr->type); + } + } + } + String SemanticsVisitor::getCallSignatureString( OverloadResolveContext& context) { @@ -1627,8 +1763,8 @@ namespace Slang // without needing dummy initializer/constructor declarations. // // Handling that special casing here (rather than in, say, - // `visitTypeCastExpr`) would allow us to continue to ensure // that `(T) expr` and `T(expr)` continue to be semantically + // `visitTypeCastExpr`) would allow us to continue to ensure // equivalent in (almost) all cases. if (!context.bestCandidate) |
