summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 12:30:15 -0700
committerGitHub <noreply@github.com>2022-10-27 12:30:15 -0700
commit8dc9efd256bd211d8c446971f09a7c79e644b110 (patch)
tree32612cfbd39531c1f21eab0777cf10a197b269d4 /source
parent0cbef6fd6d7924d37ef3ea5ec7c848c80947d13f (diff)
Rename `JVPDerivativeModifier` -> `ForwardDifferentiableAttribute`. (#2472)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-check-type.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
6 files changed, 13 insertions, 6 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 38d7270e4..ca7c1d3bd 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -2,7 +2,8 @@
/// Modifer to mark a function for forward-mode differentiation.
/// i.e. the compiler will automatically generate a new function
/// that computes the jacobian-vector product of the original.
-syntax __differentiate_jvp : JVPDerivativeModifier;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
// Custom JVP Function reference
__attributeTarget(FunctionDeclBase)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 6220fcb95..ee350be25 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -30,7 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)};
class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)};
class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)};
class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)};
-class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)};
+
// Marks that the definition of a decl is not yet synthesized.
class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)};
@@ -1015,6 +1015,12 @@ class RequiresNVAPIAttribute : public Attribute
SLANG_AST_CLASS(RequiresNVAPIAttribute)
};
+ /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated.
+class ForwardDifferentiableAttribute : public Attribute
+{
+ SLANG_AST_CLASS(ForwardDifferentiableAttribute)
+};
+
/// The `[__custom_jvp(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
class CustomJVPAttribute : public Attribute
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f28f46deb..457ae229b 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5230,7 +5230,7 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
- if (decl->findModifier<JVPDerivativeModifier>())
+ if (decl->findModifier<ForwardDifferentiableAttribute>())
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 0975de985..c7d69262d 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -942,7 +942,7 @@ namespace Slang
// Differentiable type checking.
// TODO: This can be super slow.
if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
}
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 6a8f802f7..6bc4b9d36 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -324,7 +324,7 @@ namespace Slang
// Differentiable type checking.
// TODO: This can be super slow. Switch to caching the result asap.
if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
{
auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
if (auto subtypeWitness = as<SubtypeWitness>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 3766a1a5e..386cf2a21 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7805,7 +7805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(context, irFunc, decl);
addLinkageDecoration(context, irFunc, decl);
- if (decl->findModifier<JVPDerivativeModifier>())
+ if (decl->findModifier<ForwardDifferentiableAttribute>())
{
getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
}