summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-13 11:48:54 -0800
committerGitHub <noreply@github.com>2023-01-13 11:48:54 -0800
commit4adc64f2a033ec141df6a16e65131612b30fb23b (patch)
tree31e4fabbfcac5e59ee334acb2be0f1df2542d679 /source
parent63b874dab2df8950a37e0861d24f322e0ab9bfda (diff)
Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. (#2589)
* Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. * Fix clang issue. * Fix. * fix gcc issue * fix formatting. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang6
-rw-r--r--source/slang/slang-ast-modifier.h31
-rw-r--r--source/slang/slang-check-decl.cpp287
-rw-r--r--source/slang/slang-check-modifier.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-lower-to-ir.cpp27
8 files changed, 299 insertions, 75 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f58648657..e19923c80 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,11 +9,17 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
+attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
+
+__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute;
+
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index c85464061..666ca77ea 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1066,26 +1066,36 @@ class ForwardDifferentiableAttribute : public DifferentiableAttribute
SLANG_AST_CLASS(ForwardDifferentiableAttribute)
};
+class UserDefinedDerivativeAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(UserDefinedDerivativeAttribute)
+
+ Expr* funcExpr;
+};
+
/// The `[ForwardDerivative(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
-class ForwardDerivativeAttribute : public DifferentiableAttribute
+class ForwardDerivativeAttribute : public UserDefinedDerivativeAttribute
{
SLANG_AST_CLASS(ForwardDerivativeAttribute)
+};
+
+class DerivativeOfAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(DerivativeOfAttribute)
Expr* funcExpr;
+
+ Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
/// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom
/// derivative implementation for `primalFunction`.
/// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative
/// function itself is considered differentiable.
-class ForwardDerivativeOfAttribute : public DifferentiableAttribute
+class ForwardDerivativeOfAttribute : public DerivativeOfAttribute
{
SLANG_AST_CLASS(ForwardDerivativeOfAttribute)
-
- Expr* funcExpr;
-
- Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
/// The `[BackwardDifferentiable]` attribute indicates that a function can be backward-differentiated.
@@ -1096,21 +1106,16 @@ class BackwardDifferentiableAttribute : public DifferentiableAttribute
/// The `[BackwardDerivative(function)]` attribute specifies a custom function that should
/// be used as the backward-derivative for the decorated function.
-class BackwardDerivativeAttribute : public DifferentiableAttribute
+class BackwardDerivativeAttribute : public UserDefinedDerivativeAttribute
{
SLANG_AST_CLASS(BackwardDerivativeAttribute)
- Expr* funcExpr;
};
/// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom
/// backward-derivative implementation for `primalFunction`.
-class BackwardDerivativeOfAttribute : public DifferentiableAttribute
+class BackwardDerivativeOfAttribute : public DerivativeOfAttribute
{
SLANG_AST_CLASS(BackwardDerivativeOfAttribute)
-
- Expr* funcExpr;
-
- Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
/// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b8732a67f..f016ae3d8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -259,10 +259,9 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
- void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl);
-
- void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr);
+ void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
+ void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4668,90 +4667,273 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
- void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ template<typename TDerivativeAttr>
+ void checkDerivativeAttributeImpl(
+ SemanticsVisitor* visitor,
+ TDerivativeAttr* attr,
+ const List<Expr*>& imaginaryArguments)
{
- auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
- if (!attr)
- return;
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ template<typename TDerivativeAttr>
+ const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
+
+ template<>
+ const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
+ {
+ return "ForwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
+ {
+ return "BackwardDerivative";
+ }
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
List<Expr*> imaginaryArguments;
- for (auto param : funcDecl->getParameters())
+ for (auto param : originalFuncDecl->getParameters())
{
- auto arg = m_astBuilder->create<VarExpr>();
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
- arg->loc = attr->loc;
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : fwdDiffFunc->getParameters())
+ {
+ auto arg = astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
if (auto pairType = as<DifferentialPairType>(param->getType()))
{
arg->type.type = pairType->getPrimalType();
}
imaginaryArguments.add(arg);
}
- auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
- auto resolved = ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : originalFuncDecl->getParameters())
{
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
{
- if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>())
+ arg->type.type = pairType;
+ if (auto diffPairType = as<DifferentialPairType>(pairType))
{
- // The primal function already has a `[ForwardDerivative]` attribute, this is invalid.
- getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]");
- getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
+ {
+ arg->type.isLeftValue = false;
+ arg->type.type = diffPairType->getPrimalType();
+ }
}
- attr->funcExpr = calleeDeclRef;
- auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>();
- fwdDerivativeAttr->loc = attr->loc;
- auto outterGeneric = GetOuterGeneric(funcDecl);
- auto declRef =
- DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
- auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr);
- declRefExpr->type.type = nullptr;
- fwdDerivativeAttr->args.add(declRefExpr);
- fwdDerivativeAttr->funcExpr = declRefExpr;
- checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
- attr->backDeclRef = fwdDerivativeAttr->funcExpr;
- fwdDerivativeAttr->funcExpr = nullptr;
- getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl);
- return;
}
+ imaginaryArguments.add(arg);
+ }
+ if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
+ {
+ auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
+ arg->type.isLeftValue = false;
+ arg->type.type = diffReturnType;
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
}
- getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ return imaginaryArguments;
}
- void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc)
{
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
List<Expr*> imaginaryArguments;
- for (auto param : funcDecl->getParameters())
+ for (auto param : bwdDiffFunc->getParameters())
{
- auto arg = m_astBuilder->create<VarExpr>();
+ auto arg = astBuilder->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
- arg->loc = attr->loc;
- if (auto pairType = getDifferentialPairType(param->getType()))
+ arg->loc = loc;
+ if (auto pairType = as<DifferentialPairType>(param->getType()))
{
- arg->type.type = pairType;
+ if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
+ {
+ arg->type.isLeftValue = false;
+ }
+ arg->type.type = pairType->getPrimalType();
}
imaginaryArguments.add(arg);
}
- auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
- auto resolved = ResolveInvoke(invokeExpr);
+ return imaginaryArguments;
+ }
+
+ // This helper function is needed to workaround a gcc bug.
+ // Remove when we upgrade to a newer version of gcc.
+ template <typename T>
+ static T* _findModifier(Decl* decl)
+ {
+ return decl->findModifier<T>();
+ }
+
+ template <typename TDerivativeAttr, typename TDerivativeOfAttr>
+ void checkDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
{
if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
{
- attr->funcExpr = calleeDeclRef;
+ auto calleeDecl = calleeDeclRef->declRef.getDecl();
+ if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeDecl))
+ {
+ // The primal function already has a `[*Derivative]` attribute, this is invalid.
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::declAlreadyHasAttribute,
+ calleeDeclRef->declRef,
+ getDerivativeAttrName<TDerivativeAttr>());
+ visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ }
+ derivativeOfAttr->funcExpr = calleeDeclRef;
+ auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
+ derivativeAttr->loc = derivativeOfAttr->loc;
+ auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ derivativeAttr->args.add(declRefExpr);
+ derivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(visitor, as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), derivativeAttr);
+ derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
+ derivativeAttr->funcExpr = nullptr;
+ visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl);
return;
}
}
- getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ template<typename TDerivativeAttr, typename TDerivativeOfAttr>
+ bool tryCheckDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
+ SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
+ checkDerivativeOfAttributeImpl<TDerivativeAttr>(
+ &subVisitor,
+ funcDecl,
+ derivativeOfAttr,
+ assocKind,
+ imaginaryArgsToOriginal);
+ return tempSink.getErrorCount() == 0;
+ }
+
+ void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal);
+ }
+
+ void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
+
+ // The tricky part here is that we can't easily derive the arguments to original func just
+ // from the definition of a backward derivative function, because we don't know if the last
+ // parameter is just a normal parameter of the original func, or if it is the additional
+ // derivative of the return value. The solution here is to try to resolve the original
+ // function with or without the last argument. However if the type of the last argument
+ // isn't differentiable, we know that it can't possibly be the result derivative.
+
+ if (imaginaryArguments.getCount() == 0 ||
+ !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type))
+ {
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
+ return;
+ }
+
+ // Otherwise, try resolve with all the arguments, if failed, resolve without the last
+ // argument.
+ if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments))
+ {
+ return;
+ }
+
+ imaginaryArguments.removeLast();
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
}
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
@@ -4759,9 +4941,12 @@ namespace Slang
auto newContext = withParentFunc(decl);
// Run checking on attributes that can't be fully checked in header checking stage.
- checkDerivativeOfAttribute(decl);
+ checkForwardDerivativeOfAttribute(decl);
if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
- checkDerivativeAttribute(decl, derivativeAttr);
+ checkDerivativeAttribute(this, decl, derivativeAttr);
+ checkBackwardDerivativeOfAttribute(decl);
+ if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>())
+ checkDerivativeAttribute(this, decl, derivativeAttr);
if (newContext.getParentDifferentiableAttribute())
{
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 9742e69bb..f505b1321 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -635,7 +635,7 @@ namespace Slang
diffExpr->type.type = nullptr;
forwardDerivativeAttr->funcExpr = diffExpr;
}
- else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr))
+ else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
@@ -648,7 +648,7 @@ namespace Slang
getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
return false;
}
- forwardDerivativeOfAttr->funcExpr = primalFunc;
+ derivativeOfAttr->funcExpr = primalFunc;
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index db97f4865..8f9327c53 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -345,7 +345,7 @@ DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl
DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.")
DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.")
-DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '$1'.")
+DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
// Enums
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 06f8b0e5d..ab7453b41 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -734,12 +734,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0)
- /// Decorated function is marked for the reverse-mode differentiation pass.
+ /// Decorations to associate an original function with compiler generated backward derivative functions.
INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0)
INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0)
INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0)
INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0)
+ INST(UserDefinedBackwardDerivativeDecoration, userDefinedBackwardDiffReference, 1, 0)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 1ff61a774..b30d489dc 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -701,6 +701,15 @@ struct IRBackwardDifferentiableDecoration : IRDecoration
IR_LEAF_ISA(BackwardDifferentiableDecoration)
};
+struct IRUserDefinedBackwardDerivativeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_UserDefinedBackwardDerivativeDecoration
+ };
+ IR_LEAF_ISA(UserDefinedBackwardDerivativeDecoration)
+};
+
struct IRTreatAsDifferentiableDecoration : IRDecoration
{
enum
@@ -3497,6 +3506,11 @@ public:
addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc);
}
+ void addUserDefinedBackwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc)
+ {
+ addDecoration(value, kIROp_UserDefinedBackwardDerivativeDecoration, fwdFunc);
+ }
+
void addBackwardDerivativePrimalDecoration(IRInst* value, IRInst* jvpFn)
{
addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4618b6786..9378a69e8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8360,7 +8360,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<ForwardDerivativeAttribute>())
+ if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>())
{
// We need to lower the decl ref to the custom derivative function to IR.
// The IR insts correspond to the decl ref is not part of the function we
@@ -8374,13 +8374,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- IRInst* jvpFunc = loweredVal.val;
- getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc);
+ IRInst* derivativeFunc = loweredVal.val;
+
+ if (as<ForwardDerivativeAttribute>(attr))
+ getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
+ else
+ getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
// Reset cursor.
subContext->irBuilder->setInsertInto(irFunc);
}
-
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -8391,7 +8395,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// the interface's type definition.
auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric);
- if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>())
+ if (auto attr = decl->findModifier<DerivativeOfAttribute>())
{
if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr))
{
@@ -8412,9 +8416,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
originalSubBuilder->setInsertBefore(originalFuncVal);
auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef);
- originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ if (as<ForwardDerivativeOfAttribute>(attr))
+ {
+ originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else
+ {
+ originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
}
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
subContext->irBuilder->setInsertInto(irFunc);
finalVal->moveToEnd();
}