summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-27 13:08:30 -0700
committerGitHub <noreply@github.com>2022-10-27 13:08:30 -0700
commit79af29af91fb9601886d539526a4ec87bca3d74c (patch)
treed92096aff783cf83cd01673f74a91a1b1372d3ef
parent8dc9efd256bd211d8c446971f09a7c79e644b110 (diff)
Rename `[__custom_jvp]` -> `[ForwardDerivative]`. (#2473)
* Rename `[__custom_jvp]` -> `[ForwardDerivative]`. * Rename the classes. * More renaming. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/diff.meta.slang10
-rw-r--r--source/slang/slang-ast-modifier.h6
-rw-r--r--source/slang/slang-check-modifier.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
-rw-r--r--tests/autodiff/arithmetic-jvp.slang2
-rw-r--r--tests/autodiff/custom-intrinsic.slang8
-rw-r--r--tests/autodiff/generic-custom-jvp.slang2
-rw-r--r--tests/autodiff/generic-impl-jvp.slang2
-rw-r--r--tests/autodiff/local-redecl-custom-jvp.slang2
-rw-r--r--tests/autodiff/nested-jvp.slang4
-rw-r--r--tests/autodiff/test-intrinsics-jvp.slang4
11 files changed, 24 insertions, 24 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ca7c1d3bd..e6ddb1cf6 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -5,9 +5,9 @@
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
-// Custom JVP Function reference
+// Custom Forward Derivative Function reference
__attributeTarget(FunctionDeclBase)
-attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
+attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
@@ -167,7 +167,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_exp($0)")
__target_intrinsic(cpp, "$P_exp($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [__custom_jvp(d_exp<T>)]
+ [ForwardDerivative(d_exp<T>)]
T exp(T x);
__generic<T : IDFloat>
@@ -185,7 +185,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_sin($0)")
__target_intrinsic(cpp, "$P_sin($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [__custom_jvp(d_sin<T>)]
+ [ForwardDerivative(d_sin<T>)]
T sin(T x);
__generic<T : IDFloat>
@@ -203,7 +203,7 @@ namespace dstd
__target_intrinsic(cuda, "$P_cos($0)")
__target_intrinsic(cpp, "$P_cos($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [__custom_jvp(d_cos<T>)]
+ [ForwardDerivative(d_cos<T>)]
T cos(T x);
__generic<T : IDFloat>
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index ee350be25..76106074f 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1021,11 +1021,11 @@ class ForwardDifferentiableAttribute : public Attribute
SLANG_AST_CLASS(ForwardDifferentiableAttribute)
};
- /// The `[__custom_jvp(function)]` attribute specifies a custom function that should
+ /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
-class CustomJVPAttribute : public Attribute
+class ForwardDerivativeAttribute : public Attribute
{
- SLANG_AST_CLASS(CustomJVPAttribute)
+ SLANG_AST_CLASS(ForwardDerivativeAttribute)
DeclRefExpr* funcDeclRef;
};
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e189b9114..91f655a15 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -625,7 +625,7 @@ namespace Slang
callablePayloadAttr->location = (int32_t)val->value;
}
- else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
+ else if (auto forwardDerivativeAttr = as<ForwardDerivativeAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
@@ -723,7 +723,7 @@ namespace Slang
}
// TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
- customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
+ forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 386cf2a21..acb7869e0 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8204,10 +8204,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
- // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute
+ // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<CustomJVPAttribute>())
+ if (auto attr = decl->findModifier<ForwardDerivativeAttribute>())
{
// TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly.
// If we don't move the cursor to the parent, we sometimes emit supporting
diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang
index 3b06393d3..ec2c5bc6f 100644
--- a/tests/autodiff/arithmetic-jvp.slang
+++ b/tests/autodiff/arithmetic-jvp.slang
@@ -18,7 +18,7 @@ dpfloat g_jvp_(dpfloat dpx)
return dpfloat(dpx.p(), 2 * dpx.d());
}
-[__custom_jvp(g_jvp_)]
+[ForwardDerivative(g_jvp_)]
float g(float x)
{
return x + x;
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang
index 770a18b1e..7591cd624 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic.slang
@@ -16,7 +16,7 @@ namespace myintrinsiclib
__target_intrinsic(cuda, "$P_exp($0)")
__target_intrinsic(cpp, "$P_exp($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [__custom_jvp(d_exp<T>)]
+ [ForwardDerivative(d_exp<T>)]
T exp(T x);
__generic<T : IDFloat>
@@ -35,7 +35,7 @@ namespace myintrinsiclib
__target_intrinsic(cuda, "$P_sin($0)")
__target_intrinsic(cpp, "$P_sin($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [__custom_jvp(d_sin<T>)]
+ [ForwardDerivative(d_sin<T>)]
T sin(T x);
__generic<T : IDFloat>
@@ -53,7 +53,7 @@ namespace myintrinsiclib
__target_intrinsic(cuda, "$P_cos($0)")
__target_intrinsic(cpp, "$P_cos($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [__custom_jvp(d_cos<T>)]
+ [ForwardDerivative(d_cos<T>)]
T cos(T x);
__generic<T : IDFloat>
@@ -68,7 +68,7 @@ namespace myintrinsiclib
__generic<T : IDFloat>
__target_intrinsic(hlsl)
__target_intrinsic(cuda, "$P_sincos($0, $1, $2)")
- [__custom_jvp(d_sincos<T>)]
+ [ForwardDerivative(d_sincos<T>)]
void sincos(T x, out T s, out T c)
{
s = sin(x);
diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang
index f0b8d3898..5111f0e48 100644
--- a/tests/autodiff/generic-custom-jvp.slang
+++ b/tests/autodiff/generic-custom-jvp.slang
@@ -17,7 +17,7 @@ dpfloat my_pow_jvp(dpfloat x, dpfloat n)
x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p()));
}
-[__custom_jvp(my_pow_jvp)]
+[ForwardDerivative(my_pow_jvp)]
float _pow(float, float);
[numthreads(1, 1, 1)]
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index 3ebbff996..d47da336e 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -105,7 +105,7 @@ myvector<T, N> operator *(T a, myvector<T, N> b)
}
__generic<T : IDFloat, let N : int>
-[__custom_jvp(dot_jvp)]
+[ForwardDerivative(dot_jvp)]
T dot(myvector<T, N> a, myvector<T, N> b)
{
T curr = (T)0.0;
diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang
index 79b90bd16..3a6b6f474 100644
--- a/tests/autodiff/local-redecl-custom-jvp.slang
+++ b/tests/autodiff/local-redecl-custom-jvp.slang
@@ -15,7 +15,7 @@ dpfloat my_pow_jvp(dpfloat x, dpfloat n)
x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p()));
}
-[__custom_jvp(my_pow_jvp)]
+[ForwardDerivative(my_pow_jvp)]
float _pow(float, float);
[numthreads(1, 1, 1)]
diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang
index 40518d44d..0e7d19078 100644
--- a/tests/autodiff/nested-jvp.slang
+++ b/tests/autodiff/nested-jvp.slang
@@ -7,13 +7,13 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
typedef __DifferentialPair<float3> dpfloat3;
-[__custom_jvp(pow_jvp)]
+[ForwardDerivative(pow_jvp)]
float pow_(float x, float n)
{
return pow<float>(x, n);
}
-[__custom_jvp(max_jvp)]
+[ForwardDerivative(max_jvp)]
float max_(float x, float y)
{
return max<float>(x, y);
diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang
index cb4c5c6b4..39f2ee495 100644
--- a/tests/autodiff/test-intrinsics-jvp.slang
+++ b/tests/autodiff/test-intrinsics-jvp.slang
@@ -2,14 +2,14 @@
__exported import test_intrinsics;
-[__custom_jvp(pow_jvp)]
+[ForwardDerivative(pow_jvp)]
float pow_(float x, float n);
float pow_jvp(float x, float n, float dx, float dn)
{
return dx * n * pow(x, n-1) + dn * pow(x, n) * log(x);
}
-[__custom_jvp(max_jvp)]
+[ForwardDerivative(max_jvp)]
float max_(float x, float y);
float max_jvp(float x, float y, float dx, float dy)
{