From 8dc9efd256bd211d8c446971f09a7c79e644b110 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 27 Oct 2022 12:30:15 -0700 Subject: Rename `JVPDerivativeModifier` -> `ForwardDifferentiableAttribute`. (#2472) Co-authored-by: Yong He --- source/slang/diff.meta.slang | 3 ++- source/slang/slang-ast-modifier.h | 8 +++++++- source/slang/slang-check-decl.cpp | 2 +- source/slang/slang-check-expr.cpp | 2 +- source/slang/slang-check-type.cpp | 2 +- source/slang/slang-lower-to-ir.cpp | 2 +- 6 files changed, 13 insertions(+), 6 deletions(-) (limited to 'source/slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38d7270e4..ca7c1d3bd 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2,7 +2,8 @@ /// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. -syntax __differentiate_jvp : JVPDerivativeModifier; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom JVP Function reference __attributeTarget(FunctionDeclBase) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 6220fcb95..ee350be25 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -30,7 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)}; class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)}; class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; -class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; + // Marks that the definition of a decl is not yet synthesized. class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)}; @@ -1015,6 +1015,12 @@ class RequiresNVAPIAttribute : public Attribute SLANG_AST_CLASS(RequiresNVAPIAttribute) }; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +class ForwardDifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(ForwardDifferentiableAttribute) +}; + /// The `[__custom_jvp(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. class CustomJVPAttribute : public Attribute diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f28f46deb..457ae229b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5230,7 +5230,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier()) + if (decl->findModifier()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 0975de985..c7d69262d 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -942,7 +942,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && - this->m_parentFunc->findModifier()) + this->m_parentFunc->findModifier()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 6a8f802f7..6bc4b9d36 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -324,7 +324,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. Switch to caching the result asap. if (this->m_parentFunc && - this->m_parentFunc->findModifier()) + this->m_parentFunc->findModifier()) { auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); if (auto subtypeWitness = as( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3766a1a5e..386cf2a21 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7805,7 +7805,7 @@ struct DeclLoweringVisitor : DeclVisitor addNameHint(context, irFunc, decl); addLinkageDecoration(context, irFunc, decl); - if (decl->findModifier()) + if (decl->findModifier()) { getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } -- cgit v1.2.3