From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 85 +-- source/slang/slang-ast-expr.h | 6 +- source/slang/slang-ast-modifier.h | 18 + source/slang/slang-ast-support-types.h | 12 +- source/slang/slang-check-decl.cpp | 624 ++++++++++++--------- source/slang/slang-check-expr.cpp | 79 ++- source/slang/slang-check-impl.h | 5 +- source/slang/slang-check-modifier.cpp | 30 +- source/slang/slang-check-overload.cpp | 2 +- source/slang/slang-ir-autodiff-fwd.cpp | 47 +- source/slang/slang-ir-autodiff-rev.cpp | 2 +- .../slang/slang-ir-autodiff-transcriber-base.cpp | 39 +- source/slang/slang-ir-autodiff-unzip.cpp | 26 +- source/slang/slang-ir-autodiff.cpp | 12 +- source/slang/slang-ir-autodiff.h | 2 +- source/slang/slang-ir-check-differentiability.cpp | 15 +- source/slang/slang-ir-dce.cpp | 36 +- source/slang/slang-ir-inst-defs.h | 6 + source/slang/slang-ir-insts.h | 31 + source/slang/slang-ir-link.cpp | 1 + source/slang/slang-ir.cpp | 11 + source/slang/slang-lower-to-ir.cpp | 357 ++++++------ 22 files changed, 860 insertions(+), 586 deletions(-) (limited to 'source/slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 54f927816..4301eda94 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -11,6 +11,9 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; + __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; @@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; + __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair> i, inout DifferentialPair< // Sine and cosine __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(T x, out T s, out T c) { s = sin(x); @@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(vector x, out vector s, out vector c) { s = sin(x); @@ -1053,62 +1061,18 @@ void __sincos_impl(vector x, out vector s, out vector c) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(matrix x, out matrix s, out matrix c) { s = sin(x); c = cos(x); } -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair> x, out DifferentialPair> s, out DifferentialPair> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair> x, out DifferentialPair> s, out DifferentialPair> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair x, T.Differential dS, T.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair> x, vector.Differential dS, vector.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} - -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair> x, matrix.Differential dS, matrix.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} // dst (obsolete) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(dst)] vector __dst_impl(vector src0, vector src1) { vector dest; @@ -1118,25 +1082,11 @@ vector __dst_impl(vector src0, vector src1) dest.w = src1.w; ; return dest; } -__generic -[ForwardDerivativeOf(dst)] -[ForceInline] -DifferentialPair> __d_dst(DifferentialPair> src0, DifferentialPair> src1) -{ - return __fwd_diff(__dst_impl)(src0, src1); -} -__generic -[BackwardDerivativeOf(dst)] -[ForceInline] -void __d_dst(inout DifferentialPair> src0, inout DifferentialPair> src1, vector.Differential dOut) -{ - __bwd_diff(__dst_impl)(src0, src1, dOut); -} // Legacy lighting function (obsolete) -__target_intrinsic(hlsl) [__readNone] [BackwardDifferentiable] +[PrimalSubstituteOf(lit)] float4 __lit_impl(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; @@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m) let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); return float4(ambient, diffuse, specular, 1.0f); } -[ForwardDerivativeOf(lit)] -[ForceInline] -DifferentialPair __d_lit(DifferentialPair n_dot_l, DifferentialPair n_dot_h, DifferentialPair m) -{ - return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); -} -[BackwardDerivativeOf(lit)] -[ForceInline] -void __d_lit(inout DifferentialPair n_dot_l, inout DifferentialPair n_dot_h, inout DifferentialPair m, float4 dOut) -{ - __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); -} - // Matrix determinant __generic [BackwardDifferentiable] diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index ba0b4ce7a..301dded49 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -452,10 +452,14 @@ class HigherOrderInvokeExpr : public Expr List newParameterNames; }; +class PrimalSubstituteExpr : public HigherOrderInvokeExpr +{ + SLANG_AST_CLASS(PrimalSubstituteExpr) +}; + class DifferentiateExpr : public HigherOrderInvokeExpr { SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) - }; /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 7dd0819d8..80c770c3a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1151,6 +1151,24 @@ class BackwardDerivativeOfAttribute : public DerivativeOfAttribute SLANG_AST_CLASS(BackwardDerivativeOfAttribute) }; + /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should + /// be used as the primal function substitute when differentiating code that calls the primal function. +class PrimalSubstituteAttribute : public Attribute +{ + SLANG_AST_CLASS(PrimalSubstituteAttribute) + Expr* funcExpr; +}; + + /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as + /// the substitute primal function in a forward or backward derivative function. +class PrimalSubstituteOfAttribute : public Attribute +{ + SLANG_AST_CLASS(PrimalSubstituteOfAttribute) + + Expr* funcExpr; + Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. +}; + /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be /// included for differentiation. class NoDiffThisAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 07b3a5eac..78b0ccfcc 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -391,6 +391,10 @@ namespace Slang /// maybe synthesized and made available only after conformance checking. TypesFullyResolved, + /// All attributes are fully checked. This is the final step before + /// checking the function body. + AttributesChecked, + /// The declaration is fully checked. /// /// This step includes any validation of the declaration that is @@ -1500,12 +1504,12 @@ namespace Slang enum class DeclAssociationKind { - ForwardDerivativeFunc, BackwardDerivativeFunc, + ForwardDerivativeFunc, BackwardDerivativeFunc, PrimalSubstituteFunc }; - struct DeclAssociation + struct DeclAssociation : SerialRefObject { - SLANG_VALUE_CLASS(DeclAssociation) + SLANG_OBJ_CLASS(DeclAssociation) DeclAssociationKind kind; Decl* decl; }; @@ -1516,7 +1520,7 @@ namespace Slang { SLANG_OBJ_CLASS(DeclAssociationList) - List associations; + List> associations; }; /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c42c1892..5cd7fba45 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -34,6 +34,26 @@ namespace Slang } }; + struct SemanticsDeclAttributesVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor + { + SemanticsDeclAttributesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void visitFunctionDeclBase(FunctionDeclBase* decl); + + void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr); + + void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr); + + void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr); + }; + struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase , public DeclVisitor @@ -258,10 +278,6 @@ namespace Slang void visitFunctionDeclBase(FunctionDeclBase* funcDecl); void visitParamDecl(ParamDecl* paramDecl); - - void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); - - void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4657,270 +4673,10 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } - template - void checkDerivativeAttributeImpl( - SemanticsVisitor* visitor, - TDerivativeAttr* attr, - const List& imaginaryArguments) - { - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) - { - if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) - { - attr->funcExpr = calleeDeclRef; - return; - } - } - visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); - } - - template - const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } - - template<> - const char* getDerivativeAttrName() - { - return "ForwardDerivative"; - } - template<> - const char* getDerivativeAttrName() - { - return "BackwardDerivative"; - } - - List getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) - { - List imaginaryArguments; - for (auto param : func->getParameters()) - { - auto arg = astBuilder->create(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List imaginaryArguments; - auto isOutParam = [&](ParamDecl* param) - { - return param->findModifier() != nullptr && param->findModifier() == nullptr; - }; - - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) - { - arg->type.type = pairType; - if (isOutParam(param)) - { - // out T -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); - } - } - else - { - if (isOutParam(param)) - { - // Skip non-differentiable out params. - continue; - } - } - imaginaryArguments.add(arg); - } - if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) - { - auto arg = visitor->getASTBuilder()->create(); - arg->type.isLeftValue = false; - arg->type.type = diffReturnType; - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - // This helper function is needed to workaround a gcc bug. - // Remove when we upgrade to a newer version of gcc. - template - static T* _findModifier(Decl* decl) - { - return decl->findModifier(); - } - - template - void checkDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind) - { - DeclRef calleeDeclRef; - DeclRefExpr* calleeDeclRefExpr = nullptr; - DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create(); - diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; - diffFuncExpr->loc = derivativeOfAttr->loc; - Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); - if (!checkedDiffFuncExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) - { - auto resolvedDiffFuncExpr = as(resolvedInvoke->functionExpr); - if (resolvedDiffFuncExpr) - calleeDeclRefExpr = as(resolvedDiffFuncExpr->baseFunction); - } - - if (!calleeDeclRefExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - calleeDeclRef = calleeDeclRefExpr->declRef; - - auto calleeFunc = as(calleeDeclRef.getDecl()); - if (!calleeFunc) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - - if (auto existingModifier = _findModifier(calleeFunc)) - { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef, - getDerivativeAttrName()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); - } - derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = visitor->getASTBuilder()->create(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = - DeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - template - bool tryCheckDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List& imaginaryArgsToOriginal) - { - DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); - SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); - checkDerivativeOfAttributeImpl( - &subVisitor, - funcDecl, - derivativeOfAttr, - assocKind, - imaginaryArgsToOriginal); - return tempSink.getErrorCount() == 0; - } - - void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); - } - - void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); - } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - // Run checking on attributes that can't be fully checked in header checking stage. - checkForwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier()) - checkDerivativeAttribute(this, decl, derivativeAttr); - checkBackwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier()) - checkDerivativeAttribute(this, decl, derivativeAttr); - if (newContext.getParentDifferentiableAttribute()) { // Register additional types outside the function body first. @@ -6762,7 +6518,7 @@ namespace Slang /// Note: this function creates an empty list of candidates for the given type if /// a matching entry doesn't exist already. /// - static List& _getDeclAssociationList( + static List>& _getDeclAssociationList( Decl* decl, OrderedDictionary>& mapDeclToDeclarations) { @@ -6787,14 +6543,16 @@ namespace Slang void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) { auto moduleDecl = getModuleDecl(associated); - DeclAssociation assoc = {kind, associated}; + RefPtr assoc = new DeclAssociation(); + assoc->kind = kind; + assoc->decl = associated; _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); m_associatedDeclListsBuilt = false; m_mapDeclToAssociatedDecls.Clear(); } - List const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) + List> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) { // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. // Consider refactoring them into the same framework. @@ -6838,24 +6596,47 @@ namespace Slang FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func) { - if (func->findModifier()) - return FunctionDifferentiableLevel::Backward; - if (func->findModifier()) - return FunctionDifferentiableLevel::Backward; + return _getFuncDifferentiableLevelImpl(func, 1); + } - FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; + FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit) + { + if (recurseLimit > 0) + { + if (auto primalSubst = func->findModifier()) + { + if (auto declRefExpr = as(primalSubst->funcExpr)) + { + if (auto primalSubstFunc = declRefExpr->declRef.as()) + return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1); + } + } + } + + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; + + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier()) diffLevel = FunctionDifferentiableLevel::Forward; for (auto assocDecl : getAssociatedDeclsForDecl(func)) { - switch (assocDecl.kind) + switch (assocDecl->kind) { case DeclAssociationKind::BackwardDerivativeFunc: return FunctionDifferentiableLevel::Backward; case DeclAssociationKind::ForwardDerivativeFunc: diffLevel = FunctionDifferentiableLevel::Forward; break; + case DeclAssociationKind::PrimalSubstituteFunc: + if (auto assocFunc = as(assocDecl->decl)) + { + return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1); + } + break; default: break; } @@ -6971,6 +6752,10 @@ namespace Slang SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); break; + case DeclCheckState::AttributesChecked: + SemanticsDeclAttributesVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -7058,4 +6843,297 @@ namespace Slang return val; } + + template + void checkDerivativeAttributeImpl( + SemanticsVisitor* visitor, + TDerivativeAttr* attr, + const List& imaginaryArguments) + { + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + template + const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } + + template<> + const char* getDerivativeAttrName() + { + return "ForwardDerivative"; + } + template<> + const char* getDerivativeAttrName() + { + return "BackwardDerivative"; + } + template<> + const char* getDerivativeAttrName() + { + return "PrimalSubstitute"; + } + + List getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) + { + List imaginaryArguments; + for (auto param : func->getParameters()) + { + auto arg = astBuilder->create(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List imaginaryArguments; + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier() != nullptr && param->findModifier() == nullptr; + }; + + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) + { + arg->type.type = pairType; + if (isOutParam(param)) + { + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; + } + } + imaginaryArguments.add(arg); + } + if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) + { + auto arg = visitor->getASTBuilder()->create(); + arg->type.isLeftValue = false; + arg->type.type = diffReturnType; + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + // This helper function is needed to workaround a gcc bug. + // Remove when we upgrade to a newer version of gcc. + template + static T* _findModifier(Decl* decl) + { + return decl->findModifier(); + } + + template + void checkDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind) + { + DeclRef calleeDeclRef; + DeclRefExpr* calleeDeclRefExpr = nullptr; + HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create(); + higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + higherOrderFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor); + if (!checkedHigherOrderFuncExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + auto resolvedFuncExpr = as(resolvedInvoke->functionExpr); + if (resolvedFuncExpr) + calleeDeclRefExpr = as(resolvedFuncExpr->baseFunction); + } + + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + calleeDeclRef = calleeDeclRefExpr->declRef; + + auto calleeFunc = as(calleeDeclRef.getDecl()); + if (!calleeFunc) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + + if (auto existingModifier = _findModifier(calleeFunc)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + } + + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = visitor->getASTBuilder()->create(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + template + bool tryCheckDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List& imaginaryArgsToOriginal) + { + DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); + SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); + checkDerivativeOfAttributeImpl( + &subVisitor, + funcDecl, + derivativeOfAttr, + assocKind, + imaginaryArgsToOriginal); + return tempSink.getErrorCount() == 0; + } + + void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr) + { + checkDerivativeOfAttributeImpl( + this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc); + } + + void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + // Run checking on attributes that can't be fully checked in header checking stage. + for (auto attr : decl->modifiers) + { + if (auto fwdDerivativeOfAttr = as(attr)) + checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr); + else if (auto bwdDerivativeOfAttr = as(attr)) + checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr); + else if (auto primalOfAttr = as(attr)) + checkPrimalSubstituteOfAttribute(decl, primalOfAttr); + else if (auto fwdDerivativeAttr = as(attr)) + checkDerivativeAttribute(this, decl, fwdDerivativeAttr); + else if (auto bwdDerivativeAttr = as(attr)) + checkDerivativeAttribute(this, decl, bwdDerivativeAttr); + else if (auto primalAttr = as(attr)) + checkDerivativeAttribute(this, decl, primalAttr); + } + } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index bebaa63a2..f749361d7 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2227,10 +2227,10 @@ namespace Slang return type; } - struct DifferentiateExprCheckingActions + struct HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0; - virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0; + virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) { if (auto funcType = as(funcExpr->type.type)) @@ -2255,13 +2255,13 @@ namespace Slang } }; - struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2290,13 +2290,13 @@ namespace Slang } }; - struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2333,10 +2333,45 @@ namespace Slang } }; - static Expr* _checkDifferentiateExpr( + struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions + { + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create(); + } + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; + } + resultDiffExpr->type = baseFuncType; + if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) + { + funcDecl = as(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + } + } + }; + + static Expr* _checkHigherOrderInvokeExpr( SemanticsVisitor* semantics, - DifferentiateExpr* expr, - DifferentiateExprCheckingActions* actions) + HigherOrderInvokeExpr* expr, + HigherOrderInvokeExprCheckingActions* actions) { // Check/Resolve inner function declaration. expr->baseFunction = semantics->CheckTerm(expr->baseFunction); @@ -2354,8 +2389,8 @@ namespace Slang nullptr, overloadedExpr->loc, nullptr); - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2368,8 +2403,8 @@ namespace Slang OverloadedExpr2* result = astBuilder->create(); for (auto item : overloadedExpr2->candidiateExprs) { - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, item); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2378,20 +2413,26 @@ namespace Slang return result; } - actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction); + actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction); return expr; } Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { ForwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) { BackwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + PrimalSubstituteExprCheckingActions actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index fc1b622cc..3d40b10e9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -287,10 +287,11 @@ namespace Slang void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration); - List const& getAssociatedDeclsForDecl(Decl* decl); + List> const& getAssociatedDeclsForDecl(Decl* decl); bool isDifferentiableFunc(FunctionDeclBase* func); bool isBackwardDifferentiableFunc(FunctionDeclBase* func); + FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit); FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); private: @@ -1951,6 +1952,8 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); + Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr); + Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e6a524645..a068f19d6 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -654,33 +654,23 @@ namespace Slang hitObjectAttributesAttr->location = (int32_t)val->value; } - else if (auto derivativeAttr = as(attr)) + else if (as(attr) || as(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as(attrTarget)); - - // Ensure that the argument is a reference to a function definition or declaration. - auto diffExpr = CheckTerm(attr->args[0]); - if (diffExpr->type == getASTBuilder()->getErrorType()) - { - // Could not resolve the term. - getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as(attrTarget)); - return false; - } - // We store the partially checked funcExpr in the attribute, and - // rely on `ResolveInvoke` to resolve it to the actual function decl. - // The call to `ResolveInvoke` is deferred until we are checking the - // body of the function. - // - // Set type to null to indicate that this needs expr needs to be further resolved. - diffExpr->type.type = nullptr; - derivativeAttr->funcExpr = diffExpr; + if (auto derivativeAttr = as(attr)) + derivativeAttr->funcExpr = attr->args[0]; + else if (auto primalSubstAttr = as(attr)) + primalSubstAttr->funcExpr = attr->args[0]; } - else if (auto derivativeOfAttr = as(attr)) + else if (as(attr) || as(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as(attrTarget)); - derivativeOfAttr->funcExpr = attr->args[0]; + if (auto derivativeOfAttr = as(attr)) + derivativeOfAttr->funcExpr = attr->args[0]; + else if (auto primalOfAttr = as(attr)) + primalOfAttr->funcExpr = attr->args[0]; } else if (auto comInterfaceAttr = as(attr)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 91af731ad..f786089f8 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1549,7 +1549,7 @@ namespace Slang // Base is a normal or fully specialized generic function. OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - if (auto diffExpr = as(expr)) + if (auto diffExpr = as(expr)) { candidate.funcType = as(diffExpr->type.type); } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 27106b6a2..2090cd4dc 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -344,6 +344,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* static bool _isDifferentiableFunc(IRInst* func) { + func = getResolvedInstForDecorations(func); for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration()) { switch (decor->getOp()) @@ -369,6 +370,37 @@ static IRFuncType* _getCalleeActualFuncType(IRInst* callee) return nullptr; } +IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee) +{ + if (auto func = as(callee)) + { + if (auto decor = func->findDecoration()) + return decor->getPrimalSubstituteFunc(); + } + else if (auto specialize = as(callee)) + { + auto innerGen = as(specialize->getBase()); + if (!innerGen) + return nullptr; + auto innerFunc = findGenericReturnVal(innerGen); + if (auto decor = innerFunc->findDecoration()) + { + auto substSpecialize = as(decor->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(substSpecialize); + SLANG_RELEASE_ASSERT(substSpecialize->getArgCount() == specialize->getArgCount()); + List args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + return builder->emitSpecializeInst( + callee->getFullType(), + substSpecialize->getBase(), + (UInt)args.getCount(), + args.getBuffer()); + } + } + return callee; +} + // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials // in the current transcription context. @@ -393,10 +425,20 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee); + auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee); IRInst* diffCallee = nullptr; + if (substPrimalCallee == primalCallee) + { + instMapD.TryGetValue(origCallee, diffCallee); + } + else + { + instMapD.TryGetValue(substPrimalCallee, diffCallee); + primalCallee = substPrimalCallee; + } - if (instMapD.TryGetValue(origCallee, diffCallee)) + if (diffCallee) { } else if (auto derivativeReferenceDecor = primalCallee->findDecoration()) @@ -750,6 +792,9 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as(jvpFunc)); + auto derivativeDecoration = genericInnerVal->findDecoration(); + SLANG_RELEASE_ASSERT(derivativeDecoration); + return InstPair(primalSpecialize, jvpFunc); } else if (auto derivativeDecoration = genericInnerVal->findDecoration()) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 709968f77..d7cce7c53 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -736,7 +736,7 @@ namespace Slang // We need to insert a local variable to store this var. IRInst* operandReplacement = nullptr; - if (canInstBeStored(operand)) + if (canTypeBeStored(operand->getDataType())) { auto var = storeInstAsLocalVar(operand); builder.setInsertBefore(inst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 73d9b6ba6..ed122c862 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I return true; if (isChildInstOf(currentParent, origInst->getParent())) return true; + + // If origInst is defined in the first block of the same function as current inst (e.g. a param), + // we can use it as primal. + // More generally, we should test if origInst dominates currentParent, but that requires calculating + // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient. + auto parentFunc = getParentFunc(currentParent); + if (parentFunc && origInst->parent == parentFunc->getFirstBlock()) + return true; return false; } @@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_PrimalSubstituteDecoration: break; default: if (!outInstsToSkip.Contains(use->getUser())) @@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene return InstPair(primalGeneric, diffGeneric); } +IRInst* getActualInstToTranscribe(IRInst* inst) +{ + if (auto gen = as(inst)) + { + auto retVal = findGenericReturnVal(gen); + if (retVal->getOp() != kIROp_Func) + return inst; + if (auto primalSubst = retVal->findDecoration()) + { + auto spec = as(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(spec); + return spec->getBase(); + } + } + else if (auto func = as(inst)) + { + if (auto primalSubst = func->findDecoration()) + { + auto actualFunc = as(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(actualFunc); + return actualFunc; + } + } + return inst; +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // depending on the op-code. // instsInProgress.Add(origInst); - - InstPair pair = transcribeInst(builder, origInst); + auto actualInstToTranscribe = getActualInstToTranscribe(origInst); + InstPair pair = transcribeInst(builder, actualInstToTranscribe); instsInProgress.Remove(origInst); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index a05fe7044..2347c7a8f 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext bool shouldStoreVar(IRVar* var) { // Always store intermediate context var. - if (var->findDecoration()) + if (auto typeDecor = var->findDecoration()) { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. + if (auto spec = as(as(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i)->getDataType())) + return false; + } + } return true; } @@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext // 2. Does the var have a store // - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var)); + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as(var->getDataType())->getValueType())); } bool shouldStoreInst(IRInst* inst) @@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext return false; } - if (!canInstBeStored(inst)) + if (!canTypeBeStored(inst->getDataType())) return false; // Never store certain opcodes. @@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext case kIROp_MakeOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (inst->getOp() == kIROp_Call) { // The primal calls should be marked as no side effect so they can be DCE'd if possible. - builder.addSimpleDecoration(inst); + // We can only do so if the intermediate context of the callee is stored. + if (primalCtx->getBackwardDerivativePrimalContextVar() + ->findDecoration()) + { + builder.addSimpleDecoration(inst); + } } } } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index fcfbf3bee..65e880868 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -563,12 +563,15 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* return false; } -bool canInstBeStored(IRInst* inst) +bool canTypeBeStored(IRInst* type) { - if (as(inst->getDataType())) + if (!type) + return false; + + if (as(type)) return true; - switch (inst->getDataType()->getOp()) + switch (type->getOp()) { case kIROp_StructType: case kIROp_OptionalType: @@ -716,6 +719,9 @@ struct AutoDiffPass : public InstPassBase break; } break; + case kIROp_PrimalSubstitute: + // Explicit primal subst operator is not yet supported. + SLANG_UNIMPLEMENTED_X("explicit primal_subst operator."); default: break; } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index a4eb94461..f757375d8 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -295,7 +295,7 @@ bool isBackwardDifferentiableFunc(IRInst* func); bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); -bool canInstBeStored(IRInst* inst); +bool canTypeBeStored(IRInst* type); inline bool isRelevantDifferentialPair(IRType* type) { diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 186b0cc03..14f6394e2 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -39,12 +39,17 @@ public: return false; } - bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { func = getResolvedInstForDecorations(func); if (!func) return false; + if (auto substDecor = func->findDecoration()) + { + func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc()); + if (!func) + return false; + } for (auto decorations : func->getDecorations()) { @@ -84,7 +89,13 @@ public: if (!func) return false; - + if (auto substDecor = func->findDecoration()) + { + func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc()); + if (!func) + return false; + } + if (auto existingLevel = differentiableFunctions.TryGetValue(func)) return *existingLevel >= level; diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index e5c9b1fdb..1fe88e780 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -379,25 +379,37 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o // if (options.keepExportsAlive) { - if (inst->findDecoration()) + bool isImported = false; + bool shouldKeptAliveIfImported = false; + IRInst* innerInst = inst; + if (auto genInst = as(inst)) { - return true; + innerInst = findInnerMostGenericReturnVal(genInst); } - if (inst->findDecoration()) + for (auto decor : inst->getDecorations()) { - if (inst->findDecoration()) - return true; - if (inst->findDecoration()) + switch (decor->getOp()) + { + case kIROp_ExportDecoration: return true; - if (auto genInst = as(inst)) + case kIROp_ImportDecoration: + isImported = true; + break; + } + } + for (auto decor : innerInst->getDecorations()) + { + switch (decor->getOp()) { - auto inner = findInnerMostGenericReturnVal(genInst); - if (inner->findDecoration()) - return true; - if (inner->findDecoration()) - return true; + case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: + shouldKeptAliveIfImported = true; + break; } } + if (isImported && shouldKeptAliveIfImported) + return true; } if (options.keepLayoutsAlive && inst->findDecoration()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c704359e6..7411d031c 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -764,6 +764,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) + /// Used by the auto-diff pass to hold a reference to the + /// primal substitute function. + INST(PrimalSubstituteDecoration, primalSubstFunc, 1, 0) + /// Decorations to associate an original function with compiler generated backward derivative functions. INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0) INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0) @@ -882,6 +886,8 @@ INST(BackwardDifferentiatePropagate, BackwardDifferentiatePropagate, 1, // replaced with `BackwardDifferentiatePrimal` and `BackwardDifferentiatePropagate`. INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) +INST(PrimalSubstitute, PrimalSubstitute, 1, 0) + // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5269ae02f..ae31219bd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -608,6 +608,17 @@ struct IRForwardDerivativeDecoration : IRDecoration IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; +struct IRPrimalSubstituteDecoration : IRDecoration +{ + enum + { + kOp = kIROp_PrimalSubstituteDecoration + }; + IR_LEAF_ISA(PrimalSubstituteDecoration) + + IRInst* getPrimalSubstituteFunc() { return getOperand(0); } +}; + struct IRBackwardDerivativeIntermediateTypeDecoration : IRDecoration { enum @@ -879,6 +890,20 @@ struct IRBackwardDifferentiate : IRInst IR_LEAF_ISA(BackwardDifferentiate) }; +// Retrieves the primal substitution function for the given function. +struct IRPrimalSubstitute : IRInst +{ + enum + { + kOp = kIROp_PrimalSubstitute + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(PrimalSubstitute) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2804,6 +2829,7 @@ public: IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); + IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); @@ -3623,6 +3649,11 @@ public: addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx); } + void addPrimalSubstituteDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_PrimalSubstituteDecoration, jvpFn); + } + void addLoopCounterDecoration(IRInst* value) { addDecoration(value, kIROp_LoopCounterDecoration); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index bcff5621c..b976f4b21 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -443,6 +443,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_SequentialIDDecoration: case kIROp_ForwardDerivativeDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0aa2dc607..2819a6d83 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3106,6 +3106,17 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitPrimalSubstituteInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst( + this, + kIROp_PrimalSubstitute, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn) { auto inst = createInst( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 261e08168..d8912cbd4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor baseVal.val)); } + LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitPrimalSubstituteInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); @@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (decl->findModifier()) - { - getBuilder()->addForwardDifferentiableDecoration(irFunc); - } - if (decl->findModifier()) - { - getBuilder()->addBackwardDifferentiableDecoration(irFunc); - } if (auto differentialAttr = decl->findModifier()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); @@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version); } - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); - } - - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); - } - - if (auto attr = decl->findModifier()) - { - auto builder = getBuilder(); - IRType* intType = builder->getIntType(); - - IRInst* operands[3] = { - builder->getIntValue(intType, attr->x), - builder->getIntValue(intType, attr->y), - builder->getIntValue(intType, attr->z) - }; - - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (auto attr = decl->findModifier()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); - } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute + setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier()) + for (auto modifier : decl->modifiers) { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); - } + if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (auto instanceAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr); + getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); + } + else if (auto maxVertCountAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr); + getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); + } + else if (auto numThreadsAttr = as(modifier)) + { + auto builder = getBuilder(); + IRType* intType = builder->getIntType(); - if (auto attr = decl->findModifier()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); - } + IRInst* operands[3] = { + builder->getIntValue(intType, numThreadsAttr->x), + builder->getIntValue(intType, numThreadsAttr->y), + builder->getIntValue(intType, numThreadsAttr->z) + }; - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); - } + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (auto domainAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr); + getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); + } + else if (auto partitionAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr); + getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); + } + else if (auto outputTopAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); + } + else if (auto outputCtrlPtAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); + } + else if (auto spvInstOpAttr = as(modifier)) + { + auto builder = getBuilder(); + IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0); - if (auto attr = decl->findModifier()) - { - auto builder = getBuilder(); - IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0); + IRStringLit* setStringLit = nullptr; + if (spvInstOpAttr->args.getCount() > 1) + { + IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1); + if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + { + setStringLit = checkSetStringLit; + } + } - IRStringLit* setStringLit = nullptr; - if (attr->args.getCount() > 1) - { - IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1); - if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + // If it has a `set` defined, set it on the decoration + if (setStringLit) { - setStringLit = checkSetStringLit; + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + } + else + { + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); } } - - // If it has a `set` defined, set it on the decoration - if (setStringLit) + else if (as(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } - else + else if (as(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); - } - - if (auto intrinsicOp = decl->findModifier()) - { - auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); - getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); - } - - // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute - setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - - if (auto attr = decl->findModifier()) - { - // We need to lower the decl ref to the custom derivative function to IR. - // The IR insts correspond to the decl ref is not part of the function we - // are processing. If we emit it directly to within the function, it could - // mess up the assumption on the form of the IR (e.g. having non decoration insts - // appearing in the middle of decoration insts). so we emit the decl ref to the - // function's parent for now. - - subContext->irBuilder->setInsertInto(irFunc->getParent()); - - auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* derivativeFunc = loweredVal.val; - - if (as(attr)) - getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); - else - getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + else if (auto intrinsicOp = as(modifier)) + { + auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); + getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); + } + else if (as(modifier) || as(modifier)) + { + // We need to lower the decl ref to the custom derivative function to IR. + // The IR insts correspond to the decl ref is not part of the function we + // are processing. If we emit it directly to within the function, it could + // mess up the assumption on the form of the IR (e.g. having non decoration insts + // appearing in the middle of decoration insts). so we emit the decl ref to the + // function's parent for now. + + subContext->irBuilder->setInsertInto(irFunc->getParent()); + Expr* funcExpr = nullptr; + if (auto udAttr = as(modifier)) + funcExpr = udAttr->funcExpr; + else if (auto primalAttr = as(modifier)) + funcExpr = primalAttr->funcExpr; + + auto loweredVal = lowerRValueExpr(subContext, funcExpr); + + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRInst* derivativeFunc = loweredVal.val; + + if (as(modifier)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as(modifier)) + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc); - // Reset cursor. - subContext->irBuilder->setInsertInto(irFunc); + // Reset cursor. + subContext->irBuilder->setInsertInto(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier()) + for (auto modifier : decl->modifiers) { - if (auto originalDeclRefExpr = as(attr->funcExpr)) + if (as(modifier) || as(modifier)) { - NestedContext originalContextFunc(this); - auto originalSubBuilder = originalContextFunc.getBuilder(); - auto originalSubContext = originalContextFunc.getContext(); - if (auto outterGeneric = getOuterGeneric(irFunc)) - originalSubBuilder->setInsertBefore(outterGeneric); - else - originalSubBuilder->setInsertBefore(irFunc); - auto originalFuncDecl = as(originalDeclRefExpr->declRef.getDecl()); - SLANG_RELEASE_ASSERT(originalFuncDecl); - - auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; - if (auto originalFuncGeneric = as(originalFuncVal)) + Expr* funcExpr = nullptr; + Expr* backDeclRef = nullptr; + if (auto attr = as(modifier)) { - originalFuncVal = findGenericReturnVal(originalFuncGeneric); + funcExpr = attr->funcExpr; + backDeclRef = attr->backDeclRef; } - originalSubBuilder->setInsertBefore(originalFuncVal); - auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - if (as(attr)) + else if (auto primalAttr = as(modifier)) { - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); - getBuilder()->addForwardDifferentiableDecoration(irFunc); + funcExpr = primalAttr->funcExpr; + backDeclRef = primalAttr->backDeclRef; } - else + + if (auto originalDeclRefExpr = as(funcExpr)) { - originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + NestedContext originalContextFunc(this); + auto originalSubBuilder = originalContextFunc.getBuilder(); + auto originalSubContext = originalContextFunc.getContext(); + if (auto outterGeneric = getOuterGeneric(irFunc)) + originalSubBuilder->setInsertBefore(outterGeneric); + else + originalSubBuilder->setInsertBefore(irFunc); + auto originalFuncDecl = as(originalDeclRefExpr->declRef.getDecl()); + SLANG_RELEASE_ASSERT(originalFuncDecl); + + auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; + if (auto originalFuncGeneric = as(originalFuncVal)) + { + originalFuncVal = findGenericReturnVal(originalFuncGeneric); + } + originalSubBuilder->setInsertBefore(originalFuncVal); + auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef); + if (as(modifier)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as(modifier)) + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + } + else + { + originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val); + } } + subContext->irBuilder->setInsertInto(irFunc); + finalVal->moveToEnd(); } - subContext->irBuilder->setInsertInto(irFunc); - finalVal->moveToEnd(); } return LoweredValInfo::simple(finalVal); } -- cgit v1.2.3