summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/diff.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-check-type.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
-rw-r--r--tests/autodiff/arithmetic-jvp.slang9
-rw-r--r--tests/autodiff/auto-differential-type.slang3
-rw-r--r--tests/autodiff/custom-intrinsic.slang6
-rw-r--r--tests/autodiff/differential-method-synthesis.slang3
-rw-r--r--tests/autodiff/dstdlib.slang9
-rw-r--r--tests/autodiff/generic-impl-jvp.slang48
-rw-r--r--tests/autodiff/generic-jvp.slang65
-rw-r--r--tests/autodiff/getter-setter-multi.slang3
-rw-r--r--tests/autodiff/getter-setter.slang3
-rw-r--r--tests/autodiff/imported-custom-jvp.slang3
-rw-r--r--tests/autodiff/inout-parameters-jvp.slang6
-rw-r--r--tests/autodiff/nested-jvp.slang6
-rw-r--r--tests/autodiff/out-parameters-jvp.slang3
-rw-r--r--tests/autodiff/overloads-jvp.slang9
-rw-r--r--tests/autodiff/vector-arithmetic-jvp.slang12
-rw-r--r--tests/autodiff/vector-swizzle-jvp.slang6
22 files changed, 142 insertions, 71 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 38d7270e4..ca7c1d3bd 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -2,7 +2,8 @@
/// Modifer to mark a function for forward-mode differentiation.
/// i.e. the compiler will automatically generate a new function
/// that computes the jacobian-vector product of the original.
-syntax __differentiate_jvp : JVPDerivativeModifier;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
// Custom JVP Function reference
__attributeTarget(FunctionDeclBase)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 6220fcb95..ee350be25 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -30,7 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)};
class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)};
class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)};
class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)};
-class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)};
+
// Marks that the definition of a decl is not yet synthesized.
class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)};
@@ -1015,6 +1015,12 @@ class RequiresNVAPIAttribute : public Attribute
SLANG_AST_CLASS(RequiresNVAPIAttribute)
};
+ /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated.
+class ForwardDifferentiableAttribute : public Attribute
+{
+ SLANG_AST_CLASS(ForwardDifferentiableAttribute)
+};
+
/// The `[__custom_jvp(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
class CustomJVPAttribute : public Attribute
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f28f46deb..457ae229b 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5230,7 +5230,7 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
- if (decl->findModifier<JVPDerivativeModifier>())
+ if (decl->findModifier<ForwardDifferentiableAttribute>())
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 0975de985..c7d69262d 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -942,7 +942,7 @@ namespace Slang
// Differentiable type checking.
// TODO: This can be super slow.
if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
}
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 6a8f802f7..6bc4b9d36 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -324,7 +324,7 @@ namespace Slang
// Differentiable type checking.
// TODO: This can be super slow. Switch to caching the result asap.
if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
{
auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
if (auto subtypeWitness = as<SubtypeWitness>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 3766a1a5e..386cf2a21 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7805,7 +7805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(context, irFunc, decl);
addLinkageDecoration(context, irFunc, decl);
- if (decl->findModifier<JVPDerivativeModifier>())
+ if (decl->findModifier<ForwardDifferentiableAttribute>())
{
getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
}
diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang
index 0c7dd039d..3b06393d3 100644
--- a/tests/autodiff/arithmetic-jvp.slang
+++ b/tests/autodiff/arithmetic-jvp.slang
@@ -7,7 +7,8 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
-__differentiate_jvp float f(float x)
+[ForwardDifferentiable]
+float f(float x)
{
return x;
}
@@ -23,14 +24,16 @@ float g(float x)
return x + x;
}
-__differentiate_jvp float h(float x, float y)
+[ForwardDifferentiable]
+float h(float x, float y)
{
float m = x + y;
float n = x - y;
return m * n + 2 * x * y;
}
-__differentiate_jvp float j(float x, float y)
+[ForwardDifferentiable]
+float j(float x, float y)
{
float m = x / y;
return m * y;
diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang
index b551db4ab..a687c823e 100644
--- a/tests/autodiff/auto-differential-type.slang
+++ b/tests/autodiff/auto-differential-type.slang
@@ -35,7 +35,8 @@ struct A : IDifferentiable
typedef __DifferentialPair<A> dpA;
-__differentiate_jvp A f(A a)
+[ForwardDifferentiable]
+A f(A a)
{
A aout;
aout.y = 2 * a.x;
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang
index 02f6541f5..770a18b1e 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic.slang
@@ -87,12 +87,14 @@ namespace myintrinsiclib
}
};
-__differentiate_jvp float f(float x)
+[ForwardDifferentiable]
+float f(float x)
{
return myintrinsiclib.exp(x);
}
-__differentiate_jvp float g(float x)
+[ForwardDifferentiable]
+float g(float x)
{
float s;
float t;
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang
index 3ecd636e9..76c15d5a1 100644
--- a/tests/autodiff/differential-method-synthesis.slang
+++ b/tests/autodiff/differential-method-synthesis.slang
@@ -24,7 +24,8 @@ A nonDiff(A a)
return a;
}
-__differentiate_jvp A f(A a)
+[ForwardDifferentiable]
+A f(A a)
{
A aout;
aout.y = 2 * a.b.x;
diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang
index 614de54f6..05d915c29 100644
--- a/tests/autodiff/dstdlib.slang
+++ b/tests/autodiff/dstdlib.slang
@@ -6,17 +6,20 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
-__differentiate_jvp float f(float x)
+[ForwardDifferentiable]
+float f(float x)
{
return dstd.exp(x);
}
-__differentiate_jvp float g(float x)
+[ForwardDifferentiable]
+float g(float x)
{
return dstd.sin(x);
}
-__differentiate_jvp float h(float x)
+[ForwardDifferentiable]
+float h(float x)
{
return dstd.cos(x);
}
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index e14f851ac..3ebbff996 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -68,8 +68,9 @@ struct myvector : IDifferentiable
}
};
+[ForwardDifferentiable]
__generic<T : IDFloat, let N : int>
-__differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b)
+myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b)
{
myvector<T, N> output;
for (int i = 0; i < N; i++)
@@ -79,8 +80,9 @@ __differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b
return output;
}
+[ForwardDifferentiable]
__generic<T : IDFloat, let N : int>
- __differentiate_jvp myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b)
+myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b)
{
myvector<T, N> output;
for (int i = 0; i < N; i++)
@@ -90,8 +92,9 @@ __generic<T : IDFloat, let N : int>
return output;
}
+[ForwardDifferentiable]
__generic<T : IDFloat, let N : int>
- __differentiate_jvp myvector<T, N> operator *(T a, myvector<T, N> b)
+myvector<T, N> operator *(T a, myvector<T, N> b)
{
myvector<T, N> output;
for (int i = 0; i < N; i++)
@@ -157,22 +160,26 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
[DerivativeMember(Differential.val)]
myvector<Real, N> val;
- static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b)
+ [ForwardDifferentiable]
+ static linearvector<N> ladd(linearvector<N> a, linearvector<N> b)
{
return linearvector<N>(a.val + b.val);
}
- static __differentiate_jvp linearvector<N> lmul(linearvector<N> a, linearvector<N> b)
+ [ForwardDifferentiable]
+ static linearvector<N> lmul(linearvector<N> a, linearvector<N> b)
{
return linearvector<N>(a.val * b.val);
}
- static __differentiate_jvp linearvector<N> lscale(float a, linearvector<N> b)
+ [ForwardDifferentiable]
+ static linearvector<N> lscale(float a, linearvector<N> b)
{
return linearvector<N>(a * b.val);
}
- static __differentiate_jvp float ldot(linearvector<N> a, linearvector<N> b)
+ [ForwardDifferentiable]
+ static float ldot(linearvector<N> a, linearvector<N> b)
{
return dot(a.val, b.val);
}
@@ -194,7 +201,8 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
return { myvector<Real, N>.dmul(a.val, b.val) };
}
- __differentiate_jvp __init(vector<Real, N> a)
+ [ForwardDifferentiable]
+ __init(vector<Real, N> a)
{
for (int i = 0; i < N; i++)
{
@@ -202,7 +210,8 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
}
}
- __differentiate_jvp __init(myvector<Real, N> a)
+ [ForwardDifferentiable]
+ __init(myvector<Real, N> a)
{
val = a;
}
@@ -229,22 +238,26 @@ typedef __DifferentialPair<myfloat3> dpfloat3;
extension float : MyLinearArithmeticType
{
- static __differentiate_jvp float ladd(float a, float b)
+ [ForwardDifferentiable]
+ static float ladd(float a, float b)
{
return a + b;
}
- static __differentiate_jvp float lmul(float a, float b)
+ [ForwardDifferentiable]
+ static float lmul(float a, float b)
{
return a * b;
}
- static __differentiate_jvp float lscale(float a, float b)
+ [ForwardDifferentiable]
+ static float lscale(float a, float b)
{
return a * b;
}
- static __differentiate_jvp float ldot(float a, float b)
+ [ForwardDifferentiable]
+ static float ldot(float a, float b)
{
return a * b;
}
@@ -253,19 +266,22 @@ extension float : MyLinearArithmeticType
typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType;
__generic<T : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp T operator +(T a, T b)
+[ForwardDifferentiable]
+T operator +(T a, T b)
{
return T.ladd(a, b);
}
__generic<T : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp T operator *(T a, T b)
+[ForwardDifferentiable]
+T operator *(T a, T b)
{
return T.lmul(a, b);
}
__generic<G : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp G f(G x)
+[ForwardDifferentiable]
+G f(G x)
{
G a = x + x;
G b = x * x;
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 365be45aa..c22c228a6 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -14,27 +14,32 @@ struct myvector
extension myvector<3> : MyLinearArithmeticType
{
- static __differentiate_jvp myvector<3> ladd(myvector<3> a, myvector<3> b)
+ [ForwardDifferentiable]
+ static myvector<3> ladd(myvector<3> a, myvector<3> b)
{
return myvector<3>(a.val + b.val);
}
- static __differentiate_jvp myvector<3> lmul(myvector<3> a, myvector<3> b)
+ [ForwardDifferentiable]
+ static myvector<3> lmul(myvector<3> a, myvector<3> b)
{
return myvector<3>(a.val * b.val);
}
- static __differentiate_jvp myvector<3> lscale(float a, myvector<3> b)
+ [ForwardDifferentiable]
+ static myvector<3> lscale(float a, myvector<3> b)
{
return myvector<3>(a * b.val);
}
- static __differentiate_jvp float ldot(myvector<3> a, myvector<3> b)
+ [ForwardDifferentiable]
+ static float ldot(myvector<3> a, myvector<3> b)
{
return dot(a.val, b.val);
}
- __differentiate_jvp __init(vector<Real, 3> a)
+ [ForwardDifferentiable]
+__init(vector<Real, 3> a)
{
val = a;
}
@@ -43,27 +48,32 @@ extension myvector<3> : MyLinearArithmeticType
extension myvector<4> : MyLinearArithmeticType
{
- static __differentiate_jvp myvector<4> ladd(myvector<4> a, myvector<4> b)
+ [ForwardDifferentiable]
+ static myvector<4> ladd(myvector<4> a, myvector<4> b)
{
return myvector<4>(a.val + b.val);
}
- static __differentiate_jvp myvector<4> lmul(myvector<4> a, myvector<4> b)
+ [ForwardDifferentiable]
+ static myvector<4> lmul(myvector<4> a, myvector<4> b)
{
return myvector<4>(a.val * b.val);
}
- static __differentiate_jvp myvector<4> lscale(float a, myvector<4> b)
+ [ForwardDifferentiable]
+ static myvector<4> lscale(float a, myvector<4> b)
{
return myvector<4>(a * b.val);
}
- static __differentiate_jvp float ldot(myvector<4> a, myvector<4> b)
+ [ForwardDifferentiable]
+ static float ldot(myvector<4> a, myvector<4> b)
{
return dot(a.val, b.val);
}
- __differentiate_jvp __init(vector<Real, 4> a)
+ [ForwardDifferentiable]
+ __init(vector<Real, 4> a)
{
val = a;
}
@@ -95,12 +105,14 @@ extension myfloat3 : IDifferentiable
return myfloat3(0);
}
- static __differentiate_jvp Differential dadd(Differential a, Differential b)
+ [ForwardDifferentiable]
+ static Differential dadd(Differential a, Differential b)
{
return a + b;
}
- static __differentiate_jvp Differential dmul(Differential a, Differential b)
+ [ForwardDifferentiable]
+ static Differential dmul(Differential a, Differential b)
{
return a * b;
}
@@ -119,12 +131,14 @@ extension myfloat4 : IDifferentiable
return myfloat4(0);
}
- static __differentiate_jvp Differential dadd(Differential a, Differential b)
+ [ForwardDifferentiable]
+ static Differential dadd(Differential a, Differential b)
{
return a + b;
}
- static __differentiate_jvp Differential dmul(Differential a, Differential b)
+ [ForwardDifferentiable]
+ static Differential dmul(Differential a, Differential b)
{
return a * b;
}
@@ -135,22 +149,26 @@ typedef __DifferentialPair<myfloat3> dpfloat3;
extension float : MyLinearArithmeticType
{
- static __differentiate_jvp float ladd(float a, float b)
+ [ForwardDifferentiable]
+ static float ladd(float a, float b)
{
return a + b;
}
- static __differentiate_jvp float lmul(float a, float b)
+ [ForwardDifferentiable]
+ static float lmul(float a, float b)
{
return a * b;
}
- static __differentiate_jvp float lscale(float a, float b)
+ [ForwardDifferentiable]
+ static float lscale(float a, float b)
{
return a * b;
}
-
- static __differentiate_jvp float ldot(float a, float b)
+
+ [ForwardDifferentiable]
+ static float ldot(float a, float b)
{
return a * b;
}
@@ -159,19 +177,22 @@ extension float : MyLinearArithmeticType
typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType;
__generic<T : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp T operator +(T a, T b)
+[ForwardDifferentiable]
+T operator +(T a, T b)
{
return T.ladd(a, b);
}
__generic<T : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp T operator *(T a, T b)
+[ForwardDifferentiable]
+T operator *(T a, T b)
{
return T.lmul(a, b);
}
__generic<G : MyLinearArithmeticDifferentiableType>
-__differentiate_jvp G f(G x)
+[ForwardDifferentiable]
+G f(G x)
{
G a = x + x;
G b = x * x;
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
index 08816c5bc..bc1f571eb 100644
--- a/tests/autodiff/getter-setter-multi.slang
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -43,7 +43,8 @@ struct A : IDifferentiable
typedef __DifferentialPair<A> dpA;
-__differentiate_jvp A f(A a)
+[ForwardDifferentiable]
+A f(A a)
{
A aout;
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 2f385b87f..1d7847b41 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -41,7 +41,8 @@ struct A : IDifferentiable
typedef __DifferentialPair<A> dpA;
-__differentiate_jvp A f(A a)
+[ForwardDifferentiable]
+A f(A a)
{
A aout;
aout.y = 2 * a.x;
diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang
index 8adcdee25..f5251740a 100644
--- a/tests/autodiff/imported-custom-jvp.slang
+++ b/tests/autodiff/imported-custom-jvp.slang
@@ -8,7 +8,8 @@ import test_intrinsics_jvp;
typedef __DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
-__differentiate_jvp float f(float x)
+[ForwardDifferentiable]
+float f(float x)
{
return pow_(x, 2.0);
}
diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang
index e53e5db7c..ab4a3c790 100644
--- a/tests/autodiff/inout-parameters-jvp.slang
+++ b/tests/autodiff/inout-parameters-jvp.slang
@@ -6,14 +6,16 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
-__differentiate_jvp void g(float x, float y, inout float z)
+[ForwardDifferentiable]
+void g(float x, float y, inout float z)
{
float m = x + y;
float n = x - y;
z += m * n + 2 * x * y;
}
-__differentiate_jvp void h(float x, float y, inout float z)
+[ForwardDifferentiable]
+void h(float x, float y, inout float z)
{
float m = x + y;
float n = x - y;
diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang
index 96648d861..40518d44d 100644
--- a/tests/autodiff/nested-jvp.slang
+++ b/tests/autodiff/nested-jvp.slang
@@ -36,12 +36,14 @@ dpfloat max_jvp(dpfloat x, dpfloat y)
/* Fresnel Schlick example */
-__differentiate_jvp float3 fresnel(float3 f0, float3 f90, float cosTheta)
+[ForwardDifferentiable]
+float3 fresnel(float3 f0, float3 f90, float cosTheta)
{
return f0 + (f90 - f0) * pow_(max_(1 - cosTheta, 0.0), 5);
}
-__differentiate_jvp float g(float a, float b, float c)
+[ForwardDifferentiable]
+float g(float a, float b, float c)
{
return fresnel(float3(a), float3(b), 2 * c * c).y;
}
diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang
index 9a311ed31..072c5158b 100644
--- a/tests/autodiff/out-parameters-jvp.slang
+++ b/tests/autodiff/out-parameters-jvp.slang
@@ -6,7 +6,8 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
-__differentiate_jvp void h(float x, float y, out float result)
+[ForwardDifferentiable]
+void h(float x, float y, out float result)
{
float m = x + y;
float n = x - y;
diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang
index 95b9cadd3..2577009c3 100644
--- a/tests/autodiff/overloads-jvp.slang
+++ b/tests/autodiff/overloads-jvp.slang
@@ -7,17 +7,20 @@ RWStructuredBuffer<float> outputBuffer;
typedef __DifferentialPair<float> dpfloat;
typedef __DifferentialPair<float3> dpfloat3;
-__differentiate_jvp float f(float a)
+[ForwardDifferentiable]
+float f(float a)
{
return a * a + a;
}
-__differentiate_jvp float f(float3 a)
+[ForwardDifferentiable]
+float f(float3 a)
{
return a.x * a.y + a.z;
}
-__differentiate_jvp float g(float a)
+[ForwardDifferentiable]
+float g(float a)
{
// df((2.0, 4.0, 6.0), (1.0, 2.0, 3.0))
// 2.0 * 2.0 + 4.0 * 1.0 + 3.0 = 11.0
diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang
index b79b3e764..cf0eb6170 100644
--- a/tests/autodiff/vector-arithmetic-jvp.slang
+++ b/tests/autodiff/vector-arithmetic-jvp.slang
@@ -8,26 +8,30 @@ typedef __DifferentialPair<float2> dpfloat2;
typedef __DifferentialPair<float3> dpfloat3;
typedef __DifferentialPair<float4> dpfloat4;
-__differentiate_jvp float3 f(float3 x)
+[ForwardDifferentiable]
+float3 f(float3 x)
{
return x;
}
-__differentiate_jvp float3 g(float3 x, float3 y)
+[ForwardDifferentiable]
+float3 g(float3 x, float3 y)
{
float3 a = x + y;
float3 b = x - y;
return a * b + 2 * x * y;
}
-__differentiate_jvp float2 h(float2 x, float2 y)
+[ForwardDifferentiable]
+float2 h(float2 x, float2 y)
{
float2 a = x + y;
float2 b = x - y;
return a * b + 2 * x * y;
}
-__differentiate_jvp float4 j(float4 x, float4 y)
+[ForwardDifferentiable]
+float4 j(float4 x, float4 y)
{
float4 a = x + y;
float4 b = x - y;
diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang
index fc726d067..f7a045b25 100644
--- a/tests/autodiff/vector-swizzle-jvp.slang
+++ b/tests/autodiff/vector-swizzle-jvp.slang
@@ -8,12 +8,14 @@ typedef __DifferentialPair<float2> dpfloat2;
typedef __DifferentialPair<float3> dpfloat3;
typedef __DifferentialPair<float4> dpfloat4;
-__differentiate_jvp float2 f(float3 x)
+[ForwardDifferentiable]
+float2 f(float3 x)
{
return x.zy;
}
-__differentiate_jvp float2 g(float3 x, float4 y)
+[ForwardDifferentiable]
+float2 g(float3 x, float4 y)
{
float3 a = x + y.zyx;
float2 b = x.zx - y.yw;