summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-overload.cpp
diff options
context:
space:
mode:
authorEdward Liu <shiqiu1105@gmail.com>2022-11-14 12:08:01 -0800
committerGitHub <noreply@github.com>2022-11-14 12:08:01 -0800
commit368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch)
tree3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-check-overload.cpp
parent623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff)
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
-rw-r--r--source/slang/slang-check-overload.cpp55
1 files changed, 41 insertions, 14 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 38754d170..fe9de9433 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1548,35 +1548,51 @@ namespace Slang
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
// if-else ladder.
- if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr))
+ if (auto expr = as<HigherOrderInvokeExpr>(funcExpr))
{
- if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
+ if (auto origFuncType = as<FuncType>(expr->baseFunction->type))
{
- // Case: __fwd_diff(name-resolved-to-decl-ref)
- auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
+ auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>();
SLANG_ASSERT(baseFuncDeclRef);
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
+ candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-decl-ref)
+ candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType));
+ }
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
AddOverloadCandidate(context, candidate);
}
- else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
+ else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type))
{
- // Case: __fwd_diff(name-resolved-to-multiple-decl-ref)
- if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
+ if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction))
{
for (auto item : overloadExpr->lookupResult2.items)
{
auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc));
if (!funcType)
continue;
- funcType = as<FuncType>(processJVPFuncType(funcType));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
+ funcType = as<FuncType>(processJVPFuncType(funcType));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-decl-ref)
+ funcType = as<FuncType>(processBackwardDiffFuncType(funcType));
+ }
if (!funcType)
continue;
OverloadCandidate candidate;
@@ -1597,9 +1613,8 @@ namespace Slang
funcExpr->type);
}
}
- else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
+ else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>())
{
- // Case: __fwd_diff(name-resolved-to-generic-decl)
// Get inner function
DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
@@ -1610,7 +1625,9 @@ namespace Slang
auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
// Process func type to generate JVP func type.
- auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType));
+ auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ?
+ as<FuncType>(processJVPFuncType(funcType)) :
+ as<FuncType>(processBackwardDiffFuncType(funcType));
// Extract parameter list from processed type.
List<Type*> paramTypes;
@@ -1634,8 +1651,18 @@ namespace Slang
// in order to process the specialized version of the original func type.
// This could potentially be a declRef.substitute(jvpFuncType)
//
- candidate.funcType = as<FuncType>(processJVPFuncType(
- getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-generic-decl)
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-generic-decl)
+ candidate.funcType = as<FuncType>(processBackwardDiffFuncType(
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ }
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(innerRef);