From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/slang-check-overload.cpp | 142 +++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 3 deletions(-) (limited to 'source/slang/slang-check-overload.cpp') diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 7dba3986a..eadf2f63d 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -715,6 +715,21 @@ namespace Slang callExpr->originalFunctionExpr = callExpr->functionExpr; callExpr->type = QualType(candidate.resultType); + // If the callee is the result of a higher-order function invocation, + // set it's base function to the declaration corresponding to the + // resolved overload. + // + if (auto higherOrderInvoke = as(callExpr->functionExpr)) + { + higherOrderInvoke->baseFunction = ConstructLookupResultExpr( + candidate.item, + baseExpr, + higherOrderInvoke->loc, + callExpr->functionExpr); + + higherOrderInvoke->type = candidate.funcType; + } + return callExpr; } @@ -1174,7 +1189,8 @@ namespace Slang DeclRef SemanticsVisitor::inferGenericArguments( DeclRef genericDeclRef, OverloadResolveContext& context, - GenericSubstitution* substWithKnownGenericArgs) + GenericSubstitution* substWithKnownGenericArgs, + List *innerParameterTypes) { // We have been asked to infer zero or more arguments to // `genericDeclRef`, in a context where it is being applied @@ -1279,7 +1295,7 @@ namespace Slang TryUnifyTypes( constraints, context.getArgTypeForInference(aa, this), - getType(m_astBuilder, params[aa])); + (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]); } } else @@ -1495,6 +1511,11 @@ namespace Slang AddOverloadCandidates(item, context); } } + else if (auto higherOrderExpr = as(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context); + } else if (auto partiallyAppliedGenericExpr = as(funcExpr)) { // A partially-applied generic is allowed as an overload candidate, @@ -1520,6 +1541,121 @@ namespace Slang } } + void SemanticsVisitor::AddHigherOrderOverloadCandidates( + Expr* funcExpr, + OverloadResolveContext& context) + { + // Lookup the higher order function and process types accordingly. In the future, + // if there are enough varieties, we can have dispatch logic instead of an + // if-else ladder. + if (auto jvpExpr = as(funcExpr)) + { + if (auto origFuncType = as(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-decl-ref) + + auto baseFuncDeclRef = as(jvpExpr->baseFunction)->declRef.as(); + SLANG_ASSERT(baseFuncDeclRef); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as(processJVPFuncType(this->getASTBuilder(), origFuncType)); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(baseFuncDeclRef); + + AddOverloadCandidate(context, candidate); + } + else if (auto origOverloadedType = as(jvpExpr->baseFunction->type)) + { + // Case: __jvp(name-resolved-to-multiple-decl-ref) + + if (auto overloadExpr = as(jvpExpr->baseFunction)) + { + for (auto item : overloadExpr->lookupResult2.items) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = as(processJVPFuncType( + this->getASTBuilder(), + as(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(item.declRef); + + AddOverloadCandidate(context, candidate); + } + } + else + { + // Unhandled overload expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::unimplemented, + funcExpr->type); + } + } + else if (auto baseFuncGenericDeclRef = as(jvpExpr->baseFunction)->declRef.as()) + { + // Case: __jvp(name-resolved-to-generic-decl) + + // Get inner function + DeclRef unspecializedInnerRef = DeclRef( + getInner(baseFuncGenericDeclRef), + baseFuncGenericDeclRef.substitutions); + + // Pull parameter list of inner function. + auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as()); + + // Process func type to generate JVP func type. + auto jvpFuncType = as(processJVPFuncType(this->getASTBuilder(), funcType)); + + // Extract parameter list from processed type. + List paramTypes; + + for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++) + paramTypes.add(jvpFuncType->getParamType(ii)); + + // Try to infer generic arguments, based on the updated context. + DeclRef innerRef = inferGenericArguments( + baseFuncGenericDeclRef, + context, + nullptr, + ¶mTypes); + + if (innerRef) + { + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + + // Note that we call processJVPFuncType() again here + // in order to process the specialized version of the original func type. + // This could potentially be a declRef.substitute(jvpFuncType) + // + candidate.funcType = as(processJVPFuncType( + this->getASTBuilder(), + getFuncType(this->getASTBuilder(), innerRef.as()))); + + candidate.resultType = candidate.funcType->getResultType(); + candidate.item = LookupResultItem(innerRef); + + AddOverloadCandidate(context, candidate); + } + else + { + SLANG_UNEXPECTED("Could not resolve generic candidate"); + } + + } + else + { + // Unhandled case for the inner expr. + funcExpr->type = this->getASTBuilder()->getErrorType(); + getSink()->diagnose(funcExpr->loc, + Diagnostics::expectedFunction, + funcExpr->type); + } + } + } + String SemanticsVisitor::getCallSignatureString( OverloadResolveContext& context) { @@ -1627,8 +1763,8 @@ namespace Slang // without needing dummy initializer/constructor declarations. // // Handling that special casing here (rather than in, say, - // `visitTypeCastExpr`) would allow us to continue to ensure // that `(T) expr` and `T(expr)` continue to be semantically + // `visitTypeCastExpr`) would allow us to continue to ensure // equivalent in (almost) all cases. if (!context.bestCandidate) -- cgit v1.2.3