summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-overload.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-check-overload.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
-rw-r--r--source/slang/slang-check-overload.cpp142
1 files changed, 139 insertions, 3 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 7dba3986a..eadf2f63d 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -715,6 +715,21 @@ namespace Slang
callExpr->originalFunctionExpr = callExpr->functionExpr;
callExpr->type = QualType(candidate.resultType);
+ // If the callee is the result of a higher-order function invocation,
+ // set it's base function to the declaration corresponding to the
+ // resolved overload.
+ //
+ if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr))
+ {
+ higherOrderInvoke->baseFunction = ConstructLookupResultExpr(
+ candidate.item,
+ baseExpr,
+ higherOrderInvoke->loc,
+ callExpr->functionExpr);
+
+ higherOrderInvoke->type = candidate.funcType;
+ }
+
return callExpr;
}
@@ -1174,7 +1189,8 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
- GenericSubstitution* substWithKnownGenericArgs)
+ GenericSubstitution* substWithKnownGenericArgs,
+ List<Type*> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
// `genericDeclRef`, in a context where it is being applied
@@ -1279,7 +1295,7 @@ namespace Slang
TryUnifyTypes(
constraints,
context.getArgTypeForInference(aa, this),
- getType(m_astBuilder, params[aa]));
+ (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]);
}
}
else
@@ -1495,6 +1511,11 @@ namespace Slang
AddOverloadCandidates(item, context);
}
}
+ else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr))
+ {
+ // The expression is the result of a higher order function application.
+ AddHigherOrderOverloadCandidates(higherOrderExpr, context);
+ }
else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr))
{
// A partially-applied generic is allowed as an overload candidate,
@@ -1520,6 +1541,121 @@ namespace Slang
}
}
+ void SemanticsVisitor::AddHigherOrderOverloadCandidates(
+ Expr* funcExpr,
+ OverloadResolveContext& context)
+ {
+ // Lookup the higher order function and process types accordingly. In the future,
+ // if there are enough varieties, we can have dispatch logic instead of an
+ // if-else ladder.
+ if (auto jvpExpr = as<JVPDifferentiateExpr>(funcExpr))
+ {
+ if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
+ {
+ // Case: __jvp(name-resolved-to-decl-ref)
+
+ auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
+ SLANG_ASSERT(baseFuncDeclRef);
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType));
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(baseFuncDeclRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
+ {
+ // Case: __jvp(name-resolved-to-multiple-decl-ref)
+
+ if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
+ {
+ for (auto item : overloadExpr->lookupResult2.items)
+ {
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ this->getASTBuilder(),
+ as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc))));
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(item.declRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ }
+ else
+ {
+ // Unhandled overload expr.
+ funcExpr->type = this->getASTBuilder()->getErrorType();
+ getSink()->diagnose(funcExpr->loc,
+ Diagnostics::unimplemented,
+ funcExpr->type);
+ }
+ }
+ else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
+ {
+ // Case: __jvp(name-resolved-to-generic-decl)
+
+ // Get inner function
+ DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
+ getInner(baseFuncGenericDeclRef),
+ baseFuncGenericDeclRef.substitutions);
+
+ // Pull parameter list of inner function.
+ auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
+
+ // Process func type to generate JVP func type.
+ auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType));
+
+ // Extract parameter list from processed type.
+ List<Type*> paramTypes;
+
+ for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++)
+ paramTypes.add(jvpFuncType->getParamType(ii));
+
+ // Try to infer generic arguments, based on the updated context.
+ DeclRef<Decl> innerRef = inferGenericArguments(
+ baseFuncGenericDeclRef,
+ context,
+ nullptr,
+ &paramTypes);
+
+ if (innerRef)
+ {
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+
+ // Note that we call processJVPFuncType() again here
+ // in order to process the specialized version of the original func type.
+ // This could potentially be a declRef.substitute(jvpFuncType)
+ //
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ this->getASTBuilder(),
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+
+ candidate.resultType = candidate.funcType->getResultType();
+ candidate.item = LookupResultItem(innerRef);
+
+ AddOverloadCandidate(context, candidate);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Could not resolve generic candidate");
+ }
+
+ }
+ else
+ {
+ // Unhandled case for the inner expr.
+ funcExpr->type = this->getASTBuilder()->getErrorType();
+ getSink()->diagnose(funcExpr->loc,
+ Diagnostics::expectedFunction,
+ funcExpr->type);
+ }
+ }
+ }
+
String SemanticsVisitor::getCallSignatureString(
OverloadResolveContext& context)
{
@@ -1627,8 +1763,8 @@ namespace Slang
// without needing dummy initializer/constructor declarations.
//
// Handling that special casing here (rather than in, say,
- // `visitTypeCastExpr`) would allow us to continue to ensure
// that `(T) expr` and `T(expr)` continue to be semantically
+ // `visitTypeCastExpr`) would allow us to continue to ensure
// equivalent in (almost) all cases.
if (!context.bestCandidate)