summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorEdward Liu <shiqiu1105@gmail.com>2022-11-14 12:08:01 -0800
committerGitHub <noreply@github.com>2022-11-14 12:08:01 -0800
commit368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch)
tree3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-lower-to-ir.cpp
parent623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff)
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5930875f1..a0158cf38 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3081,6 +3081,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ // Emit IR to denote the forward-mode derivative
+ // of the inner func-expr. This will be resolved
+ // to a concrete function during the derivative
+ // pass.
+ LoweredValInfo visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFunction);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitBackwardDifferentiateInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
auto baseVal = lowerSubExpr(expr->arrayExpr);
@@ -7799,6 +7814,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addForwardDifferentiableDecoration(irFunc);
}
+ if (decl->findModifier<BackwardDifferentiableAttribute>())
+ {
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
{
lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);