diff options
| author | Edward Liu <shiqiu1105@gmail.com> | 2022-11-14 12:08:01 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-14 12:08:01 -0800 |
| commit | 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch) | |
| tree | 3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-check-overload.cpp | |
| parent | 623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff) | |
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend.
* More plumbing.
* Initial reverse autodiff working.
* Bug fixes.
* Misc.
* Remove redundant code.
* More clean up.
* Misc.
* Rebase and add backward diff test.
* Disable test.
* Clean up.
* Minor fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 38754d170..fe9de9433 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1548,35 +1548,51 @@ namespace Slang // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an // if-else ladder. - if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr)) + if (auto expr = as<HigherOrderInvokeExpr>(funcExpr)) { - if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) + if (auto origFuncType = as<FuncType>(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-decl-ref) - auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); + auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>(); SLANG_ASSERT(baseFuncDeclRef); OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType)); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); AddOverloadCandidate(context, candidate); } - else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) + else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-multiple-decl-ref) - if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) + if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction)) { for (auto item : overloadExpr->lookupResult2.items) { auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); if (!funcType) continue; - funcType = as<FuncType>(processJVPFuncType(funcType)); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + funcType = as<FuncType>(processJVPFuncType(funcType)); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + funcType = as<FuncType>(processBackwardDiffFuncType(funcType)); + } if (!funcType) continue; OverloadCandidate candidate; @@ -1597,9 +1613,8 @@ namespace Slang funcExpr->type); } } - else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) + else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>()) { - // Case: __fwd_diff(name-resolved-to-generic-decl) // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( @@ -1610,7 +1625,9 @@ namespace Slang auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); // Process func type to generate JVP func type. - auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType)); + auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ? + as<FuncType>(processJVPFuncType(funcType)) : + as<FuncType>(processBackwardDiffFuncType(funcType)); // Extract parameter list from processed type. List<Type*> paramTypes; @@ -1634,8 +1651,18 @@ namespace Slang // in order to process the specialized version of the original func type. // This could potentially be a declRef.substitute(jvpFuncType) // - candidate.funcType = as<FuncType>(processJVPFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as<FuncType>(processJVPFuncType( + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as<FuncType>(processBackwardDiffFuncType( + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(innerRef); |
