From 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 Mon Sep 17 00:00:00 2001 From: Edward Liu Date: Mon, 14 Nov 2022 12:08:01 -0800 Subject: Minimum binary arithmetic reverse autodiff working. (#2514) * Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He --- source/slang/slang-check-overload.cpp | 55 ++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 14 deletions(-) (limited to 'source/slang/slang-check-overload.cpp') diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 38754d170..fe9de9433 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1548,35 +1548,51 @@ namespace Slang // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an // if-else ladder. - if (auto jvpExpr = as(funcExpr)) + if (auto expr = as(funcExpr)) { - if (auto origFuncType = as(jvpExpr->baseFunction->type)) + if (auto origFuncType = as(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-decl-ref) - auto baseFuncDeclRef = as(jvpExpr->baseFunction)->declRef.as(); + auto baseFuncDeclRef = as(expr->baseFunction)->declRef.as(); SLANG_ASSERT(baseFuncDeclRef); OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as(processJVPFuncType(origFuncType)); + if (auto fwdExpr = as(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as(processJVPFuncType(origFuncType)); + } + else if (auto bwdExpr = as(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as(processBackwardDiffFuncType(origFuncType)); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); AddOverloadCandidate(context, candidate); } - else if (auto origOverloadedType = as(jvpExpr->baseFunction->type)) + else if (auto origOverloadedType = as(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-multiple-decl-ref) - if (auto overloadExpr = as(jvpExpr->baseFunction)) + if (auto overloadExpr = as(expr->baseFunction)) { for (auto item : overloadExpr->lookupResult2.items) { auto funcType = as(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); if (!funcType) continue; - funcType = as(processJVPFuncType(funcType)); + if (auto fwdExpr = as(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + funcType = as(processJVPFuncType(funcType)); + } + else if (auto bwdExpr = as(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + funcType = as(processBackwardDiffFuncType(funcType)); + } if (!funcType) continue; OverloadCandidate candidate; @@ -1597,9 +1613,8 @@ namespace Slang funcExpr->type); } } - else if (auto baseFuncGenericDeclRef = as(jvpExpr->baseFunction)->declRef.as()) + else if (auto baseFuncGenericDeclRef = as(expr->baseFunction)->declRef.as()) { - // Case: __fwd_diff(name-resolved-to-generic-decl) // Get inner function DeclRef unspecializedInnerRef = DeclRef( @@ -1610,7 +1625,9 @@ namespace Slang auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as()); // Process func type to generate JVP func type. - auto jvpFuncType = as(processJVPFuncType(funcType)); + auto jvpFuncType = as(expr) ? + as(processJVPFuncType(funcType)) : + as(processBackwardDiffFuncType(funcType)); // Extract parameter list from processed type. List paramTypes; @@ -1634,8 +1651,18 @@ namespace Slang // in order to process the specialized version of the original func type. // This could potentially be a declRef.substitute(jvpFuncType) // - candidate.funcType = as(processJVPFuncType( - getFuncType(this->getASTBuilder(), innerRef.as()))); + if (auto fwdExpr = as(expr)) + { + // Case: __fwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as(processJVPFuncType( + getFuncType(this->getASTBuilder(), innerRef.as()))); + } + else if (auto bwdExpr = as(expr)) + { + // Case: __bwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as(processBackwardDiffFuncType( + getFuncType(this->getASTBuilder(), innerRef.as()))); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(innerRef); -- cgit v1.2.3