diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 13:08:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 13:08:30 -0700 |
| commit | 79af29af91fb9601886d539526a4ec87bca3d74c (patch) | |
| tree | d92096aff783cf83cd01673f74a91a1b1372d3ef /source | |
| parent | 8dc9efd256bd211d8c446971f09a7c79e644b110 (diff) | |
Rename `[__custom_jvp]` -> `[ForwardDerivative]`. (#2473)
* Rename `[__custom_jvp]` -> `[ForwardDerivative]`.
* Rename the classes.
* More renaming.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 |
4 files changed, 12 insertions, 12 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ca7c1d3bd..e6ddb1cf6 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -5,9 +5,9 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; -// Custom JVP Function reference +// Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) -attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; +attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; /// Interface to denote types as differentiable. /// Allows for user-specified differential types as @@ -167,7 +167,7 @@ namespace dstd __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [__custom_jvp(d_exp<T>)] + [ForwardDerivative(d_exp<T>)] T exp(T x); __generic<T : IDFloat> @@ -185,7 +185,7 @@ namespace dstd __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [__custom_jvp(d_sin<T>)] + [ForwardDerivative(d_sin<T>)] T sin(T x); __generic<T : IDFloat> @@ -203,7 +203,7 @@ namespace dstd __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [__custom_jvp(d_cos<T>)] + [ForwardDerivative(d_cos<T>)] T cos(T x); __generic<T : IDFloat> diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ee350be25..76106074f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1021,11 +1021,11 @@ class ForwardDifferentiableAttribute : public Attribute SLANG_AST_CLASS(ForwardDifferentiableAttribute) }; - /// The `[__custom_jvp(function)]` attribute specifies a custom function that should + /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. -class CustomJVPAttribute : public Attribute +class ForwardDerivativeAttribute : public Attribute { - SLANG_AST_CLASS(CustomJVPAttribute) + SLANG_AST_CLASS(ForwardDerivativeAttribute) DeclRefExpr* funcDeclRef; }; diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e189b9114..91f655a15 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -625,7 +625,7 @@ namespace Slang callablePayloadAttr->location = (int32_t)val->value; } - else if (auto customJVPAttr = as<CustomJVPAttribute>(attr)) + else if (auto forwardDerivativeAttr = as<ForwardDerivativeAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); @@ -723,7 +723,7 @@ namespace Slang } // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) - customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); + forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 386cf2a21..acb7869e0 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8204,10 +8204,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier<CustomJVPAttribute>()) + if (auto attr = decl->findModifier<ForwardDerivativeAttribute>()) { // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly. // If we don't move the cursor to the parent, we sometimes emit supporting |
