From 79af29af91fb9601886d539526a4ec87bca3d74c Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 27 Oct 2022 13:08:30 -0700 Subject: Rename `[__custom_jvp]` -> `[ForwardDerivative]`. (#2473) * Rename `[__custom_jvp]` -> `[ForwardDerivative]`. * Rename the classes. * More renaming. Co-authored-by: Yong He --- source/slang/diff.meta.slang | 10 +++++----- source/slang/slang-ast-modifier.h | 6 +++--- source/slang/slang-check-modifier.cpp | 4 ++-- source/slang/slang-lower-to-ir.cpp | 4 ++-- tests/autodiff/arithmetic-jvp.slang | 2 +- tests/autodiff/custom-intrinsic.slang | 8 ++++---- tests/autodiff/generic-custom-jvp.slang | 2 +- tests/autodiff/generic-impl-jvp.slang | 2 +- tests/autodiff/local-redecl-custom-jvp.slang | 2 +- tests/autodiff/nested-jvp.slang | 4 ++-- tests/autodiff/test-intrinsics-jvp.slang | 4 ++-- 11 files changed, 24 insertions(+), 24 deletions(-) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ca7c1d3bd..e6ddb1cf6 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -5,9 +5,9 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; -// Custom JVP Function reference +// Custom Forward Derivative Function reference __attributeTarget(FunctionDeclBase) -attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; +attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; /// Interface to denote types as differentiable. /// Allows for user-specified differential types as @@ -167,7 +167,7 @@ namespace dstd __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [__custom_jvp(d_exp)] + [ForwardDerivative(d_exp)] T exp(T x); __generic @@ -185,7 +185,7 @@ namespace dstd __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [__custom_jvp(d_sin)] + [ForwardDerivative(d_sin)] T sin(T x); __generic @@ -203,7 +203,7 @@ namespace dstd __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [__custom_jvp(d_cos)] + [ForwardDerivative(d_cos)] T cos(T x); __generic diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ee350be25..76106074f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1021,11 +1021,11 @@ class ForwardDifferentiableAttribute : public Attribute SLANG_AST_CLASS(ForwardDifferentiableAttribute) }; - /// The `[__custom_jvp(function)]` attribute specifies a custom function that should + /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. -class CustomJVPAttribute : public Attribute +class ForwardDerivativeAttribute : public Attribute { - SLANG_AST_CLASS(CustomJVPAttribute) + SLANG_AST_CLASS(ForwardDerivativeAttribute) DeclRefExpr* funcDeclRef; }; diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e189b9114..91f655a15 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -625,7 +625,7 @@ namespace Slang callablePayloadAttr->location = (int32_t)val->value; } - else if (auto customJVPAttr = as(attr)) + else if (auto forwardDerivativeAttr = as(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as(attrTarget)); @@ -723,7 +723,7 @@ namespace Slang } // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) - customJVPAttr->funcDeclRef = as(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); + forwardDerivativeAttr->funcDeclRef = as(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); } else if (auto comInterfaceAttr = as(attr)) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 386cf2a21..acb7869e0 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8204,10 +8204,10 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier()) + if (auto attr = decl->findModifier()) { // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly. // If we don't move the cursor to the parent, we sometimes emit supporting diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index 3b06393d3..ec2c5bc6f 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -18,7 +18,7 @@ dpfloat g_jvp_(dpfloat dpx) return dpfloat(dpx.p(), 2 * dpx.d()); } -[__custom_jvp(g_jvp_)] +[ForwardDerivative(g_jvp_)] float g(float x) { return x + x; diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 770a18b1e..7591cd624 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -16,7 +16,7 @@ namespace myintrinsiclib __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [__custom_jvp(d_exp)] + [ForwardDerivative(d_exp)] T exp(T x); __generic @@ -35,7 +35,7 @@ namespace myintrinsiclib __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [__custom_jvp(d_sin)] + [ForwardDerivative(d_sin)] T sin(T x); __generic @@ -53,7 +53,7 @@ namespace myintrinsiclib __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [__custom_jvp(d_cos)] + [ForwardDerivative(d_cos)] T cos(T x); __generic @@ -68,7 +68,7 @@ namespace myintrinsiclib __generic __target_intrinsic(hlsl) __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") - [__custom_jvp(d_sincos)] + [ForwardDerivative(d_sincos)] void sincos(T x, out T s, out T c) { s = sin(x); diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang index f0b8d3898..5111f0e48 100644 --- a/tests/autodiff/generic-custom-jvp.slang +++ b/tests/autodiff/generic-custom-jvp.slang @@ -17,7 +17,7 @@ dpfloat my_pow_jvp(dpfloat x, dpfloat n) x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); } -[__custom_jvp(my_pow_jvp)] +[ForwardDerivative(my_pow_jvp)] float _pow(float, float); [numthreads(1, 1, 1)] diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 3ebbff996..d47da336e 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -105,7 +105,7 @@ myvector operator *(T a, myvector b) } __generic -[__custom_jvp(dot_jvp)] +[ForwardDerivative(dot_jvp)] T dot(myvector a, myvector b) { T curr = (T)0.0; diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 79b90bd16..3a6b6f474 100644 --- a/tests/autodiff/local-redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang @@ -15,7 +15,7 @@ dpfloat my_pow_jvp(dpfloat x, dpfloat n) x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); } -[__custom_jvp(my_pow_jvp)] +[ForwardDerivative(my_pow_jvp)] float _pow(float, float); [numthreads(1, 1, 1)] diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index 40518d44d..0e7d19078 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -7,13 +7,13 @@ RWStructuredBuffer outputBuffer; typedef __DifferentialPair dpfloat; typedef __DifferentialPair dpfloat3; -[__custom_jvp(pow_jvp)] +[ForwardDerivative(pow_jvp)] float pow_(float x, float n) { return pow(x, n); } -[__custom_jvp(max_jvp)] +[ForwardDerivative(max_jvp)] float max_(float x, float y) { return max(x, y); diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang index cb4c5c6b4..39f2ee495 100644 --- a/tests/autodiff/test-intrinsics-jvp.slang +++ b/tests/autodiff/test-intrinsics-jvp.slang @@ -2,14 +2,14 @@ __exported import test_intrinsics; -[__custom_jvp(pow_jvp)] +[ForwardDerivative(pow_jvp)] float pow_(float x, float n); float pow_jvp(float x, float n, float dx, float dn) { return dx * n * pow(x, n-1) + dn * pow(x, n) * log(x); } -[__custom_jvp(max_jvp)] +[ForwardDerivative(max_jvp)] float max_(float x, float y); float max_jvp(float x, float y, float dx, float dy) { -- cgit v1.2.3