summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 13:08:30 -0700
committerGitHub <noreply@github.com>2022-10-27 13:08:30 -0700
commit79af29af91fb9601886d539526a4ec87bca3d74c (patch)
treed92096aff783cf83cd01673f74a91a1b1372d3ef /source
parent8dc9efd256bd211d8c446971f09a7c79e644b110 (diff)
Rename `[__custom_jvp]` -> `[ForwardDerivative]`. (#2473)
* Rename `[__custom_jvp]` -> `[ForwardDerivative]`. * Rename the classes. * More renaming. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang10
-rw-r--r--source/slang/slang-ast-modifier.h6
-rw-r--r--source/slang/slang-check-modifier.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
4 files changed, 12 insertions, 12 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ca7c1d3bd..e6ddb1cf6 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -5,9 +5,9 @@
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
-// Custom JVP Function reference
+// Custom Forward Derivative Function reference
__attributeTarget(FunctionDeclBase)
-attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
+attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
@@ -167,7 +167,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_exp($0)")
__target_intrinsic(cpp, "$P_exp($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [__custom_jvp(d_exp<T>)]
+ [ForwardDerivative(d_exp<T>)]
T exp(T x);
__generic<T : IDFloat>
@@ -185,7 +185,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_sin($0)")
__target_intrinsic(cpp, "$P_sin($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [__custom_jvp(d_sin<T>)]
+ [ForwardDerivative(d_sin<T>)]
T sin(T x);
__generic<T : IDFloat>
@@ -203,7 +203,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_cos($0)")
__target_intrinsic(cpp, "$P_cos($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [__custom_jvp(d_cos<T>)]
+ [ForwardDerivative(d_cos<T>)]
T cos(T x);
__generic<T : IDFloat>
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index ee350be25..76106074f 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1021,11 +1021,11 @@ class ForwardDifferentiableAttribute : public Attribute
SLANG_AST_CLASS(ForwardDifferentiableAttribute)
};
- /// The `[__custom_jvp(function)]` attribute specifies a custom function that should
+ /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
-class CustomJVPAttribute : public Attribute
+class ForwardDerivativeAttribute : public Attribute
{
- SLANG_AST_CLASS(CustomJVPAttribute)
+ SLANG_AST_CLASS(ForwardDerivativeAttribute)
DeclRefExpr* funcDeclRef;
};
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e189b9114..91f655a15 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -625,7 +625,7 @@ namespace Slang
callablePayloadAttr->location = (int32_t)val->value;
}
- else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
+ else if (auto forwardDerivativeAttr = as<ForwardDerivativeAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
@@ -723,7 +723,7 @@ namespace Slang
}
// TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
- customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
+ forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 386cf2a21..acb7869e0 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8204,10 +8204,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
- // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute
+ // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<CustomJVPAttribute>())
+ if (auto attr = decl->findModifier<ForwardDerivativeAttribute>())
{
// TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly.
// If we don't move the cursor to the parent, we sometimes emit supporting