diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-13 11:48:54 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-13 11:48:54 -0800 |
| commit | 4adc64f2a033ec141df6a16e65131612b30fb23b (patch) | |
| tree | 31e4fabbfcac5e59ee334acb2be0f1df2542d679 /source/slang/slang-lower-to-ir.cpp | |
| parent | 63b874dab2df8950a37e0861d24f322e0ab9bfda (diff) | |
Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. (#2589)
* Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`.
* Fix clang issue.
* Fix.
* fix gcc issue
* fix formatting.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4618b6786..9378a69e8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8360,7 +8360,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier<ForwardDerivativeAttribute>()) + if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>()) { // We need to lower the decl ref to the custom derivative function to IR. // The IR insts correspond to the decl ref is not part of the function we @@ -8374,13 +8374,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* jvpFunc = loweredVal.val; - getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc); + IRInst* derivativeFunc = loweredVal.val; + + if (as<ForwardDerivativeAttribute>(attr)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); // Reset cursor. subContext->irBuilder->setInsertInto(irFunc); } - + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8391,7 +8395,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>()) + if (auto attr = decl->findModifier<DerivativeOfAttribute>()) { if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr)) { @@ -8412,9 +8416,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } originalSubBuilder->setInsertBefore(originalFuncVal); auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + if (as<ForwardDerivativeOfAttribute>(attr)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - getBuilder()->addForwardDifferentiableDecoration(irFunc); subContext->irBuilder->setInsertInto(irFunc); finalVal->moveToEnd(); } |
