diff options
| author | Edward Liu <shiqiu1105@gmail.com> | 2022-11-14 12:08:01 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-14 12:08:01 -0800 |
| commit | 368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch) | |
| tree | 3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-lower-to-ir.cpp | |
| parent | 623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff) | |
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend.
* More plumbing.
* Initial reverse autodiff working.
* Bug fixes.
* Misc.
* Remove redundant code.
* More clean up.
* Misc.
* Rebase and add backward diff test.
* Disable test.
* Clean up.
* Minor fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5930875f1..a0158cf38 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3081,6 +3081,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + // Emit IR to denote the forward-mode derivative + // of the inner func-expr. This will be resolved + // to a concrete function during the derivative + // pass. + LoweredValInfo visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitBackwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { auto baseVal = lowerSubExpr(expr->arrayExpr); @@ -7799,6 +7814,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addForwardDifferentiableDecoration(irFunc); } + if (decl->findModifier<BackwardDifferentiableAttribute>()) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); |
