diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-13 11:48:54 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-13 11:48:54 -0800 |
| commit | 4adc64f2a033ec141df6a16e65131612b30fb23b (patch) | |
| tree | 31e4fabbfcac5e59ee334acb2be0f1df2542d679 /source | |
| parent | 63b874dab2df8950a37e0861d24f322e0ab9bfda (diff) | |
Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`. (#2589)
* Frontend work for `[BackwardDerivative]` and `[BackwardDerivativeOf]`.
* Fix clang issue.
* Fix.
* fix gcc issue
* fix formatting.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 31 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 287 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 27 |
8 files changed, 299 insertions, 75 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f58648657..e19923c80 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,11 +9,17 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) +attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; + +__attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; + __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index c85464061..666ca77ea 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1066,26 +1066,36 @@ class ForwardDifferentiableAttribute : public DifferentiableAttribute SLANG_AST_CLASS(ForwardDifferentiableAttribute) }; +class UserDefinedDerivativeAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(UserDefinedDerivativeAttribute) + + Expr* funcExpr; +}; + /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. -class ForwardDerivativeAttribute : public DifferentiableAttribute +class ForwardDerivativeAttribute : public UserDefinedDerivativeAttribute { SLANG_AST_CLASS(ForwardDerivativeAttribute) +}; + +class DerivativeOfAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(DerivativeOfAttribute) Expr* funcExpr; + + Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom /// derivative implementation for `primalFunction`. /// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative /// function itself is considered differentiable. -class ForwardDerivativeOfAttribute : public DifferentiableAttribute +class ForwardDerivativeOfAttribute : public DerivativeOfAttribute { SLANG_AST_CLASS(ForwardDerivativeOfAttribute) - - Expr* funcExpr; - - Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// The `[BackwardDifferentiable]` attribute indicates that a function can be backward-differentiated. @@ -1096,21 +1106,16 @@ class BackwardDifferentiableAttribute : public DifferentiableAttribute /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should /// be used as the backward-derivative for the decorated function. -class BackwardDerivativeAttribute : public DifferentiableAttribute +class BackwardDerivativeAttribute : public UserDefinedDerivativeAttribute { SLANG_AST_CLASS(BackwardDerivativeAttribute) - Expr* funcExpr; }; /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom /// backward-derivative implementation for `primalFunction`. -class BackwardDerivativeOfAttribute : public DifferentiableAttribute +class BackwardDerivativeOfAttribute : public DerivativeOfAttribute { SLANG_AST_CLASS(BackwardDerivativeOfAttribute) - - Expr* funcExpr; - - Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b8732a67f..f016ae3d8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -259,10 +259,9 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); - void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl); - - void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr); + void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); + void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4668,90 +4667,273 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } - void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl) + template<typename TDerivativeAttr> + void checkDerivativeAttributeImpl( + SemanticsVisitor* visitor, + TDerivativeAttr* attr, + const List<Expr*>& imaginaryArguments) { - auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); - if (!attr) - return; + auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + template<typename TDerivativeAttr> + const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } + + template<> + const char* getDerivativeAttrName<ForwardDerivativeAttribute>() + { + return "ForwardDerivative"; + } + template<> + const char* getDerivativeAttrName<BackwardDerivativeAttribute>() + { + return "BackwardDerivative"; + } + List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { List<Expr*> imaginaryArguments; - for (auto param : funcDecl->getParameters()) + for (auto param : originalFuncDecl->getParameters()) { - auto arg = m_astBuilder->create<VarExpr>(); + auto arg = visitor->getASTBuilder()->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); - arg->loc = attr->loc; + arg->loc = loc; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : fwdDiffFunc->getParameters()) + { + auto arg = astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; if (auto pairType = as<DifferentialPairType>(param->getType())) { arg->type.type = pairType->getPrimalType(); } imaginaryArguments.add(arg); } - auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); - auto resolved = ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : originalFuncDecl->getParameters()) { - if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + auto arg = visitor->getASTBuilder()->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) { - if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>()) + arg->type.type = pairType; + if (auto diffPairType = as<DifferentialPairType>(pairType)) { - // The primal function already has a `[ForwardDerivative]` attribute, this is invalid. - getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]"); - getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) + { + arg->type.isLeftValue = false; + arg->type.type = diffPairType->getPrimalType(); + } } - attr->funcExpr = calleeDeclRef; - auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>(); - fwdDerivativeAttr->loc = attr->loc; - auto outterGeneric = GetOuterGeneric(funcDecl); - auto declRef = - DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr); - declRefExpr->type.type = nullptr; - fwdDerivativeAttr->args.add(declRefExpr); - fwdDerivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); - attr->backDeclRef = fwdDerivativeAttr->funcExpr; - fwdDerivativeAttr->funcExpr = nullptr; - getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl); - return; } + imaginaryArguments.add(arg); + } + if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) + { + auto arg = visitor->getASTBuilder()->create<InitializerListExpr>(); + arg->type.isLeftValue = false; + arg->type.type = diffReturnType; + arg->loc = loc; + imaginaryArguments.add(arg); } - getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + return imaginaryArguments; } - void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - List<Expr*> imaginaryArguments; - for (auto param : funcDecl->getParameters()) + for (auto param : bwdDiffFunc->getParameters()) { - auto arg = m_astBuilder->create<VarExpr>(); + auto arg = astBuilder->create<VarExpr>(); arg->declRef.decl = param; arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); - arg->loc = attr->loc; - if (auto pairType = getDifferentialPairType(param->getType())) + arg->loc = loc; + if (auto pairType = as<DifferentialPairType>(param->getType())) { - arg->type.type = pairType; + if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr) + { + arg->type.isLeftValue = false; + } + arg->type.type = pairType->getPrimalType(); } imaginaryArguments.add(arg); } - auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); - auto resolved = ResolveInvoke(invokeExpr); + return imaginaryArguments; + } + + // This helper function is needed to workaround a gcc bug. + // Remove when we upgrade to a newer version of gcc. + template <typename T> + static T* _findModifier(Decl* decl) + { + return decl->findModifier<T>(); + } + + template <typename TDerivativeAttr, typename TDerivativeOfAttr> + void checkDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List<Expr*>& imaginaryArgsToOriginal) + { + auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); + auto resolved = visitor->ResolveInvoke(invokeExpr); if (auto resolvedInvoke = as<InvokeExpr>(resolved)) { if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) { - attr->funcExpr = calleeDeclRef; + auto calleeDecl = calleeDeclRef->declRef.getDecl(); + if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeDecl)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef->declRef, + getDerivativeAttrName<TDerivativeAttr>()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + } + derivativeOfAttr->funcExpr = calleeDeclRef; + auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl); return; } } - getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + template<typename TDerivativeAttr, typename TDerivativeOfAttr> + bool tryCheckDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List<Expr*>& imaginaryArgsToOriginal) + { + DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); + SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); + checkDerivativeOfAttributeImpl<TDerivativeAttr>( + &subVisitor, + funcDecl, + derivativeOfAttr, + assocKind, + imaginaryArgsToOriginal); + return tempSink.getErrorCount() == 0; + } + + void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) + { + auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); + if (!attr) + return; + + List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); + checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal); + } + + void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) + { + auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>(); + if (!attr) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc); + + // The tricky part here is that we can't easily derive the arguments to original func just + // from the definition of a backward derivative function, because we don't know if the last + // parameter is just a normal parameter of the original func, or if it is the additional + // derivative of the return value. The solution here is to try to resolve the original + // function with or without the last argument. However if the type of the last argument + // isn't differentiable, we know that it can't possibly be the result derivative. + + if (imaginaryArguments.getCount() == 0 || + !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type)) + { + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); + return; + } + + // Otherwise, try resolve with all the arguments, if failed, resolve without the last + // argument. + if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments)) + { + return; + } + + imaginaryArguments.removeLast(); + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments); } void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) @@ -4759,9 +4941,12 @@ namespace Slang auto newContext = withParentFunc(decl); // Run checking on attributes that can't be fully checked in header checking stage. - checkDerivativeOfAttribute(decl); + checkForwardDerivativeOfAttribute(decl); if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) - checkDerivativeAttribute(decl, derivativeAttr); + checkDerivativeAttribute(this, decl, derivativeAttr); + checkBackwardDerivativeOfAttribute(decl); + if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>()) + checkDerivativeAttribute(this, decl, derivativeAttr); if (newContext.getParentDifferentiableAttribute()) { diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 9742e69bb..f505b1321 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -635,7 +635,7 @@ namespace Slang diffExpr->type.type = nullptr; forwardDerivativeAttr->funcExpr = diffExpr; } - else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr)) + else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); @@ -648,7 +648,7 @@ namespace Slang getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); return false; } - forwardDerivativeOfAttr->funcExpr = primalFunc; + derivativeOfAttr->funcExpr = primalFunc; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index db97f4865..8f9327c53 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -345,7 +345,7 @@ DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.") DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") -DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '$1'.") +DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") // Enums diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 06f8b0e5d..ab7453b41 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -734,12 +734,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) - /// Decorated function is marked for the reverse-mode differentiation pass. + /// Decorations to associate an original function with compiler generated backward derivative functions. INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0) INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0) INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0) INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0) + INST(UserDefinedBackwardDerivativeDecoration, userDefinedBackwardDiffReference, 1, 0) INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1ff61a774..b30d489dc 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -701,6 +701,15 @@ struct IRBackwardDifferentiableDecoration : IRDecoration IR_LEAF_ISA(BackwardDifferentiableDecoration) }; +struct IRUserDefinedBackwardDerivativeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_UserDefinedBackwardDerivativeDecoration + }; + IR_LEAF_ISA(UserDefinedBackwardDerivativeDecoration) +}; + struct IRTreatAsDifferentiableDecoration : IRDecoration { enum @@ -3497,6 +3506,11 @@ public: addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc); } + void addUserDefinedBackwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc) + { + addDecoration(value, kIROp_UserDefinedBackwardDerivativeDecoration, fwdFunc); + } + void addBackwardDerivativePrimalDecoration(IRInst* value, IRInst* jvpFn) { addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4618b6786..9378a69e8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8360,7 +8360,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier<ForwardDerivativeAttribute>()) + if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>()) { // We need to lower the decl ref to the custom derivative function to IR. // The IR insts correspond to the decl ref is not part of the function we @@ -8374,13 +8374,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* jvpFunc = loweredVal.val; - getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc); + IRInst* derivativeFunc = loweredVal.val; + + if (as<ForwardDerivativeAttribute>(attr)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); // Reset cursor. subContext->irBuilder->setInsertInto(irFunc); } - + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8391,7 +8395,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>()) + if (auto attr = decl->findModifier<DerivativeOfAttribute>()) { if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr)) { @@ -8412,9 +8416,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } originalSubBuilder->setInsertBefore(originalFuncVal); auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + if (as<ForwardDerivativeOfAttribute>(attr)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - getBuilder()->addForwardDifferentiableDecoration(irFunc); subContext->irBuilder->setInsertInto(irFunc); finalVal->moveToEnd(); } |
