diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 12:30:15 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 12:30:15 -0700 |
| commit | 8dc9efd256bd211d8c446971f09a7c79e644b110 (patch) | |
| tree | 32612cfbd39531c1f21eab0777cf10a197b269d4 /source | |
| parent | 0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (diff) | |
Rename `JVPDerivativeModifier` -> `ForwardDifferentiableAttribute`. (#2472)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 2 |
6 files changed, 13 insertions, 6 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38d7270e4..ca7c1d3bd 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2,7 +2,8 @@ /// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. -syntax __differentiate_jvp : JVPDerivativeModifier; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom JVP Function reference __attributeTarget(FunctionDeclBase) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 6220fcb95..ee350be25 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -30,7 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)}; class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)}; class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; -class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; + // Marks that the definition of a decl is not yet synthesized. class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)}; @@ -1015,6 +1015,12 @@ class RequiresNVAPIAttribute : public Attribute SLANG_AST_CLASS(RequiresNVAPIAttribute) }; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +class ForwardDifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(ForwardDifferentiableAttribute) +}; + /// The `[__custom_jvp(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. class CustomJVPAttribute : public Attribute diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f28f46deb..457ae229b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5230,7 +5230,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier<JVPDerivativeModifier>()) + if (decl->findModifier<ForwardDifferentiableAttribute>()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 0975de985..c7d69262d 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -942,7 +942,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && - this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 6a8f802f7..6bc4b9d36 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -324,7 +324,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. Switch to caching the result asap. if (this->m_parentFunc && - this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) { auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); if (auto subtypeWitness = as<SubtypeWitness>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3766a1a5e..386cf2a21 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7805,7 +7805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, irFunc, decl); addLinkageDecoration(context, irFunc, decl); - if (decl->findModifier<JVPDerivativeModifier>()) + if (decl->findModifier<ForwardDifferentiableAttribute>()) { getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } |
