summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-overload.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-16 12:17:49 -0800
committerGitHub <noreply@github.com>2022-11-16 12:17:49 -0800
commit801aa3b44254341018a1acbe754f2ce3b0900e2a (patch)
treeb3066778522edb99bf64c0ac80c91b0b4cb788f8 /source/slang/slang-check-overload.cpp
parent09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff)
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions. * Replace `goto` with `break` to pacify clang. * Fix. * Fixes. * Fix more tests. * Fix lowerWitnessTable parameter error. * Exclude attributes from ast printing. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
-rw-r--r--source/slang/slang-check-overload.cpp187
1 files changed, 62 insertions, 125 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index fe9de9433..83774303b 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -714,22 +714,7 @@ namespace Slang
callExpr->originalFunctionExpr = callExpr->functionExpr;
callExpr->type = QualType(candidate.resultType);
-
- // If the callee is the result of a higher-order function invocation,
- // set it's base function to the declaration corresponding to the
- // resolved overload.
- //
- if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr))
- {
- higherOrderInvoke->baseFunction = ConstructLookupResultExpr(
- candidate.item,
- baseExpr,
- higherOrderInvoke->loc,
- callExpr->functionExpr);
-
- higherOrderInvoke->type = candidate.funcType;
- }
-
+ callExpr->functionExpr = candidate.exprVal;
return callExpr;
}
@@ -1252,10 +1237,19 @@ namespace Slang
// to match it up with the arguments accordingly...
if (auto funcDeclRef = partiallySpecializedInnerRef.as<CallableDecl>())
{
- auto params = getParameters(funcDeclRef).toArray();
+ List<Type*> paramTypes;
+ if (!innerParameterTypes)
+ {
+ auto params = getParameters(funcDeclRef).toArray();
+ for (auto param : params)
+ {
+ paramTypes.add(getType(m_astBuilder, param));
+ }
+ innerParameterTypes = &paramTypes;
+ }
Index valueArgCount = context.getArgCount();
- Index valueParamCount = params.getCount();
+ Index valueParamCount = innerParameterTypes->getCount();
// If there are too many arguments, we cannot possibly have a match.
//
@@ -1295,7 +1289,7 @@ namespace Slang
TryUnifyTypes(
constraints,
context.getArgTypeForInference(aa, this),
- (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]);
+ (*innerParameterTypes)[aa]);
}
}
else
@@ -1495,6 +1489,11 @@ namespace Slang
// for anything applicable.
AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context);
}
+ else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
+ {
+ // The expression is the result of a higher order function application.
+ AddHigherOrderOverloadCandidates(higherOrderExpr, context);
+ }
else if (auto funcType = as<FuncType>(funcExprType))
{
// TODO(tfoley): deprecate this path...
@@ -1511,11 +1510,6 @@ namespace Slang
AddOverloadCandidates(item, context);
}
}
- else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
- {
- // The expression is the result of a higher order function application.
- AddHigherOrderOverloadCandidates(higherOrderExpr, context);
- }
else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr))
{
// A partially-applied generic is allowed as an overload candidate,
@@ -1550,90 +1544,43 @@ namespace Slang
// if-else ladder.
if (auto expr = as<HigherOrderInvokeExpr>(funcExpr))
{
- if (auto origFuncType = as<FuncType>(expr->baseFunction->type))
+ auto funcDeclRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(expr->baseFunction));
+ if (!funcDeclRefExpr)
+ return;
+ if (auto baseFuncDeclRef = funcDeclRefExpr->declRef.as<CallableDecl>())
{
-
- auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>();
- SLANG_ASSERT(baseFuncDeclRef);
-
+ // Base is a normal or fully specialized generic function.
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
- {
- // Case: __fwd_diff(name-resolved-to-decl-ref)
- candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
- }
- else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ if (auto diffExpr = as<DifferentiateExpr>(expr))
{
- // Case: __bwd_diff(name-resolved-to-decl-ref)
- candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType));
+ candidate.funcType = as<FuncType>(diffExpr->type.type);
}
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
-
+ candidate.exprVal = expr;
AddOverloadCandidate(context, candidate);
}
- else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type))
- {
-
- if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction))
- {
- for (auto item : overloadExpr->lookupResult2.items)
- {
- auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc));
- if (!funcType)
- continue;
- if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
- {
- // Case: __fwd_diff(name-resolved-to-decl-ref)
- funcType = as<FuncType>(processJVPFuncType(funcType));
- }
- else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
- {
- // Case: __bwd_diff(name-resolved-to-decl-ref)
- funcType = as<FuncType>(processBackwardDiffFuncType(funcType));
- }
- if (!funcType)
- continue;
- OverloadCandidate candidate;
- candidate.flavor = OverloadCandidate::Flavor::Expr;
- candidate.funcType = funcType;
- candidate.resultType = candidate.funcType->getResultType();
- candidate.item = LookupResultItem(item.declRef);
-
- AddOverloadCandidate(context, candidate);
- }
- }
- else
- {
- // Unhandled overload expr.
- funcExpr->type = this->getASTBuilder()->getErrorType();
- getSink()->diagnose(funcExpr->loc,
- Diagnostics::unimplemented,
- funcExpr->type);
- }
- }
- else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>())
+ else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>())
{
-
// Get inner function
DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
getInner(baseFuncGenericDeclRef),
baseFuncGenericDeclRef.substitutions);
-
- // Pull parameter list of inner function.
- auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
// Process func type to generate JVP func type.
- auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ?
- as<FuncType>(processJVPFuncType(funcType)) :
- as<FuncType>(processBackwardDiffFuncType(funcType));
+ auto diffFuncType = as<FuncType>(expr->type.type);
+ if (!diffFuncType)
+ {
+ // This shouldn't happen, but we check to be safe.
+ return;
+ }
// Extract parameter list from processed type.
List<Type*> paramTypes;
- for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++)
- paramTypes.add(jvpFuncType->getParamType(ii));
+ for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++)
+ paramTypes.add(diffFuncType->getParamType(ii));
// Try to infer generic arguments, based on the updated context.
DeclRef<Decl> innerRef = inferGenericArguments(
@@ -1641,39 +1588,39 @@ namespace Slang
context,
nullptr,
&paramTypes);
-
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
if (innerRef)
{
- OverloadCandidate candidate;
- candidate.flavor = OverloadCandidate::Flavor::Expr;
-
- // Note that we call processJVPFuncType() again here
- // in order to process the specialized version of the original func type.
- // This could potentially be a declRef.substitute(jvpFuncType)
- //
- if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
- {
- // Case: __fwd_diff(name-resolved-to-generic-decl)
- candidate.funcType = as<FuncType>(processJVPFuncType(
- getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
- }
- else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
- {
- // Case: __bwd_diff(name-resolved-to-generic-decl)
- candidate.funcType = as<FuncType>(processBackwardDiffFuncType(
- getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
- }
-
- candidate.resultType = candidate.funcType->getResultType();
+ diffFuncType = as<FuncType>(innerRef.substitute(m_astBuilder, diffFuncType));
candidate.item = LookupResultItem(innerRef);
-
- AddOverloadCandidate(context, candidate);
}
else
{
- SLANG_UNEXPECTED("Could not resolve generic candidate");
+ candidate.item = LookupResultItem(funcDeclRefExpr->declRef);
}
+ candidate.funcType = as<FuncType>(diffFuncType);
+ candidate.resultType = candidate.funcType->getResultType();
+ // Substitute all types in the high-order expression chain.
+ Expr* inner = expr;
+ HigherOrderInvokeExpr* lastInner = nullptr;
+ while (auto hoInner = as<HigherOrderInvokeExpr>(inner))
+ {
+ lastInner = hoInner;
+ hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type);
+ inner = hoInner->baseFunction;
+ }
+ // Set inner expression to resolved declref expr.
+ if (lastInner)
+ {
+ auto baseExpr = GetBaseExpr(funcDeclRefExpr);
+ lastInner->baseFunction = ConstructLookupResultExpr(candidate.item, baseExpr, funcDeclRefExpr->loc, funcDeclRefExpr);
+ }
+ candidate.exprVal = expr;
+ expr->type.type = diffFuncType;
+ AddOverloadCandidate(context, candidate);
}
else
{
@@ -1683,6 +1630,7 @@ namespace Slang
Diagnostics::expectedFunction,
funcExpr->type);
}
+
}
}
@@ -1769,18 +1717,7 @@ namespace Slang
context.args = expr->arguments.getBuffer();
context.loc = expr->loc;
- if (auto funcMemberExpr = as<MemberExpr>(funcExpr))
- {
- context.baseExpr = funcMemberExpr->baseExpression;
- }
- else if (auto funcOverloadExpr = as<OverloadedExpr>(funcExpr))
- {
- context.baseExpr = funcOverloadExpr->base;
- }
- else if (auto funcOverloadExpr2 = as<OverloadedExpr2>(funcExpr))
- {
- context.baseExpr = funcOverloadExpr2->base;
- }
+ context.baseExpr = GetBaseExpr(funcExpr);
// TODO: We should have a special case here where an `InvokeExpr`
// with a single argument where the base/func expression names