summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-13 11:48:54 -0800
committerGitHub <noreply@github.com>2023-01-13 11:48:54 -0800
commit4adc64f2a033ec141df6a16e65131612b30fb23b (patch)
tree31e4fabbfcac5e59ee334acb2be0f1df2542d679 /source/slang/slang-lower-to-ir.cpp
parent63b874dab2df8950a37e0861d24f322e0ab9bfda (diff)
Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. (#2589)
* Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. * Fix clang issue. * Fix. * fix gcc issue * fix formatting. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp27
1 files changed, 20 insertions, 7 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4618b6786..9378a69e8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8360,7 +8360,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<ForwardDerivativeAttribute>())
+ if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>())
{
// We need to lower the decl ref to the custom derivative function to IR.
// The IR insts correspond to the decl ref is not part of the function we
@@ -8374,13 +8374,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- IRInst* jvpFunc = loweredVal.val;
- getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc);
+ IRInst* derivativeFunc = loweredVal.val;
+
+ if (as<ForwardDerivativeAttribute>(attr))
+ getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
+ else
+ getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
// Reset cursor.
subContext->irBuilder->setInsertInto(irFunc);
}
-
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -8391,7 +8395,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// the interface's type definition.
auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric);
- if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>())
+ if (auto attr = decl->findModifier<DerivativeOfAttribute>())
{
if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr))
{
@@ -8412,9 +8416,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
originalSubBuilder->setInsertBefore(originalFuncVal);
auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef);
- originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ if (as<ForwardDerivativeOfAttribute>(attr))
+ {
+ originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else
+ {
+ originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
}
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
subContext->irBuilder->setInsertInto(irFunc);
finalVal->moveToEnd();
}