summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/diff.meta.slang85
-rw-r--r--source/slang/slang-ast-expr.h6
-rw-r--r--source/slang/slang-ast-modifier.h18
-rw-r--r--source/slang/slang-ast-support-types.h12
-rw-r--r--source/slang/slang-check-decl.cpp614
-rw-r--r--source/slang/slang-check-expr.cpp79
-rw-r--r--source/slang/slang-check-impl.h5
-rw-r--r--source/slang/slang-check-modifier.cpp30
-rw-r--r--source/slang/slang-check-overload.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp47
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp39
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp26
-rw-r--r--source/slang/slang-ir-autodiff.cpp12
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
-rw-r--r--source/slang/slang-ir-dce.cpp36
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h31
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--source/slang/slang-lower-to-ir.cpp357
22 files changed, 855 insertions, 581 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 54f927816..4301eda94 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -12,6 +12,9 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
+attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute;
+
+__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
@@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute;
+
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
@@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<
// Sine and cosine
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(T x, out T s, out T c)
{
s = sin(x);
@@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c)
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
{
s = sin(x);
@@ -1053,62 +1061,18 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c)
{
s = sin(x);
c = cos(x);
}
-__generic<T: __BuiltinFloatingPointType>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<vector<T, N>> x, out DifferentialPair<vector<T, N>> s, out DifferentialPair<vector<T, N>> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix<T, N, M>> s, out DifferentialPair<matrix<T, N, M>> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T: __BuiltinFloatingPointType>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
-__generic<T: __BuiltinFloatingPointType, let N : int>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
// dst (obsolete)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PrimalSubstituteOf(dst)]
vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
{
vector<T, 4> dest;
@@ -1118,25 +1082,11 @@ vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
dest.w = src1.w; ;
return dest;
}
-__generic<T : __BuiltinFloatingPointType>
-[ForwardDerivativeOf(dst)]
-[ForceInline]
-DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1)
-{
- return __fwd_diff(__dst_impl)(src0, src1);
-}
-__generic<T : __BuiltinFloatingPointType>
-[BackwardDerivativeOf(dst)]
-[ForceInline]
-void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut)
-{
- __bwd_diff(__dst_impl)(src0, src1, dOut);
-}
// Legacy lighting function (obsolete)
-__target_intrinsic(hlsl)
[__readNone]
[BackwardDifferentiable]
+[PrimalSubstituteOf(lit)]
float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
{
let ambient = 1.0f;
@@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m));
return float4(ambient, diffuse, specular, 1.0f);
}
-[ForwardDerivativeOf(lit)]
-[ForceInline]
-DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m)
-{
- return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m);
-}
-[BackwardDerivativeOf(lit)]
-[ForceInline]
-void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut)
-{
- __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut);
-}
-
// Matrix determinant
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index ba0b4ce7a..301dded49 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -452,10 +452,14 @@ class HigherOrderInvokeExpr : public Expr
List<Name*> newParameterNames;
};
+class PrimalSubstituteExpr : public HigherOrderInvokeExpr
+{
+ SLANG_AST_CLASS(PrimalSubstituteExpr)
+};
+
class DifferentiateExpr : public HigherOrderInvokeExpr
{
SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr)
-
};
/// An expression of the form `__fwd_diff(fn)` to access the
/// forward-mode derivative version of the function `fn`
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 7dd0819d8..80c770c3a 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1151,6 +1151,24 @@ class BackwardDerivativeOfAttribute : public DerivativeOfAttribute
SLANG_AST_CLASS(BackwardDerivativeOfAttribute)
};
+ /// The `[PrimalSubstitute(function)]` attribute specifies a custom function that should
+ /// be used as the primal function substitute when differentiating code that calls the primal function.
+class PrimalSubstituteAttribute : public Attribute
+{
+ SLANG_AST_CLASS(PrimalSubstituteAttribute)
+ Expr* funcExpr;
+};
+
+ /// The `[PrimalSubstituteOf(primalFunction)]` attribute marks the decorated function as
+ /// the substitute primal function in a forward or backward derivative function.
+class PrimalSubstituteOfAttribute : public Attribute
+{
+ SLANG_AST_CLASS(PrimalSubstituteOfAttribute)
+
+ Expr* funcExpr;
+ Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
+};
+
/// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be
/// included for differentiation.
class NoDiffThisAttribute : public Attribute
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 07b3a5eac..78b0ccfcc 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -391,6 +391,10 @@ namespace Slang
/// maybe synthesized and made available only after conformance checking.
TypesFullyResolved,
+ /// All attributes are fully checked. This is the final step before
+ /// checking the function body.
+ AttributesChecked,
+
/// The declaration is fully checked.
///
/// This step includes any validation of the declaration that is
@@ -1500,12 +1504,12 @@ namespace Slang
enum class DeclAssociationKind
{
- ForwardDerivativeFunc, BackwardDerivativeFunc,
+ ForwardDerivativeFunc, BackwardDerivativeFunc, PrimalSubstituteFunc
};
- struct DeclAssociation
+ struct DeclAssociation : SerialRefObject
{
- SLANG_VALUE_CLASS(DeclAssociation)
+ SLANG_OBJ_CLASS(DeclAssociation)
DeclAssociationKind kind;
Decl* decl;
};
@@ -1516,7 +1520,7 @@ namespace Slang
{
SLANG_OBJ_CLASS(DeclAssociationList)
- List<DeclAssociation> associations;
+ List<RefPtr<DeclAssociation>> associations;
};
/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7c42c1892..5cd7fba45 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -34,6 +34,26 @@ namespace Slang
}
};
+ struct SemanticsDeclAttributesVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclAttributesVisitor>
+ {
+ SemanticsDeclAttributesVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl);
+
+ void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr);
+
+ void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr);
+
+ void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr);
+ };
+
struct SemanticsDeclHeaderVisitor
: public SemanticsDeclVisitorBase
, public DeclVisitor<SemanticsDeclHeaderVisitor>
@@ -258,10 +278,6 @@ namespace Slang
void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
-
- void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
-
- void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4657,270 +4673,10 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
- template<typename TDerivativeAttr>
- void checkDerivativeAttributeImpl(
- SemanticsVisitor* visitor,
- TDerivativeAttr* attr,
- const List<Expr*>& imaginaryArguments)
- {
- auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor);
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
- {
- attr->funcExpr = calleeDeclRef;
- return;
- }
- }
- visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
- }
-
- template<typename TDerivativeAttr>
- const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
-
- template<>
- const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
- {
- return "ForwardDerivative";
- }
- template<>
- const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
- {
- return "BackwardDerivative";
- }
-
- List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- for (auto param : func->getParameters())
- {
- auto arg = astBuilder->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- for (auto param : originalFuncDecl->getParameters())
- {
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
- {
- arg->type.type = pairType;
- }
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
- {
- List<Expr*> imaginaryArguments;
- auto isOutParam = [&](ParamDecl* param)
- {
- return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
- };
-
- for (auto param : originalFuncDecl->getParameters())
- {
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
- {
- arg->type.type = pairType;
- if (isOutParam(param))
- {
- // out T -> in T.Differential
- arg->type.isLeftValue = false;
- arg->type.type = visitor->tryGetDifferentialType(
- visitor->getASTBuilder(), pairType->getPrimalType());
- }
- }
- else
- {
- if (isOutParam(param))
- {
- // Skip non-differentiable out params.
- continue;
- }
- }
- imaginaryArguments.add(arg);
- }
- if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
- {
- auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
- arg->type.isLeftValue = false;
- arg->type.type = diffReturnType;
- arg->loc = loc;
- imaginaryArguments.add(arg);
- }
- return imaginaryArguments;
- }
-
- // This helper function is needed to workaround a gcc bug.
- // Remove when we upgrade to a newer version of gcc.
- template <typename T>
- static T* _findModifier(Decl* decl)
- {
- return decl->findModifier<T>();
- }
-
- template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
- void checkDerivativeOfAttributeImpl(
- SemanticsVisitor* visitor,
- FunctionDeclBase* funcDecl,
- TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind)
- {
- DeclRef<Decl> calleeDeclRef;
- DeclRefExpr* calleeDeclRefExpr = nullptr;
- DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
- diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
- diffFuncExpr->loc = derivativeOfAttr->loc;
- Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor);
- if (!checkedDiffFuncExpr)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr);
- if (resolvedDiffFuncExpr)
- calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction);
- }
-
- if (!calleeDeclRefExpr)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- calleeDeclRef = calleeDeclRefExpr->declRef;
-
- auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
- if (!calleeFunc)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
-
- if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
- {
- // The primal function already has a `[*Derivative]` attribute, this is invalid.
- visitor->getSink()->diagnose(
- derivativeOfAttr,
- Diagnostics::declAlreadyHasAttribute,
- calleeDeclRef,
- getDerivativeAttrName<TDerivativeAttr>());
- visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl());
- }
- derivativeOfAttr->funcExpr = calleeDeclRefExpr;
- auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
- derivativeAttr->loc = derivativeOfAttr->loc;
- auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
- auto declRef =
- DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
- auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
- declRefExpr->type.type = nullptr;
- derivativeAttr->args.add(declRefExpr);
- derivativeAttr->funcExpr = declRefExpr;
- checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr);
- derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
- derivativeAttr->funcExpr = nullptr;
- visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl);
- }
-
- static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
- {
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
- List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
- }
-
- static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
- {
- if (!attr->funcExpr)
- return;
- if (attr->funcExpr->type.type)
- return;
-
- List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
- }
-
- template<typename TDerivativeAttr, typename TDerivativeOfAttr>
- bool tryCheckDerivativeOfAttributeImpl(
- SemanticsVisitor* visitor,
- FunctionDeclBase* funcDecl,
- TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind,
- const List<Expr*>& imaginaryArgsToOriginal)
- {
- DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
- SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
- checkDerivativeOfAttributeImpl<TDerivativeAttr>(
- &subVisitor,
- funcDecl,
- derivativeOfAttr,
- assocKind,
- imaginaryArgsToOriginal);
- return tempSink.getErrorCount() == 0;
- }
-
- void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
- {
- auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
- if (!attr)
- return;
-
- checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
- this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
- }
-
- void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
- {
- auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>();
- if (!attr)
- return;
-
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
- }
-
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
- // Run checking on attributes that can't be fully checked in header checking stage.
- checkForwardDerivativeOfAttribute(decl);
- if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
- checkDerivativeAttribute(this, decl, derivativeAttr);
- checkBackwardDerivativeOfAttribute(decl);
- if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>())
- checkDerivativeAttribute(this, decl, derivativeAttr);
-
if (newContext.getParentDifferentiableAttribute())
{
// Register additional types outside the function body first.
@@ -6762,7 +6518,7 @@ namespace Slang
/// Note: this function creates an empty list of candidates for the given type if
/// a matching entry doesn't exist already.
///
- static List<DeclAssociation>& _getDeclAssociationList(
+ static List<RefPtr<DeclAssociation>>& _getDeclAssociationList(
Decl* decl,
OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations)
{
@@ -6787,14 +6543,16 @@ namespace Slang
void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated)
{
auto moduleDecl = getModuleDecl(associated);
- DeclAssociation assoc = {kind, associated};
+ RefPtr<DeclAssociation> assoc = new DeclAssociation();
+ assoc->kind = kind;
+ assoc->decl = associated;
_getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc);
m_associatedDeclListsBuilt = false;
m_mapDeclToAssociatedDecls.Clear();
}
- List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
+ List<RefPtr<DeclAssociation>> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
{
// This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`.
// Consider refactoring them into the same framework.
@@ -6838,6 +6596,23 @@ namespace Slang
FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func)
{
+ return _getFuncDifferentiableLevelImpl(func, 1);
+ }
+
+ FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit)
+ {
+ if (recurseLimit > 0)
+ {
+ if (auto primalSubst = func->findModifier<PrimalSubstituteAttribute>())
+ {
+ if (auto declRefExpr = as<DeclRefExpr>(primalSubst->funcExpr))
+ {
+ if (auto primalSubstFunc = declRefExpr->declRef.as<FunctionDeclBase>())
+ return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1);
+ }
+ }
+ }
+
if (func->findModifier<BackwardDifferentiableAttribute>())
return FunctionDifferentiableLevel::Backward;
if (func->findModifier<BackwardDerivativeAttribute>())
@@ -6849,13 +6624,19 @@ namespace Slang
for (auto assocDecl : getAssociatedDeclsForDecl(func))
{
- switch (assocDecl.kind)
+ switch (assocDecl->kind)
{
case DeclAssociationKind::BackwardDerivativeFunc:
return FunctionDifferentiableLevel::Backward;
case DeclAssociationKind::ForwardDerivativeFunc:
diffLevel = FunctionDifferentiableLevel::Forward;
break;
+ case DeclAssociationKind::PrimalSubstituteFunc:
+ if (auto assocFunc = as<FunctionDeclBase>(assocDecl->decl))
+ {
+ return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1);
+ }
+ break;
default:
break;
}
@@ -6971,6 +6752,10 @@ namespace Slang
SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl);
break;
+ case DeclCheckState::AttributesChecked:
+ SemanticsDeclAttributesVisitor(shared).dispatch(decl);
+ break;
+
case DeclCheckState::Checked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
@@ -7058,4 +6843,297 @@ namespace Slang
return val;
}
+
+ template<typename TDerivativeAttr>
+ void checkDerivativeAttributeImpl(
+ SemanticsVisitor* visitor,
+ TDerivativeAttr* attr,
+ const List<Expr*>& imaginaryArguments)
+ {
+ SemanticsContext::ExprLocalScope scope;
+ auto ctx = visitor->withExprLocalScope(&scope);
+ auto subVisitor = SemanticsVisitor(ctx);
+ auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
+ auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
+ auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ template<typename TDerivativeAttr>
+ const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); }
+
+ template<>
+ const char* getDerivativeAttrName<ForwardDerivativeAttribute>()
+ {
+ return "ForwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<BackwardDerivativeAttribute>()
+ {
+ return "BackwardDerivative";
+ }
+ template<>
+ const char* getDerivativeAttrName<PrimalSubstituteAttribute>()
+ {
+ return "PrimalSubstitute";
+ }
+
+ List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : func->getParameters())
+ {
+ auto arg = astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ for (auto param : originalFuncDecl->getParameters())
+ {
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ {
+ List<Expr*> imaginaryArguments;
+ auto isOutParam = [&](ParamDecl* param)
+ {
+ return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
+ };
+
+ for (auto param : originalFuncDecl->getParameters())
+ {
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = loc;
+ if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
+ {
+ arg->type.type = pairType;
+ if (isOutParam(param))
+ {
+ // out T -> in T.Differential
+ arg->type.isLeftValue = false;
+ arg->type.type = visitor->tryGetDifferentialType(
+ visitor->getASTBuilder(), pairType->getPrimalType());
+ }
+ }
+ else
+ {
+ if (isOutParam(param))
+ {
+ // Skip non-differentiable out params.
+ continue;
+ }
+ }
+ imaginaryArguments.add(arg);
+ }
+ if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type))
+ {
+ auto arg = visitor->getASTBuilder()->create<InitializerListExpr>();
+ arg->type.isLeftValue = false;
+ arg->type.type = diffReturnType;
+ arg->loc = loc;
+ imaginaryArguments.add(arg);
+ }
+ return imaginaryArguments;
+ }
+
+ // This helper function is needed to workaround a gcc bug.
+ // Remove when we upgrade to a newer version of gcc.
+ template <typename T>
+ static T* _findModifier(Decl* decl)
+ {
+ return decl->findModifier<T>();
+ }
+
+ template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
+ void checkDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind)
+ {
+ DeclRef<Decl> calleeDeclRef;
+ DeclRefExpr* calleeDeclRefExpr = nullptr;
+ HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
+ higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
+ higherOrderFuncExpr->loc = derivativeOfAttr->loc;
+ Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor);
+ if (!checkedHigherOrderFuncExpr)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs);
+ SemanticsContext::ExprLocalScope scope;
+ auto ctx = visitor->withExprLocalScope(&scope);
+ auto subVisitor = SemanticsVisitor(ctx);
+ auto resolved = subVisitor.ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr);
+ if (resolvedFuncExpr)
+ calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction);
+ }
+
+ if (!calleeDeclRefExpr)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ calleeDeclRef = calleeDeclRefExpr->declRef;
+
+ auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
+ if (!calleeFunc)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+
+ if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
+ {
+ // The primal function already has a `[*Derivative]` attribute, this is invalid.
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::declAlreadyHasAttribute,
+ calleeDeclRef,
+ getDerivativeAttrName<TDerivativeAttr>());
+ visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl());
+ }
+
+ derivativeOfAttr->funcExpr = calleeDeclRefExpr;
+ auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
+ derivativeAttr->loc = derivativeOfAttr->loc;
+ auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ derivativeAttr->args.add(declRefExpr);
+ derivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr);
+ derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
+ derivativeAttr->funcExpr = nullptr;
+ visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
+ checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments);
+ }
+
+ template<typename TDerivativeAttr, typename TDerivativeOfAttr>
+ bool tryCheckDerivativeOfAttributeImpl(
+ SemanticsVisitor* visitor,
+ FunctionDeclBase* funcDecl,
+ TDerivativeOfAttr* derivativeOfAttr,
+ DeclAssociationKind assocKind,
+ const List<Expr*>& imaginaryArgsToOriginal)
+ {
+ DiagnosticSink tempSink(visitor->getSourceManager(), nullptr);
+ SemanticsVisitor subVisitor(visitor->withSink(&tempSink));
+ checkDerivativeOfAttributeImpl<TDerivativeAttr>(
+ &subVisitor,
+ funcDecl,
+ derivativeOfAttr,
+ assocKind,
+ imaginaryArgsToOriginal);
+ return tempSink.getErrorCount() == 0;
+ }
+
+ void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr)
+ {
+ checkDerivativeOfAttributeImpl<PrimalSubstituteAttribute, PrimalSubstituteExpr>(
+ this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc);
+ }
+
+ void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ // Run checking on attributes that can't be fully checked in header checking stage.
+ for (auto attr : decl->modifiers)
+ {
+ if (auto fwdDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr))
+ checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr);
+ else if (auto bwdDerivativeOfAttr = as<BackwardDerivativeOfAttribute>(attr))
+ checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr);
+ else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr))
+ checkPrimalSubstituteOfAttribute(decl, primalOfAttr);
+ else if (auto fwdDerivativeAttr = as<ForwardDerivativeAttribute>(attr))
+ checkDerivativeAttribute(this, decl, fwdDerivativeAttr);
+ else if (auto bwdDerivativeAttr = as<BackwardDerivativeAttribute>(attr))
+ checkDerivativeAttribute(this, decl, bwdDerivativeAttr);
+ else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr))
+ checkDerivativeAttribute(this, decl, primalAttr);
+ }
+ }
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index bebaa63a2..f749361d7 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2227,10 +2227,10 @@ namespace Slang
return type;
}
- struct DifferentiateExprCheckingActions
+ struct HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0;
- virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0;
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0;
+ virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0;
FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr)
{
if (auto funcType = as<FuncType>(funcExpr->type.type))
@@ -2255,13 +2255,13 @@ namespace Slang
}
};
- struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>();
}
- void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
resultDiffExpr->baseFunction = funcExpr;
auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
@@ -2290,13 +2290,13 @@ namespace Slang
}
};
- struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>();
}
- void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
resultDiffExpr->baseFunction = funcExpr;
auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
@@ -2333,10 +2333,45 @@ namespace Slang
}
};
- static Expr* _checkDifferentiateExpr(
+ struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions
+ {
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
+ {
+ return semantics->getASTBuilder()->create<PrimalSubstituteExpr>();
+ }
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ {
+ resultDiffExpr->baseFunction = funcExpr;
+ auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
+ if (!baseFuncType)
+ {
+ resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
+ semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ return;
+ }
+ resultDiffExpr->type = baseFuncType;
+ if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
+ {
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
+ {
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
+ {
+ for (auto param : funcDecl->getParameters())
+ {
+ resultDiffExpr->newParameterNames.add(param->getName());
+ }
+ }
+ }
+ }
+ };
+
+ static Expr* _checkHigherOrderInvokeExpr(
SemanticsVisitor* semantics,
- DifferentiateExpr* expr,
- DifferentiateExprCheckingActions* actions)
+ HigherOrderInvokeExpr* expr,
+ HigherOrderInvokeExprCheckingActions* actions)
{
// Check/Resolve inner function declaration.
expr->baseFunction = semantics->CheckTerm(expr->baseFunction);
@@ -2354,8 +2389,8 @@ namespace Slang
nullptr,
overloadedExpr->loc,
nullptr);
- auto candidateExpr = actions->createDifferentiateExpr(semantics);
- actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr);
+ auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics);
+ actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr);
candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
@@ -2368,8 +2403,8 @@ namespace Slang
OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>();
for (auto item : overloadedExpr2->candidiateExprs)
{
- auto candidateExpr = actions->createDifferentiateExpr(semantics);
- actions->fillDifferentiateExpr(candidateExpr, semantics, item);
+ auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics);
+ actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item);
candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
@@ -2378,20 +2413,26 @@ namespace Slang
return result;
}
- actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction);
+ actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction);
return expr;
}
Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
ForwardDifferentiateExprCheckingActions actions;
- return _checkDifferentiateExpr(this, expr, &actions);
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
}
Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
{
BackwardDifferentiateExprCheckingActions actions;
- return _checkDifferentiateExpr(this, expr, &actions);
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
+ }
+
+ Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
+ {
+ PrimalSubstituteExprCheckingActions actions;
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
}
Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index fc1b622cc..3d40b10e9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -287,10 +287,11 @@ namespace Slang
void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration);
- List<DeclAssociation> const& getAssociatedDeclsForDecl(Decl* decl);
+ List<RefPtr<DeclAssociation>> const& getAssociatedDeclsForDecl(Decl* decl);
bool isDifferentiableFunc(FunctionDeclBase* func);
bool isBackwardDifferentiableFunc(FunctionDeclBase* func);
+ FunctionDifferentiableLevel _getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit);
FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func);
private:
@@ -1951,6 +1952,8 @@ namespace Slang
Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr);
+ Expr* visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr);
+
Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr);
Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e6a524645..a068f19d6 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -654,33 +654,23 @@ namespace Slang
hitObjectAttributesAttr->location = (int32_t)val->value;
}
- else if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr))
+ else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
-
- // Ensure that the argument is a reference to a function definition or declaration.
- auto diffExpr = CheckTerm(attr->args[0]);
- if (diffExpr->type == getASTBuilder()->getErrorType())
- {
- // Could not resolve the term.
- getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
- return false;
- }
- // We store the partially checked funcExpr in the attribute, and
- // rely on `ResolveInvoke` to resolve it to the actual function decl.
- // The call to `ResolveInvoke` is deferred until we are checking the
- // body of the function.
- //
- // Set type to null to indicate that this needs expr needs to be further resolved.
- diffExpr->type.type = nullptr;
- derivativeAttr->funcExpr = diffExpr;
+ if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr))
+ derivativeAttr->funcExpr = attr->args[0];
+ else if (auto primalSubstAttr = as<PrimalSubstituteAttribute>(attr))
+ primalSubstAttr->funcExpr = attr->args[0];
}
- else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
+ else if (as<DerivativeOfAttribute>(attr) || as<PrimalSubstituteOfAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
- derivativeOfAttr->funcExpr = attr->args[0];
+ if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
+ derivativeOfAttr->funcExpr = attr->args[0];
+ else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr))
+ primalOfAttr->funcExpr = attr->args[0];
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 91af731ad..f786089f8 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1549,7 +1549,7 @@ namespace Slang
// Base is a normal or fully specialized generic function.
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- if (auto diffExpr = as<DifferentiateExpr>(expr))
+ if (auto diffExpr = as<HigherOrderInvokeExpr>(expr))
{
candidate.funcType = as<FuncType>(diffExpr->type.type);
}
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 27106b6a2..2090cd4dc 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -344,6 +344,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
static bool _isDifferentiableFunc(IRInst* func)
{
+ func = getResolvedInstForDecorations(func);
for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration())
{
switch (decor->getOp())
@@ -369,6 +370,37 @@ static IRFuncType* _getCalleeActualFuncType(IRInst* callee)
return nullptr;
}
+IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee)
+{
+ if (auto func = as<IRFunc>(callee))
+ {
+ if (auto decor = func->findDecoration<IRPrimalSubstituteDecoration>())
+ return decor->getPrimalSubstituteFunc();
+ }
+ else if (auto specialize = as<IRSpecialize>(callee))
+ {
+ auto innerGen = as<IRGeneric>(specialize->getBase());
+ if (!innerGen)
+ return nullptr;
+ auto innerFunc = findGenericReturnVal(innerGen);
+ if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ auto substSpecialize = as<IRSpecialize>(decor->getPrimalSubstituteFunc());
+ SLANG_RELEASE_ASSERT(substSpecialize);
+ SLANG_RELEASE_ASSERT(substSpecialize->getArgCount() == specialize->getArgCount());
+ List<IRInst*> args;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ args.add(specialize->getArg(i));
+ return builder->emitSpecializeInst(
+ callee->getFullType(),
+ substSpecialize->getBase(),
+ (UInt)args.getCount(),
+ args.getBuffer());
+ }
+ }
+ return callee;
+}
+
// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials
// in the current transcription context.
@@ -393,10 +425,20 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee);
+ auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee);
IRInst* diffCallee = nullptr;
+ if (substPrimalCallee == primalCallee)
+ {
+ instMapD.TryGetValue(origCallee, diffCallee);
+ }
+ else
+ {
+ instMapD.TryGetValue(substPrimalCallee, diffCallee);
+ primalCallee = substPrimalCallee;
+ }
- if (instMapD.TryGetValue(origCallee, diffCallee))
+ if (diffCallee)
{
}
else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
@@ -750,6 +792,9 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
// Make sure this isn't itself a specialize .
SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+ auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>();
+ SLANG_RELEASE_ASSERT(derivativeDecoration);
+
return InstPair(primalSpecialize, jvpFunc);
}
else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>())
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 709968f77..d7cce7c53 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -736,7 +736,7 @@ namespace Slang
// We need to insert a local variable to store this var.
IRInst* operandReplacement = nullptr;
- if (canInstBeStored(operand))
+ if (canTypeBeStored(operand->getDataType()))
{
auto var = storeInstAsLocalVar(operand);
builder.setInsertBefore(inst);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 73d9b6ba6..ed122c862 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I
return true;
if (isChildInstOf(currentParent, origInst->getParent()))
return true;
+
+ // If origInst is defined in the first block of the same function as current inst (e.g. a param),
+ // we can use it as primal.
+ // More generally, we should test if origInst dominates currentParent, but that requires calculating
+ // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient.
+ auto parentFunc = getParentFunc(currentParent);
+ if (parentFunc && origInst->parent == parentFunc->getFirstBlock())
+ return true;
return false;
}
@@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS
case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_PrimalSubstituteDecoration:
break;
default:
if (!outInstsToSkip.Contains(use->getUser()))
@@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
return InstPair(primalGeneric, diffGeneric);
}
+IRInst* getActualInstToTranscribe(IRInst* inst)
+{
+ if (auto gen = as<IRGeneric>(inst))
+ {
+ auto retVal = findGenericReturnVal(gen);
+ if (retVal->getOp() != kIROp_Func)
+ return inst;
+ if (auto primalSubst = retVal->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ auto spec = as<IRSpecialize>(primalSubst->getPrimalSubstituteFunc());
+ SLANG_RELEASE_ASSERT(spec);
+ return spec->getBase();
+ }
+ }
+ else if (auto func = as<IRFunc>(inst))
+ {
+ if (auto primalSubst = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ auto actualFunc = as<IRFunc>(primalSubst->getPrimalSubstituteFunc());
+ SLANG_RELEASE_ASSERT(actualFunc);
+ return actualFunc;
+ }
+ }
+ return inst;
+}
+
IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
{
// If a differential instruction is already mapped for
@@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
// depending on the op-code.
//
instsInProgress.Add(origInst);
-
- InstPair pair = transcribeInst(builder, origInst);
+ auto actualInstToTranscribe = getActualInstToTranscribe(origInst);
+ InstPair pair = transcribeInst(builder, actualInstToTranscribe);
instsInProgress.Remove(origInst);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index a05fe7044..2347c7a8f 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext
bool shouldStoreVar(IRVar* var)
{
// Always store intermediate context var.
- if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
{
+ // If we are specializing a callee's intermediate context with types that can't be stored,
+ // we can't store the entire context.
+ if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType()))
+ {
+ for (UInt i = 0; i < spec->getArgCount(); i++)
+ {
+ if (!canTypeBeStored(spec->getArg(i)->getDataType()))
+ return false;
+ }
+ }
return true;
}
@@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext
// 2. Does the var have a store
//
- return (doesInstHaveDiffUse(var) && doesInstHaveStore(var));
+ return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
}
bool shouldStoreInst(IRInst* inst)
@@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext
return false;
}
- if (!canInstBeStored(inst))
+ if (!canTypeBeStored(inst->getDataType()))
return false;
// Never store certain opcodes.
@@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext
case kIROp_MakeOptionalValue:
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialWitnessTable:
return false;
case kIROp_GetElement:
case kIROp_FieldExtract:
@@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (inst->getOp() == kIROp_Call)
{
// The primal calls should be marked as no side effect so they can be DCE'd if possible.
- builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst);
+ // We can only do so if the intermediate context of the callee is stored.
+ if (primalCtx->getBackwardDerivativePrimalContextVar()
+ ->findDecoration<IRPrimalValueStructKeyDecoration>())
+ {
+ builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst);
+ }
}
}
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index fcfbf3bee..65e880868 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -563,12 +563,15 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
return false;
}
-bool canInstBeStored(IRInst* inst)
+bool canTypeBeStored(IRInst* type)
{
- if (as<IRBasicType>(inst->getDataType()))
+ if (!type)
+ return false;
+
+ if (as<IRBasicType>(type))
return true;
- switch (inst->getDataType()->getOp())
+ switch (type->getOp())
{
case kIROp_StructType:
case kIROp_OptionalType:
@@ -716,6 +719,9 @@ struct AutoDiffPass : public InstPassBase
break;
}
break;
+ case kIROp_PrimalSubstitute:
+ // Explicit primal subst operator is not yet supported.
+ SLANG_UNIMPLEMENTED_X("explicit primal_subst operator.");
default:
break;
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index a4eb94461..f757375d8 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -295,7 +295,7 @@ bool isBackwardDifferentiableFunc(IRInst* func);
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);
-bool canInstBeStored(IRInst* inst);
+bool canTypeBeStored(IRInst* type);
inline bool isRelevantDifferentialPair(IRType* type)
{
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 186b0cc03..14f6394e2 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -39,12 +39,17 @@ public:
return false;
}
-
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
func = getResolvedInstForDecorations(func);
if (!func)
return false;
+ if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc());
+ if (!func)
+ return false;
+ }
for (auto decorations : func->getDecorations())
{
@@ -84,7 +89,13 @@ public:
if (!func)
return false;
-
+ if (auto substDecor = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ func = getResolvedInstForDecorations(substDecor->getPrimalSubstituteFunc());
+ if (!func)
+ return false;
+ }
+
if (auto existingLevel = differentiableFunctions.TryGetValue(func))
return *existingLevel >= level;
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index e5c9b1fdb..1fe88e780 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -379,25 +379,37 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
//
if (options.keepExportsAlive)
{
- if (inst->findDecoration<IRExportDecoration>())
+ bool isImported = false;
+ bool shouldKeptAliveIfImported = false;
+ IRInst* innerInst = inst;
+ if (auto genInst = as<IRGeneric>(inst))
{
- return true;
+ innerInst = findInnerMostGenericReturnVal(genInst);
}
- if (inst->findDecoration<IRImportDecoration>())
+ for (auto decor : inst->getDecorations())
{
- if (inst->findDecoration<IRForwardDerivativeDecoration>())
- return true;
- if (inst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ switch (decor->getOp())
+ {
+ case kIROp_ExportDecoration:
return true;
- if (auto genInst = as<IRGeneric>(inst))
+ case kIROp_ImportDecoration:
+ isImported = true;
+ break;
+ }
+ }
+ for (auto decor : innerInst->getDecorations())
+ {
+ switch (decor->getOp())
{
- auto inner = findInnerMostGenericReturnVal(genInst);
- if (inner->findDecoration<IRForwardDerivativeDecoration>())
- return true;
- if (inner->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
- return true;
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_PrimalSubstituteDecoration:
+ shouldKeptAliveIfImported = true;
+ break;
}
}
+ if (isImported && shouldKeptAliveIfImported)
+ return true;
}
if (options.keepLayoutsAlive && inst->findDecoration<IRLayoutDecoration>())
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c704359e6..7411d031c 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -764,6 +764,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0)
+ /// Used by the auto-diff pass to hold a reference to the
+ /// primal substitute function.
+ INST(PrimalSubstituteDecoration, primalSubstFunc, 1, 0)
+
/// Decorations to associate an original function with compiler generated backward derivative functions.
INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0)
INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0)
@@ -882,6 +886,8 @@ INST(BackwardDifferentiatePropagate, BackwardDifferentiatePropagate, 1,
// replaced with `BackwardDifferentiatePrimal` and `BackwardDifferentiatePropagate`.
INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
+INST(PrimalSubstitute, PrimalSubstitute, 1, 0)
+
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5269ae02f..ae31219bd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -608,6 +608,17 @@ struct IRForwardDerivativeDecoration : IRDecoration
IRInst* getForwardDerivativeFunc() { return getOperand(0); }
};
+struct IRPrimalSubstituteDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_PrimalSubstituteDecoration
+ };
+ IR_LEAF_ISA(PrimalSubstituteDecoration)
+
+ IRInst* getPrimalSubstituteFunc() { return getOperand(0); }
+};
+
struct IRBackwardDerivativeIntermediateTypeDecoration : IRDecoration
{
enum
@@ -879,6 +890,20 @@ struct IRBackwardDifferentiate : IRInst
IR_LEAF_ISA(BackwardDifferentiate)
};
+// Retrieves the primal substitution function for the given function.
+struct IRPrimalSubstitute : IRInst
+{
+ enum
+ {
+ kOp = kIROp_PrimalSubstitute
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(PrimalSubstitute)
+};
+
// Dictionary item mapping a type with a corresponding
// IDifferentiable witness table
//
@@ -2804,6 +2829,7 @@ public:
IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn);
IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
@@ -3623,6 +3649,11 @@ public:
addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx);
}
+ void addPrimalSubstituteDecoration(IRInst* value, IRInst* jvpFn)
+ {
+ addDecoration(value, kIROp_PrimalSubstituteDecoration, jvpFn);
+ }
+
void addLoopCounterDecoration(IRInst* value)
{
addDecoration(value, kIROp_LoopCounterDecoration);
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index bcff5621c..b976f4b21 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -443,6 +443,7 @@ static void cloneExtraDecorationsFromInst(
case kIROp_SequentialIDDecoration:
case kIROp_ForwardDerivativeDecoration:
case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_PrimalSubstituteDecoration:
case kIROp_IntrinsicOpDecoration:
if (!clonedInst->findDecorationImpl(decoration->getOp()))
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 0aa2dc607..2819a6d83 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3106,6 +3106,17 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitPrimalSubstituteInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRPrimalSubstitute>(
+ this,
+ kIROp_PrimalSubstitute,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn)
{
auto inst = createInst<IRBackwardDifferentiate>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 261e08168..d8912cbd4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFunction);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitPrimalSubstituteInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
{
auto baseVal = lowerSubExpr(expr->innerExpr);
@@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(subContext, irFunc, decl);
addLinkageDecoration(subContext, irFunc, decl);
- if (decl->findModifier<ForwardDifferentiableAttribute>())
- {
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
- }
- if (decl->findModifier<BackwardDifferentiableAttribute>())
- {
- getBuilder()->addBackwardDifferentiableDecoration(irFunc);
- }
if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
{
lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
@@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version);
}
- if (decl->findModifier<RequiresNVAPIAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc);
- }
-
- if (decl->findModifier<AlwaysFoldIntoUseSiteAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc);
- }
-
- if (decl->findModifier<NoInlineAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc);
- }
-
- if (auto attr = decl->findModifier<InstanceAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
- }
-
- if (auto attr = decl->findModifier<MaxVertexCountAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
- }
-
- if (auto attr = decl->findModifier<NumThreadsAttribute>())
- {
- auto builder = getBuilder();
- IRType* intType = builder->getIntType();
-
- IRInst* operands[3] = {
- builder->getIntValue(intType, attr->x),
- builder->getIntValue(intType, attr->y),
- builder->getIntValue(intType, attr->z)
- };
-
- builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
- }
-
- if (decl->findModifier<ReadNoneAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
- }
-
- if (decl->findModifier<EarlyDepthStencilAttribute>())
- {
- getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc);
- }
-
- if (auto attr = decl->findModifier<DomainAttribute>())
- {
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit);
- }
+ // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
+ setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<PartitioningAttribute>())
+ for (auto modifier : decl->modifiers)
{
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit);
- }
+ if (as<RequiresNVAPIAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc);
+ }
+ else if (as<AlwaysFoldIntoUseSiteAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc);
+ }
+ else if (as<NoInlineAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc);
+ }
+ else if (auto instanceAttr = as<InstanceAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
+ }
+ else if (auto maxVertCountAttr = as<MaxVertexCountAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
+ }
+ else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier))
+ {
+ auto builder = getBuilder();
+ IRType* intType = builder->getIntType();
- if (auto attr = decl->findModifier<OutputTopologyAttribute>())
- {
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit);
- }
+ IRInst* operands[3] = {
+ builder->getIntValue(intType, numThreadsAttr->x),
+ builder->getIntValue(intType, numThreadsAttr->y),
+ builder->getIntValue(intType, numThreadsAttr->z)
+ };
- if (auto attr = decl->findModifier<OutputControlPointsAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit);
- }
+ builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
+ }
+ else if (as<ReadNoneAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
+ }
+ else if (as<EarlyDepthStencilAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc);
+ }
+ else if (auto domainAttr = as<DomainAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit);
+ }
+ else if (auto partitionAttr = as<PartitioningAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit);
+ }
+ else if (auto outputTopAttr = as<OutputTopologyAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit);
+ }
+ else if (auto outputCtrlPtAttr = as<OutputControlPointsAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit);
+ }
+ else if (auto spvInstOpAttr = as<SPIRVInstructionOpAttribute>(modifier))
+ {
+ auto builder = getBuilder();
+ IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0);
- if (auto attr = decl->findModifier<SPIRVInstructionOpAttribute>())
- {
- auto builder = getBuilder();
- IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0);
+ IRStringLit* setStringLit = nullptr;
+ if (spvInstOpAttr->args.getCount() > 1)
+ {
+ IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1);
+ if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0)
+ {
+ setStringLit = checkSetStringLit;
+ }
+ }
- IRStringLit* setStringLit = nullptr;
- if (attr->args.getCount() > 1)
- {
- IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1);
- if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0)
+ // If it has a `set` defined, set it on the decoration
+ if (setStringLit)
{
- setStringLit = checkSetStringLit;
+ builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit);
+ }
+ else
+ {
+ builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit);
}
}
-
- // If it has a `set` defined, set it on the decoration
- if (setStringLit)
+ else if (as<UnsafeForceInlineEarlyAttribute>(modifier))
{
- builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit);
+ getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
}
- else
+ else if (as<ForceInlineAttribute>(modifier))
{
- builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit);
+ getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
- }
-
- if (decl->findModifier<UnsafeForceInlineEarlyAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
- }
-
- if (decl->findModifier<ForceInlineAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
- }
-
- if (decl->findModifier<TreatAsDifferentiableAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
- }
-
- if (auto intrinsicOp = decl->findModifier<IntrinsicOpModifier>())
- {
- auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
- getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
- }
-
- // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
- setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
-
- if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>())
- {
- // We need to lower the decl ref to the custom derivative function to IR.
- // The IR insts correspond to the decl ref is not part of the function we
- // are processing. If we emit it directly to within the function, it could
- // mess up the assumption on the form of the IR (e.g. having non decoration insts
- // appearing in the middle of decoration insts). so we emit the decl ref to the
- // function's parent for now.
-
- subContext->irBuilder->setInsertInto(irFunc->getParent());
-
- auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
-
- SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- IRInst* derivativeFunc = loweredVal.val;
-
- if (as<ForwardDerivativeAttribute>(attr))
- getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
- else
- getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
+ else if (as<TreatAsDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
+ }
+ else if (auto intrinsicOp = as<IntrinsicOpModifier>(modifier))
+ {
+ auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
+ getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
+ }
+ else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier))
+ {
+ // We need to lower the decl ref to the custom derivative function to IR.
+ // The IR insts correspond to the decl ref is not part of the function we
+ // are processing. If we emit it directly to within the function, it could
+ // mess up the assumption on the form of the IR (e.g. having non decoration insts
+ // appearing in the middle of decoration insts). so we emit the decl ref to the
+ // function's parent for now.
+
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
+ Expr* funcExpr = nullptr;
+ if (auto udAttr = as<UserDefinedDerivativeAttribute>(modifier))
+ funcExpr = udAttr->funcExpr;
+ else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier))
+ funcExpr = primalAttr->funcExpr;
+
+ auto loweredVal = lowerRValueExpr(subContext, funcExpr);
+
+ SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
+ IRInst* derivativeFunc = loweredVal.val;
+
+ if (as<ForwardDerivativeAttribute>(modifier))
+ getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
+ else if (as<BackwardDerivativeAttribute>(modifier))
+ getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
+ else
+ getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc);
- // Reset cursor.
- subContext->irBuilder->setInsertInto(irFunc);
+ // Reset cursor.
+ subContext->irBuilder->setInsertInto(irFunc);
+ }
+ else if (as<ForwardDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
}
-
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// the interface's type definition.
auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric);
- if (auto attr = decl->findModifier<DerivativeOfAttribute>())
+ for (auto modifier : decl->modifiers)
{
- if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr))
+ if (as<DerivativeOfAttribute>(modifier) || as<PrimalSubstituteOfAttribute>(modifier))
{
- NestedContext originalContextFunc(this);
- auto originalSubBuilder = originalContextFunc.getBuilder();
- auto originalSubContext = originalContextFunc.getContext();
- if (auto outterGeneric = getOuterGeneric(irFunc))
- originalSubBuilder->setInsertBefore(outterGeneric);
- else
- originalSubBuilder->setInsertBefore(irFunc);
- auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl());
- SLANG_RELEASE_ASSERT(originalFuncDecl);
-
- auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val;
- if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal))
+ Expr* funcExpr = nullptr;
+ Expr* backDeclRef = nullptr;
+ if (auto attr = as<DerivativeOfAttribute>(modifier))
{
- originalFuncVal = findGenericReturnVal(originalFuncGeneric);
+ funcExpr = attr->funcExpr;
+ backDeclRef = attr->backDeclRef;
}
- originalSubBuilder->setInsertBefore(originalFuncVal);
- auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef);
- if (as<ForwardDerivativeOfAttribute>(attr))
+ else if (auto primalAttr = as<PrimalSubstituteOfAttribute>(modifier))
{
- originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ funcExpr = primalAttr->funcExpr;
+ backDeclRef = primalAttr->backDeclRef;
}
- else
+
+ if (auto originalDeclRefExpr = as<DeclRefExpr>(funcExpr))
{
- originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ NestedContext originalContextFunc(this);
+ auto originalSubBuilder = originalContextFunc.getBuilder();
+ auto originalSubContext = originalContextFunc.getContext();
+ if (auto outterGeneric = getOuterGeneric(irFunc))
+ originalSubBuilder->setInsertBefore(outterGeneric);
+ else
+ originalSubBuilder->setInsertBefore(irFunc);
+ auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl());
+ SLANG_RELEASE_ASSERT(originalFuncDecl);
+
+ auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val;
+ if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal))
+ {
+ originalFuncVal = findGenericReturnVal(originalFuncGeneric);
+ }
+ originalSubBuilder->setInsertBefore(originalFuncVal);
+ auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef);
+ if (as<ForwardDerivativeOfAttribute>(modifier))
+ {
+ originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else if (as<BackwardDerivativeOfAttribute>(modifier))
+ {
+ originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ }
+ else
+ {
+ originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val);
+ }
}
+ subContext->irBuilder->setInsertInto(irFunc);
+ finalVal->moveToEnd();
}
- subContext->irBuilder->setInsertInto(irFunc);
- finalVal->moveToEnd();
}
return LoweredValInfo::simple(finalVal);
}