diff options
Diffstat (limited to 'source')
22 files changed, 855 insertions, 581 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 54f927816..4301eda94 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -12,6 +12,9 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; + +__attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) @@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; + __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair< // Sine and cosine __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(T x, out T s, out T c) { s = sin(x); @@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c) __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) { s = sin(x); @@ -1053,62 +1061,18 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c) { s = sin(x); c = cos(x); } -__generic<T: __BuiltinFloatingPointType> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T : __BuiltinFloatingPointType, let N : int> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<vector<T, N>> x, out DifferentialPair<vector<T, N>> s, out DifferentialPair<vector<T, N>> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix<T, N, M>> s, out DifferentialPair<matrix<T, N, M>> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic<T: __BuiltinFloatingPointType> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} -__generic<T: __BuiltinFloatingPointType, let N : int> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} - -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} // dst (obsolete) __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] +[PrimalSubstituteOf(dst)] vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) { vector<T, 4> dest; @@ -1118,25 +1082,11 @@ vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1) dest.w = src1.w; ; return dest; } -__generic<T : __BuiltinFloatingPointType> -[ForwardDerivativeOf(dst)] -[ForceInline] -DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1) -{ - return __fwd_diff(__dst_impl)(src0, src1); -} -__generic<T : __BuiltinFloatingPointType> -[BackwardDerivativeOf(dst)] -[ForceInline] -void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut) -{ - __bwd_diff(__dst_impl)(src0, src1, dOut); -} // Legacy lighting function (obsolete) -__target_intrinsic(hlsl) [__readNone] [BackwardDifferentiable] +[PrimalSubstituteOf(lit)] float4 __lit_impl(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; @@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m) let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); return float4(ambient, diffuse, specular, 1.0f); } -[ForwardDerivativeOf(lit)] -[ForceInline] -DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m) -{ - return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); -} -[BackwardDerivativeOf(lit)] -[ForceInline] -void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut) -{ - __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); -} - // Matrix determinant __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDifferentiable] diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index ba0b4ce7a..301dded49 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -452,10 +452,14 @@ class HigherOrderInvokeExpr : public Expr List<Name*> newParameterNames; }; +class PrimalSubstituteExpr : public HigherOrderInvokeExpr +{ + SLANG_AST_CLASS(PrimalSubstituteExpr) +}; + class DifferentiateExpr : public HigherOrderInvokeExpr { SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) - }; /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 7dd0819d8..80c770c3a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1151,6 +1151,24 @@ class BackwardDerivativeOfAttribute : public DerivativeOfAttribute SLANG_AST_CLASS(BackwardDerivativeOfAttribute) }; + /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should + /// be used as the primal function substitute when differentiating code that calls the primal function. +class PrimalSubstituteAttribute : public Attribute +{ + SLANG_AST_CLASS(PrimalSubstituteAttribute) + Expr* funcExpr; +}; + + /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as + /// the substitute primal function in a forward or backward derivative function. +class PrimalSubstituteOfAttribute : public Attribute +{ + SLANG_AST_CLASS(PrimalSubstituteOfAttribute) + + Expr* funcExpr; + Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. +}; + /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be /// included for differentiation. class NoDiffThisAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 07b3a5eac..78b0ccfcc 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -391,6 +391,10 @@ namespace Slang /// maybe synthesized and made available only after conformance checking. TypesFullyResolved, + /// All attributes are fully checked. This is the final step before + /// checking the function body. + AttributesChecked, + /// The declaration is fully checked. /// /// This step includes any validation of the declaration that is @@ -1500,12 +1504,12 @@ namespace Slang enum class DeclAssociationKind { - ForwardDerivativeFunc, BackwardDerivativeFunc, + ForwardDerivativeFunc, BackwardDerivativeFunc, PrimalSubstituteFunc }; - struct DeclAssociation + struct DeclAssociation : SerialRefObject { - SLANG_VALUE_CLASS(DeclAssociation) + SLANG_OBJ_CLASS(DeclAssociation) DeclAssociationKind kind; Decl* decl; }; @@ -1516,7 +1520,7 @@ namespace Slang { SLANG_OBJ_CLASS(DeclAssociationList) - List<DeclAssociation> associations; + List<RefPtr<DeclAssociation>> associations; }; /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c42c1892..5cd7fba45 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -34,6 +34,26 @@ namespace Slang } }; + struct SemanticsDeclAttributesVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclAttributesVisitor> + { + SemanticsDeclAttributesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void visitFunctionDeclBase(FunctionDeclBase* decl); + + void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr); + + void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr); + + void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr); + }; + struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclHeaderVisitor> @@ -258,10 +278,6 @@ namespace Slang void visitFunctionDeclBase(FunctionDeclBase* funcDecl); void visitParamDecl(ParamDecl* paramDecl); - - void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); - - void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4657,270 +4673,10 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } - template<typename TDerivativeAttr> - void checkDerivativeAttributeImpl( - SemanticsVisitor* visitor, - TDerivativeAttr* attr, - const List<Expr*>& imaginaryArguments) - { - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) - { - attr->funcExpr = calleeDeclRef; - return; - } - } - visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); - } - - template<typename TDerivativeAttr> - const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } - - template<> - const char* getDerivativeAttrName<ForwardDerivativeAttribute>() - { - return "ForwardDerivative"; - } - template<> - const char* getDerivativeAttrName<BackwardDerivativeAttribute>() - { - return "BackwardDerivative"; - } - - List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - for (auto param : func->getParameters()) - { - auto arg = astBuilder->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - auto isOutParam = [&](ParamDecl* param) - { - return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; - }; - - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) - { - arg->type.type = pairType; - if (isOutParam(param)) - { - // out T -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); - } - } - else - { - if (isOutParam(param)) - { - // Skip non-differentiable out params. - continue; - } - } - imaginaryArguments.add(arg); - } - if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) - { - auto arg = visitor->getASTBuilder()->create<InitializerListExpr>(); - arg->type.isLeftValue = false; - arg->type.type = diffReturnType; - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - // This helper function is needed to workaround a gcc bug. - // Remove when we upgrade to a newer version of gcc. - template <typename T> - static T* _findModifier(Decl* decl) - { - return decl->findModifier<T>(); - } - - template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> - void checkDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind) - { - DeclRef<Decl> calleeDeclRef; - DeclRefExpr* calleeDeclRefExpr = nullptr; - DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); - diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; - diffFuncExpr->loc = derivativeOfAttr->loc; - Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); - if (!checkedDiffFuncExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr); - if (resolvedDiffFuncExpr) - calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction); - } - - if (!calleeDeclRefExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - calleeDeclRef = calleeDeclRefExpr->declRef; - - auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); - if (!calleeFunc) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - - if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) - { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef, - getDerivativeAttrName<TDerivativeAttr>()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); - } - derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = - DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - template<typename TDerivativeAttr, typename TDerivativeOfAttr> - bool tryCheckDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List<Expr*>& imaginaryArgsToOriginal) - { - DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); - SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); - checkDerivativeOfAttributeImpl<TDerivativeAttr>( - &subVisitor, - funcDecl, - derivativeOfAttr, - assocKind, - imaginaryArgsToOriginal); - return tempSink.getErrorCount() == 0; - } - - void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); - } - - void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); - } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - // Run checking on attributes that can't be fully checked in header checking stage. - checkForwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) - checkDerivativeAttribute(this, decl, derivativeAttr); - checkBackwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>()) - checkDerivativeAttribute(this, decl, derivativeAttr); - if (newContext.getParentDifferentiableAttribute()) { // Register additional types outside the function body first. @@ -6762,7 +6518,7 @@ namespace Slang /// Note: this function creates an empty list of candidates for the given type if /// a matching entry doesn't exist already. /// - static List<DeclAssociation>& _getDeclAssociationList( + static List<RefPtr<DeclAssociation>>& _getDeclAssociationList( Decl* decl, OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations) { @@ -6787,14 +6543,16 @@ namespace Slang void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) { auto moduleDecl = getModuleDecl(associated); - DeclAssociation assoc = {kind, associated}; + RefPtr<DeclAssociation> assoc = new DeclAssociation(); + assoc->kind = kind; + assoc->decl = associated; _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); m_associatedDeclListsBuilt = false; m_mapDeclToAssociatedDecls.Clear(); } - List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) + List<RefPtr<DeclAssociation>> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) { // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. // Consider refactoring them into the same framework. @@ -6838,6 +6596,23 @@ namespace Slang FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func) { + return _getFuncDifferentiableLevelImpl(func, 1); + } + + FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit) + { + if (recurseLimit > 0) + { + if (auto primalSubst = func->findModifier<PrimalSubstituteAttribute>()) + { + if (auto declRefExpr = as<DeclRefExpr>(primalSubst->funcExpr)) + { + if (auto primalSubstFunc = declRefExpr->declRef.as<FunctionDeclBase>()) + return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1); + } + } + } + if (func->findModifier<BackwardDifferentiableAttribute>()) return FunctionDifferentiableLevel::Backward; if (func->findModifier<BackwardDerivativeAttribute>()) @@ -6849,13 +6624,19 @@ namespace Slang for (auto assocDecl : getAssociatedDeclsForDecl(func)) { - switch (assocDecl.kind) + switch (assocDecl->kind) { case DeclAssociationKind::BackwardDerivativeFunc: return FunctionDifferentiableLevel::Backward; case DeclAssociationKind::ForwardDerivativeFunc: diffLevel = FunctionDifferentiableLevel::Forward; break; + case DeclAssociationKind::PrimalSubstituteFunc: + if (auto assocFunc = as<FunctionDeclBase>(assocDecl->decl)) + { + return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1); + } + break; default: break; } @@ -6971,6 +6752,10 @@ namespace Slang SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); break; + case DeclCheckState::AttributesChecked: + SemanticsDeclAttributesVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -7058,4 +6843,297 @@ namespace Slang return val; } + + template<typename TDerivativeAttr> + void checkDerivativeAttributeImpl( + SemanticsVisitor* visitor, + TDerivativeAttr* attr, + const List<Expr*>& imaginaryArguments) + { + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + template<typename TDerivativeAttr> + const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } + + template<> + const char* getDerivativeAttrName<ForwardDerivativeAttribute>() + { + return "ForwardDerivative"; + } + template<> + const char* getDerivativeAttrName<BackwardDerivativeAttribute>() + { + return "BackwardDerivative"; + } + template<> + const char* getDerivativeAttrName<PrimalSubstituteAttribute>() + { + return "PrimalSubstitute"; + } + + List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : func->getParameters()) + { + auto arg = astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; + }; + + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) + { + arg->type.type = pairType; + if (isOutParam(param)) + { + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; + } + } + imaginaryArguments.add(arg); + } + if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) + { + auto arg = visitor->getASTBuilder()->create<InitializerListExpr>(); + arg->type.isLeftValue = false; + arg->type.type = diffReturnType; + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + // This helper function is needed to workaround a gcc bug. + // Remove when we upgrade to a newer version of gcc. + template <typename T> + static T* _findModifier(Decl* decl) + { + return decl->findModifier<T>(); + } + + template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> + void checkDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind) + { + DeclRef<Decl> calleeDeclRef; + DeclRefExpr* calleeDeclRefExpr = nullptr; + HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + higherOrderFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor); + if (!checkedHigherOrderFuncExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr); + if (resolvedFuncExpr) + calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction); + } + + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + calleeDeclRef = calleeDeclRefExpr->declRef; + + auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); + if (!calleeFunc) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + + if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName<TDerivativeAttr>()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + } + + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + template<typename TDerivativeAttr, typename TDerivativeOfAttr> + bool tryCheckDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List<Expr*>& imaginaryArgsToOriginal) + { + DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); + SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); + checkDerivativeOfAttributeImpl<TDerivativeAttr>( + &subVisitor, + funcDecl, + derivativeOfAttr, + assocKind, + imaginaryArgsToOriginal); + return tempSink.getErrorCount() == 0; + } + + void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<PrimalSubstituteAttribute, PrimalSubstituteExpr>( + this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc); + } + + void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + // Run checking on attributes that can't be fully checked in header checking stage. + for (auto attr : decl->modifiers) + { + if (auto fwdDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr)) + checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr); + else if (auto bwdDerivativeOfAttr = as<BackwardDerivativeOfAttribute>(attr)) + checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr); + else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr)) + checkPrimalSubstituteOfAttribute(decl, primalOfAttr); + else if (auto fwdDerivativeAttr = as<ForwardDerivativeAttribute>(attr)) + checkDerivativeAttribute(this, decl, fwdDerivativeAttr); + else if (auto bwdDerivativeAttr = as<BackwardDerivativeAttribute>(attr)) + checkDerivativeAttribute(this, decl, bwdDerivativeAttr); + else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr)) + checkDerivativeAttribute(this, decl, primalAttr); + } + } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index bebaa63a2..f749361d7 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2227,10 +2227,10 @@ namespace Slang return type; } - struct DifferentiateExprCheckingActions + struct HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0; - virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0; + virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) { if (auto funcType = as<FuncType>(funcExpr->type.type)) @@ -2255,13 +2255,13 @@ namespace Slang } }; - struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2290,13 +2290,13 @@ namespace Slang } }; - struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2333,10 +2333,45 @@ namespace Slang } }; - static Expr* _checkDifferentiateExpr( + struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions + { + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<PrimalSubstituteExpr>(); + } + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; + } + resultDiffExpr->type = baseFuncType; + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) + { + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + } + } + }; + + static Expr* _checkHigherOrderInvokeExpr( SemanticsVisitor* semantics, - DifferentiateExpr* expr, - DifferentiateExprCheckingActions* actions) + HigherOrderInvokeExpr* expr, + HigherOrderInvokeExprCheckingActions* actions) { // Check/Resolve inner function declaration. expr->baseFunction = semantics->CheckTerm(expr->baseFunction); @@ -2354,8 +2389,8 @@ namespace Slang nullptr, overloadedExpr->loc, nullptr); - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2368,8 +2403,8 @@ namespace Slang OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); for (auto item : overloadedExpr2->candidiateExprs) { - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, item); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2378,20 +2413,26 @@ namespace Slang return result; } - actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction); + actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction); return expr; } Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { ForwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) { BackwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + PrimalSubstituteExprCheckingActions actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index fc1b622cc..3d40b10e9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -287,10 +287,11 @@ namespace Slang void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration); - List<DeclAssociation> const& getAssociatedDeclsForDecl(Decl* decl); + List<RefPtr<DeclAssociation>> const& getAssociatedDeclsForDecl(Decl* decl); bool isDifferentiableFunc(FunctionDeclBase* func); bool isBackwardDifferentiableFunc(FunctionDeclBase* func); + FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit); FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); private: @@ -1951,6 +1952,8 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); + Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr); + Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e6a524645..a068f19d6 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -654,33 +654,23 @@ namespace Slang hitObjectAttributesAttr->location = (int32_t)val->value; } - else if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr)) + else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); - - // Ensure that the argument is a reference to a function definition or declaration. - auto diffExpr = CheckTerm(attr->args[0]); - if (diffExpr->type == getASTBuilder()->getErrorType()) - { - // Could not resolve the term. - getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); - return false; - } - // We store the partially checked funcExpr in the attribute, and - // rely on `ResolveInvoke` to resolve it to the actual function decl. - // The call to `ResolveInvoke` is deferred until we are checking the - // body of the function. - // - // Set type to null to indicate that this needs expr needs to be further resolved. - diffExpr->type.type = nullptr; - derivativeAttr->funcExpr = diffExpr; + if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr)) + derivativeAttr->funcExpr = attr->args[0]; + else if (auto primalSubstAttr = as<PrimalSubstituteAttribute>(attr)) + primalSubstAttr->funcExpr = attr->args[0]; } - else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr)) + else if (as<DerivativeOfAttribute>(attr) || as<PrimalSubstituteOfAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); - derivativeOfAttr->funcExpr = attr->args[0]; + if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr)) + derivativeOfAttr->funcExpr = attr->args[0]; + else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr)) + primalOfAttr->funcExpr = attr->args[0]; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 91af731ad..f786089f8 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1549,7 +1549,7 @@ namespace Slang // Base is a normal or fully specialized generic function. OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - if (auto diffExpr = as<DifferentiateExpr>(expr)) + if (auto diffExpr = as<HigherOrderInvokeExpr>(expr)) { candidate.funcType = as<FuncType>(diffExpr->type.type); } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 27106b6a2..2090cd4dc 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -344,6 +344,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* static bool _isDifferentiableFunc(IRInst* func) { + func = getResolvedInstForDecorations(func); for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration()) { switch (decor->getOp()) @@ -369,6 +370,37 @@ static IRFuncType* _getCalleeActualFuncType(IRInst* callee) return nullptr; } +IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee) +{ + if (auto func = as<IRFunc>(callee)) + { + if (auto decor = func->findDecoration<IRPrimalSubstituteDecoration>()) + return decor->getPrimalSubstituteFunc(); + } + else if (auto specialize = as<IRSpecialize>(callee)) + { + auto innerGen = as<IRGeneric>(specialize->getBase()); + if (!innerGen) + return nullptr; + auto innerFunc = findGenericReturnVal(innerGen); + if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>()) + { + auto substSpecialize = as<IRSpecialize>(decor->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(substSpecialize); + SLANG_RELEASE_ASSERT(substSpecialize->getArgCount() == specialize->getArgCount()); + List<IRInst*> args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + return builder->emitSpecializeInst( + callee->getFullType(), + substSpecialize->getBase(), + (UInt)args.getCount(), + args.getBuffer()); + } + } + return callee; +} + // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials // in the current transcription context. @@ -393,10 +425,20 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee); + auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee); IRInst* diffCallee = nullptr; + if (substPrimalCallee == primalCallee) + { + instMapD.TryGetValue(origCallee, diffCallee); + } + else + { + instMapD.TryGetValue(substPrimalCallee, diffCallee); + primalCallee = substPrimalCallee; + } - if (instMapD.TryGetValue(origCallee, diffCallee)) + if (diffCallee) { } else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) @@ -750,6 +792,9 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>(); + SLANG_RELEASE_ASSERT(derivativeDecoration); + return InstPair(primalSpecialize, jvpFunc); } else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>()) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 709968f77..d7cce7c53 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -736,7 +736,7 @@ namespace Slang // We need to insert a local variable to store this var. IRInst* operandReplacement = nullptr; - if (canInstBeStored(operand)) + if (canTypeBeStored(operand->getDataType())) { auto var = storeInstAsLocalVar(operand); builder.setInsertBefore(inst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 73d9b6ba6..ed122c862 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I return true; if (isChildInstOf(currentParent, origInst->getParent())) return true; + + // If origInst is defined in the first block of the same function as current inst (e.g. a param), + // we can use it as primal. + // More generally, we should test if origInst dominates currentParent, but that requires calculating + // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient. + auto parentFunc = getParentFunc(currentParent); + if (parentFunc && origInst->parent == parentFunc->getFirstBlock()) + return true; return false; } @@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_PrimalSubstituteDecoration: break; default: if (!outInstsToSkip.Contains(use->getUser())) @@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene return InstPair(primalGeneric, diffGeneric); } +IRInst* getActualInstToTranscribe(IRInst* inst) +{ + if (auto gen = as<IRGeneric>(inst)) + { + auto retVal = findGenericReturnVal(gen); + if (retVal->getOp() != kIROp_Func) + return inst; + if (auto primalSubst = retVal->findDecoration<IRPrimalSubstituteDecoration>()) + { + auto spec = as<IRSpecialize>(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(spec); + return spec->getBase(); + } + } + else if (auto func = as<IRFunc>(inst)) + { + if (auto primalSubst = func->findDecoration<IRPrimalSubstituteDecoration>()) + { + auto actualFunc = as<IRFunc>(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(actualFunc); + return actualFunc; + } + } + return inst; +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // depending on the op-code. // instsInProgress.Add(origInst); - - InstPair pair = transcribeInst(builder, origInst); + auto actualInstToTranscribe = getActualInstToTranscribe(origInst); + InstPair pair = transcribeInst(builder, actualInstToTranscribe); instsInProgress.Remove(origInst); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index a05fe7044..2347c7a8f 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext bool shouldStoreVar(IRVar* var) { // Always store intermediate context var. - if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. + if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i)->getDataType())) + return false; + } + } return true; } @@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext // 2. Does the var have a store // - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var)); + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); } bool shouldStoreInst(IRInst* inst) @@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext return false; } - if (!canInstBeStored(inst)) + if (!canTypeBeStored(inst->getDataType())) return false; // Never store certain opcodes. @@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext case kIROp_MakeOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (inst->getOp() == kIROp_Call) { // The primal calls should be marked as no side effect so they can be DCE'd if possible. - builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst); + // We can only do so if the intermediate context of the callee is stored. + if (primalCtx->getBackwardDerivativePrimalContextVar() + ->findDecoration<IRPrimalValueStructKeyDecoration>()) + { + builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst); + } } } } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index fcfbf3bee..65e880868 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -563,12 +563,15 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* return false; } -bool canInstBeStored(IRInst* inst) +bool canTypeBeStored(IRInst* type) { - if (as<IRBasicType>(inst->getDataType())) + if (!type) + return false; + + if (as<IRBasicType>(type)) return true; - switch (inst->getDataType()->getOp()) + switch (type->getOp()) { case kIROp_StructType: case kIROp_OptionalType: @@ -716,6 +719,9 @@ struct AutoDiffPass : public InstPassBase break; } break; + case kIROp_PrimalSubstitute: + // Explicit primal subst operator is not yet supported. + SLANG_UNIMPLEMENTED_X("explicit primal_subst operator."); default: break; } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index a4eb94461..f757375d8 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -295,7 +295,7 @@ bool isBackwardDifferentiableFunc(IRInst* func); bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); -bool canInstBeStored(IRInst* inst); +bool canTypeBeStored(IRInst* type); inline bool isRelevantDifferentialPair(IRType* type) { diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 186b0cc03..14f6394e2 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -39,12 +39,17 @@ public: return false; } - bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { func = getResolvedInstForDecorations(func); if (!func) return false; + if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>()) + { + func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc()); + if (!func) + return false; + } for (auto decorations : func->getDecorations()) { @@ -84,7 +89,13 @@ public: if (!func) return false; - + if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>()) + { + func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc()); + if (!func) + return false; + } + if (auto existingLevel = differentiableFunctions.TryGetValue(func)) return *existingLevel >= level; diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index e5c9b1fdb..1fe88e780 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -379,25 +379,37 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o // if (options.keepExportsAlive) { - if (inst->findDecoration<IRExportDecoration>()) + bool isImported = false; + bool shouldKeptAliveIfImported = false; + IRInst* innerInst = inst; + if (auto genInst = as<IRGeneric>(inst)) { - return true; + innerInst = findInnerMostGenericReturnVal(genInst); } - if (inst->findDecoration<IRImportDecoration>()) + for (auto decor : inst->getDecorations()) { - if (inst->findDecoration<IRForwardDerivativeDecoration>()) - return true; - if (inst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + switch (decor->getOp()) + { + case kIROp_ExportDecoration: return true; - if (auto genInst = as<IRGeneric>(inst)) + case kIROp_ImportDecoration: + isImported = true; + break; + } + } + for (auto decor : innerInst->getDecorations()) + { + switch (decor->getOp()) { - auto inner = findInnerMostGenericReturnVal(genInst); - if (inner->findDecoration<IRForwardDerivativeDecoration>()) - return true; - if (inner->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) - return true; + case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: + shouldKeptAliveIfImported = true; + break; } } + if (isImported && shouldKeptAliveIfImported) + return true; } if (options.keepLayoutsAlive && inst->findDecoration<IRLayoutDecoration>()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c704359e6..7411d031c 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -764,6 +764,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) + /// Used by the auto-diff pass to hold a reference to the + /// primal substitute function. + INST(PrimalSubstituteDecoration, primalSubstFunc, 1, 0) + /// Decorations to associate an original function with compiler generated backward derivative functions. INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0) INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0) @@ -882,6 +886,8 @@ INST(BackwardDifferentiatePropagate, BackwardDifferentiatePropagate, 1, // replaced with `BackwardDifferentiatePrimal` and `BackwardDifferentiatePropagate`. INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) +INST(PrimalSubstitute, PrimalSubstitute, 1, 0) + // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5269ae02f..ae31219bd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -608,6 +608,17 @@ struct IRForwardDerivativeDecoration : IRDecoration IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; +struct IRPrimalSubstituteDecoration : IRDecoration +{ + enum + { + kOp = kIROp_PrimalSubstituteDecoration + }; + IR_LEAF_ISA(PrimalSubstituteDecoration) + + IRInst* getPrimalSubstituteFunc() { return getOperand(0); } +}; + struct IRBackwardDerivativeIntermediateTypeDecoration : IRDecoration { enum @@ -879,6 +890,20 @@ struct IRBackwardDifferentiate : IRInst IR_LEAF_ISA(BackwardDifferentiate) }; +// Retrieves the primal substitution function for the given function. +struct IRPrimalSubstitute : IRInst +{ + enum + { + kOp = kIROp_PrimalSubstitute + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(PrimalSubstitute) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2804,6 +2829,7 @@ public: IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); + IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); @@ -3623,6 +3649,11 @@ public: addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx); } + void addPrimalSubstituteDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_PrimalSubstituteDecoration, jvpFn); + } + void addLoopCounterDecoration(IRInst* value) { addDecoration(value, kIROp_LoopCounterDecoration); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index bcff5621c..b976f4b21 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -443,6 +443,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_SequentialIDDecoration: case kIROp_ForwardDerivativeDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0aa2dc607..2819a6d83 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3106,6 +3106,17 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitPrimalSubstituteInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRPrimalSubstitute>( + this, + kIROp_PrimalSubstitute, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn) { auto inst = createInst<IRBackwardDifferentiate>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 261e08168..d8912cbd4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitPrimalSubstituteInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); @@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (decl->findModifier<ForwardDifferentiableAttribute>()) - { - getBuilder()->addForwardDifferentiableDecoration(irFunc); - } - if (decl->findModifier<BackwardDifferentiableAttribute>()) - { - getBuilder()->addBackwardDifferentiableDecoration(irFunc); - } if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); @@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version); } - if (decl->findModifier<RequiresNVAPIAttribute>()) - { - getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc); - } - - if (decl->findModifier<AlwaysFoldIntoUseSiteAttribute>()) - { - getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc); - } - - if (decl->findModifier<NoInlineAttribute>()) - { - getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); - } - - if (auto attr = decl->findModifier<InstanceAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); - } - - if (auto attr = decl->findModifier<MaxVertexCountAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); - } - - if (auto attr = decl->findModifier<NumThreadsAttribute>()) - { - auto builder = getBuilder(); - IRType* intType = builder->getIntType(); - - IRInst* operands[3] = { - builder->getIntValue(intType, attr->x), - builder->getIntValue(intType, attr->y), - builder->getIntValue(intType, attr->z) - }; - - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); - } - - if (decl->findModifier<ReadNoneAttribute>()) - { - getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); - } - - if (decl->findModifier<EarlyDepthStencilAttribute>()) - { - getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc); - } - - if (auto attr = decl->findModifier<DomainAttribute>()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); - } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute + setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier<PartitioningAttribute>()) + for (auto modifier : decl->modifiers) { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); - } + if (as<RequiresNVAPIAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc); + } + else if (as<AlwaysFoldIntoUseSiteAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc); + } + else if (as<NoInlineAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); + } + else if (auto instanceAttr = as<InstanceAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr); + getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); + } + else if (auto maxVertCountAttr = as<MaxVertexCountAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr); + getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); + } + else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier)) + { + auto builder = getBuilder(); + IRType* intType = builder->getIntType(); - if (auto attr = decl->findModifier<OutputTopologyAttribute>()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); - } + IRInst* operands[3] = { + builder->getIntValue(intType, numThreadsAttr->x), + builder->getIntValue(intType, numThreadsAttr->y), + builder->getIntValue(intType, numThreadsAttr->z) + }; - if (auto attr = decl->findModifier<OutputControlPointsAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); - } + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + } + else if (as<ReadNoneAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); + } + else if (as<EarlyDepthStencilAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc); + } + else if (auto domainAttr = as<DomainAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr); + getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); + } + else if (auto partitionAttr = as<PartitioningAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr); + getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); + } + else if (auto outputTopAttr = as<OutputTopologyAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); + } + else if (auto outputCtrlPtAttr = as<OutputControlPointsAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); + } + else if (auto spvInstOpAttr = as<SPIRVInstructionOpAttribute>(modifier)) + { + auto builder = getBuilder(); + IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0); - if (auto attr = decl->findModifier<SPIRVInstructionOpAttribute>()) - { - auto builder = getBuilder(); - IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0); + IRStringLit* setStringLit = nullptr; + if (spvInstOpAttr->args.getCount() > 1) + { + IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1); + if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + { + setStringLit = checkSetStringLit; + } + } - IRStringLit* setStringLit = nullptr; - if (attr->args.getCount() > 1) - { - IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1); - if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + // If it has a `set` defined, set it on the decoration + if (setStringLit) { - setStringLit = checkSetStringLit; + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + } + else + { + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); } } - - // If it has a `set` defined, set it on the decoration - if (setStringLit) + else if (as<UnsafeForceInlineEarlyAttribute>(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } - else + else if (as<ForceInlineAttribute>(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - } - - if (decl->findModifier<UnsafeForceInlineEarlyAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); - } - - if (decl->findModifier<ForceInlineAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); - } - - if (decl->findModifier<TreatAsDifferentiableAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); - } - - if (auto intrinsicOp = decl->findModifier<IntrinsicOpModifier>()) - { - auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); - getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); - } - - // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute - setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - - if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>()) - { - // We need to lower the decl ref to the custom derivative function to IR. - // The IR insts correspond to the decl ref is not part of the function we - // are processing. If we emit it directly to within the function, it could - // mess up the assumption on the form of the IR (e.g. having non decoration insts - // appearing in the middle of decoration insts). so we emit the decl ref to the - // function's parent for now. - - subContext->irBuilder->setInsertInto(irFunc->getParent()); - - auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* derivativeFunc = loweredVal.val; - - if (as<ForwardDerivativeAttribute>(attr)) - getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); - else - getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as<TreatAsDifferentiableAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + else if (auto intrinsicOp = as<IntrinsicOpModifier>(modifier)) + { + auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); + getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); + } + else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier)) + { + // We need to lower the decl ref to the custom derivative function to IR. + // The IR insts correspond to the decl ref is not part of the function we + // are processing. If we emit it directly to within the function, it could + // mess up the assumption on the form of the IR (e.g. having non decoration insts + // appearing in the middle of decoration insts). so we emit the decl ref to the + // function's parent for now. + + subContext->irBuilder->setInsertInto(irFunc->getParent()); + Expr* funcExpr = nullptr; + if (auto udAttr = as<UserDefinedDerivativeAttribute>(modifier)) + funcExpr = udAttr->funcExpr; + else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier)) + funcExpr = primalAttr->funcExpr; + + auto loweredVal = lowerRValueExpr(subContext, funcExpr); + + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRInst* derivativeFunc = loweredVal.val; + + if (as<ForwardDerivativeAttribute>(modifier)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as<BackwardDerivativeAttribute>(modifier)) + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc); - // Reset cursor. - subContext->irBuilder->setInsertInto(irFunc); + // Reset cursor. + subContext->irBuilder->setInsertInto(irFunc); + } + else if (as<ForwardDifferentiableAttribute>(modifier)) + { + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as<BackwardDifferentiableAttribute>(modifier)) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier<DerivativeOfAttribute>()) + for (auto modifier : decl->modifiers) { - if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr)) + if (as<DerivativeOfAttribute>(modifier) || as<PrimalSubstituteOfAttribute>(modifier)) { - NestedContext originalContextFunc(this); - auto originalSubBuilder = originalContextFunc.getBuilder(); - auto originalSubContext = originalContextFunc.getContext(); - if (auto outterGeneric = getOuterGeneric(irFunc)) - originalSubBuilder->setInsertBefore(outterGeneric); - else - originalSubBuilder->setInsertBefore(irFunc); - auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); - SLANG_RELEASE_ASSERT(originalFuncDecl); - - auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; - if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal)) + Expr* funcExpr = nullptr; + Expr* backDeclRef = nullptr; + if (auto attr = as<DerivativeOfAttribute>(modifier)) { - originalFuncVal = findGenericReturnVal(originalFuncGeneric); + funcExpr = attr->funcExpr; + backDeclRef = attr->backDeclRef; } - originalSubBuilder->setInsertBefore(originalFuncVal); - auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - if (as<ForwardDerivativeOfAttribute>(attr)) + else if (auto primalAttr = as<PrimalSubstituteOfAttribute>(modifier)) { - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); - getBuilder()->addForwardDifferentiableDecoration(irFunc); + funcExpr = primalAttr->funcExpr; + backDeclRef = primalAttr->backDeclRef; } - else + + if (auto originalDeclRefExpr = as<DeclRefExpr>(funcExpr)) { - originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + NestedContext originalContextFunc(this); + auto originalSubBuilder = originalContextFunc.getBuilder(); + auto originalSubContext = originalContextFunc.getContext(); + if (auto outterGeneric = getOuterGeneric(irFunc)) + originalSubBuilder->setInsertBefore(outterGeneric); + else + originalSubBuilder->setInsertBefore(irFunc); + auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); + SLANG_RELEASE_ASSERT(originalFuncDecl); + + auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; + if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal)) + { + originalFuncVal = findGenericReturnVal(originalFuncGeneric); + } + originalSubBuilder->setInsertBefore(originalFuncVal); + auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef); + if (as<ForwardDerivativeOfAttribute>(modifier)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as<BackwardDerivativeOfAttribute>(modifier)) + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + } + else + { + originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val); + } } + subContext->irBuilder->setInsertInto(irFunc); + finalVal->moveToEnd(); } - subContext->irBuilder->setInsertInto(irFunc); - finalVal->moveToEnd(); } return LoweredValInfo::simple(finalVal); } |
