diff options
22 files changed, 142 insertions, 71 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 38d7270e4..ca7c1d3bd 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -2,7 +2,8 @@ /// Modifer to mark a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. -syntax __differentiate_jvp : JVPDerivativeModifier; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; // Custom JVP Function reference __attributeTarget(FunctionDeclBase) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 6220fcb95..ee350be25 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -30,7 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)}; class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)}; class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; -class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; + // Marks that the definition of a decl is not yet synthesized. class ToBeSynthesizedModifier : public Modifier {SLANG_AST_CLASS(ToBeSynthesizedModifier)}; @@ -1015,6 +1015,12 @@ class RequiresNVAPIAttribute : public Attribute SLANG_AST_CLASS(RequiresNVAPIAttribute) }; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. +class ForwardDifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(ForwardDifferentiableAttribute) +}; + /// The `[__custom_jvp(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. class CustomJVPAttribute : public Attribute diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f28f46deb..457ae229b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5230,7 +5230,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier<JVPDerivativeModifier>()) + if (decl->findModifier<ForwardDifferentiableAttribute>()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 0975de985..c7d69262d 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -942,7 +942,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && - this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 6a8f802f7..6bc4b9d36 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -324,7 +324,7 @@ namespace Slang // Differentiable type checking. // TODO: This can be super slow. Switch to caching the result asap. if (this->m_parentFunc && - this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) { auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); if (auto subtypeWitness = as<SubtypeWitness>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3766a1a5e..386cf2a21 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7805,7 +7805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, irFunc, decl); addLinkageDecoration(context, irFunc, decl); - if (decl->findModifier<JVPDerivativeModifier>()) + if (decl->findModifier<ForwardDifferentiableAttribute>()) { getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index 0c7dd039d..3b06393d3 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -7,7 +7,8 @@ RWStructuredBuffer<float> outputBuffer; typedef __DifferentialPair<float> dpfloat; typedef float.Differential dfloat; -__differentiate_jvp float f(float x) +[ForwardDifferentiable] +float f(float x) { return x; } @@ -23,14 +24,16 @@ float g(float x) return x + x; } -__differentiate_jvp float h(float x, float y) +[ForwardDifferentiable] +float h(float x, float y) { float m = x + y; float n = x - y; return m * n + 2 * x * y; } -__differentiate_jvp float j(float x, float y) +[ForwardDifferentiable] +float j(float x, float y) { float m = x / y; return m * y; diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index b551db4ab..a687c823e 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -35,7 +35,8 @@ struct A : IDifferentiable typedef __DifferentialPair<A> dpA; -__differentiate_jvp A f(A a) +[ForwardDifferentiable] +A f(A a) { A aout; aout.y = 2 * a.x; diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 02f6541f5..770a18b1e 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -87,12 +87,14 @@ namespace myintrinsiclib } }; -__differentiate_jvp float f(float x) +[ForwardDifferentiable] +float f(float x) { return myintrinsiclib.exp(x); } -__differentiate_jvp float g(float x) +[ForwardDifferentiable] +float g(float x) { float s; float t; diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 3ecd636e9..76c15d5a1 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -24,7 +24,8 @@ A nonDiff(A a) return a; } -__differentiate_jvp A f(A a) +[ForwardDifferentiable] +A f(A a) { A aout; aout.y = 2 * a.b.x; diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang index 614de54f6..05d915c29 100644 --- a/tests/autodiff/dstdlib.slang +++ b/tests/autodiff/dstdlib.slang @@ -6,17 +6,20 @@ RWStructuredBuffer<float> outputBuffer; typedef __DifferentialPair<float> dpfloat; -__differentiate_jvp float f(float x) +[ForwardDifferentiable] +float f(float x) { return dstd.exp(x); } -__differentiate_jvp float g(float x) +[ForwardDifferentiable] +float g(float x) { return dstd.sin(x); } -__differentiate_jvp float h(float x) +[ForwardDifferentiable] +float h(float x) { return dstd.cos(x); } diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index e14f851ac..3ebbff996 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -68,8 +68,9 @@ struct myvector : IDifferentiable } }; +[ForwardDifferentiable] __generic<T : IDFloat, let N : int> -__differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b) +myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b) { myvector<T, N> output; for (int i = 0; i < N; i++) @@ -79,8 +80,9 @@ __differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b return output; } +[ForwardDifferentiable] __generic<T : IDFloat, let N : int> - __differentiate_jvp myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b) +myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b) { myvector<T, N> output; for (int i = 0; i < N; i++) @@ -90,8 +92,9 @@ __generic<T : IDFloat, let N : int> return output; } +[ForwardDifferentiable] __generic<T : IDFloat, let N : int> - __differentiate_jvp myvector<T, N> operator *(T a, myvector<T, N> b) +myvector<T, N> operator *(T a, myvector<T, N> b) { myvector<T, N> output; for (int i = 0; i < N; i++) @@ -157,22 +160,26 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable [DerivativeMember(Differential.val)] myvector<Real, N> val; - static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) + [ForwardDifferentiable] + static linearvector<N> ladd(linearvector<N> a, linearvector<N> b) { return linearvector<N>(a.val + b.val); } - static __differentiate_jvp linearvector<N> lmul(linearvector<N> a, linearvector<N> b) + [ForwardDifferentiable] + static linearvector<N> lmul(linearvector<N> a, linearvector<N> b) { return linearvector<N>(a.val * b.val); } - static __differentiate_jvp linearvector<N> lscale(float a, linearvector<N> b) + [ForwardDifferentiable] + static linearvector<N> lscale(float a, linearvector<N> b) { return linearvector<N>(a * b.val); } - static __differentiate_jvp float ldot(linearvector<N> a, linearvector<N> b) + [ForwardDifferentiable] + static float ldot(linearvector<N> a, linearvector<N> b) { return dot(a.val, b.val); } @@ -194,7 +201,8 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable return { myvector<Real, N>.dmul(a.val, b.val) }; } - __differentiate_jvp __init(vector<Real, N> a) + [ForwardDifferentiable] + __init(vector<Real, N> a) { for (int i = 0; i < N; i++) { @@ -202,7 +210,8 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable } } - __differentiate_jvp __init(myvector<Real, N> a) + [ForwardDifferentiable] + __init(myvector<Real, N> a) { val = a; } @@ -229,22 +238,26 @@ typedef __DifferentialPair<myfloat3> dpfloat3; extension float : MyLinearArithmeticType { - static __differentiate_jvp float ladd(float a, float b) + [ForwardDifferentiable] + static float ladd(float a, float b) { return a + b; } - static __differentiate_jvp float lmul(float a, float b) + [ForwardDifferentiable] + static float lmul(float a, float b) { return a * b; } - static __differentiate_jvp float lscale(float a, float b) + [ForwardDifferentiable] + static float lscale(float a, float b) { return a * b; } - static __differentiate_jvp float ldot(float a, float b) + [ForwardDifferentiable] + static float ldot(float a, float b) { return a * b; } @@ -253,19 +266,22 @@ extension float : MyLinearArithmeticType typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; __generic<T : MyLinearArithmeticDifferentiableType> -__differentiate_jvp T operator +(T a, T b) +[ForwardDifferentiable] +T operator +(T a, T b) { return T.ladd(a, b); } __generic<T : MyLinearArithmeticDifferentiableType> -__differentiate_jvp T operator *(T a, T b) +[ForwardDifferentiable] +T operator *(T a, T b) { return T.lmul(a, b); } __generic<G : MyLinearArithmeticDifferentiableType> -__differentiate_jvp G f(G x) +[ForwardDifferentiable] +G f(G x) { G a = x + x; G b = x * x; diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 365be45aa..c22c228a6 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -14,27 +14,32 @@ struct myvector extension myvector<3> : MyLinearArithmeticType { - static __differentiate_jvp myvector<3> ladd(myvector<3> a, myvector<3> b) + [ForwardDifferentiable] + static myvector<3> ladd(myvector<3> a, myvector<3> b) { return myvector<3>(a.val + b.val); } - static __differentiate_jvp myvector<3> lmul(myvector<3> a, myvector<3> b) + [ForwardDifferentiable] + static myvector<3> lmul(myvector<3> a, myvector<3> b) { return myvector<3>(a.val * b.val); } - static __differentiate_jvp myvector<3> lscale(float a, myvector<3> b) + [ForwardDifferentiable] + static myvector<3> lscale(float a, myvector<3> b) { return myvector<3>(a * b.val); } - static __differentiate_jvp float ldot(myvector<3> a, myvector<3> b) + [ForwardDifferentiable] + static float ldot(myvector<3> a, myvector<3> b) { return dot(a.val, b.val); } - __differentiate_jvp __init(vector<Real, 3> a) + [ForwardDifferentiable] +__init(vector<Real, 3> a) { val = a; } @@ -43,27 +48,32 @@ extension myvector<3> : MyLinearArithmeticType extension myvector<4> : MyLinearArithmeticType { - static __differentiate_jvp myvector<4> ladd(myvector<4> a, myvector<4> b) + [ForwardDifferentiable] + static myvector<4> ladd(myvector<4> a, myvector<4> b) { return myvector<4>(a.val + b.val); } - static __differentiate_jvp myvector<4> lmul(myvector<4> a, myvector<4> b) + [ForwardDifferentiable] + static myvector<4> lmul(myvector<4> a, myvector<4> b) { return myvector<4>(a.val * b.val); } - static __differentiate_jvp myvector<4> lscale(float a, myvector<4> b) + [ForwardDifferentiable] + static myvector<4> lscale(float a, myvector<4> b) { return myvector<4>(a * b.val); } - static __differentiate_jvp float ldot(myvector<4> a, myvector<4> b) + [ForwardDifferentiable] + static float ldot(myvector<4> a, myvector<4> b) { return dot(a.val, b.val); } - __differentiate_jvp __init(vector<Real, 4> a) + [ForwardDifferentiable] + __init(vector<Real, 4> a) { val = a; } @@ -95,12 +105,14 @@ extension myfloat3 : IDifferentiable return myfloat3(0); } - static __differentiate_jvp Differential dadd(Differential a, Differential b) + [ForwardDifferentiable] + static Differential dadd(Differential a, Differential b) { return a + b; } - static __differentiate_jvp Differential dmul(Differential a, Differential b) + [ForwardDifferentiable] + static Differential dmul(Differential a, Differential b) { return a * b; } @@ -119,12 +131,14 @@ extension myfloat4 : IDifferentiable return myfloat4(0); } - static __differentiate_jvp Differential dadd(Differential a, Differential b) + [ForwardDifferentiable] + static Differential dadd(Differential a, Differential b) { return a + b; } - static __differentiate_jvp Differential dmul(Differential a, Differential b) + [ForwardDifferentiable] + static Differential dmul(Differential a, Differential b) { return a * b; } @@ -135,22 +149,26 @@ typedef __DifferentialPair<myfloat3> dpfloat3; extension float : MyLinearArithmeticType { - static __differentiate_jvp float ladd(float a, float b) + [ForwardDifferentiable] + static float ladd(float a, float b) { return a + b; } - static __differentiate_jvp float lmul(float a, float b) + [ForwardDifferentiable] + static float lmul(float a, float b) { return a * b; } - static __differentiate_jvp float lscale(float a, float b) + [ForwardDifferentiable] + static float lscale(float a, float b) { return a * b; } - - static __differentiate_jvp float ldot(float a, float b) + + [ForwardDifferentiable] + static float ldot(float a, float b) { return a * b; } @@ -159,19 +177,22 @@ extension float : MyLinearArithmeticType typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; __generic<T : MyLinearArithmeticDifferentiableType> -__differentiate_jvp T operator +(T a, T b) +[ForwardDifferentiable] +T operator +(T a, T b) { return T.ladd(a, b); } __generic<T : MyLinearArithmeticDifferentiableType> -__differentiate_jvp T operator *(T a, T b) +[ForwardDifferentiable] +T operator *(T a, T b) { return T.lmul(a, b); } __generic<G : MyLinearArithmeticDifferentiableType> -__differentiate_jvp G f(G x) +[ForwardDifferentiable] +G f(G x) { G a = x + x; G b = x * x; diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 08816c5bc..bc1f571eb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -43,7 +43,8 @@ struct A : IDifferentiable typedef __DifferentialPair<A> dpA; -__differentiate_jvp A f(A a) +[ForwardDifferentiable] +A f(A a) { A aout; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 2f385b87f..1d7847b41 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -41,7 +41,8 @@ struct A : IDifferentiable typedef __DifferentialPair<A> dpA; -__differentiate_jvp A f(A a) +[ForwardDifferentiable] +A f(A a) { A aout; aout.y = 2 * a.x; diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang index 8adcdee25..f5251740a 100644 --- a/tests/autodiff/imported-custom-jvp.slang +++ b/tests/autodiff/imported-custom-jvp.slang @@ -8,7 +8,8 @@ import test_intrinsics_jvp; typedef __DifferentialPair<float> dpfloat; typedef float.Differential dfloat; -__differentiate_jvp float f(float x) +[ForwardDifferentiable] +float f(float x) { return pow_(x, 2.0); } diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang index e53e5db7c..ab4a3c790 100644 --- a/tests/autodiff/inout-parameters-jvp.slang +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -6,14 +6,16 @@ RWStructuredBuffer<float> outputBuffer; typedef __DifferentialPair<float> dpfloat; -__differentiate_jvp void g(float x, float y, inout float z) +[ForwardDifferentiable] +void g(float x, float y, inout float z) { float m = x + y; float n = x - y; z += m * n + 2 * x * y; } -__differentiate_jvp void h(float x, float y, inout float z) +[ForwardDifferentiable] +void h(float x, float y, inout float z) { float m = x + y; float n = x - y; diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index 96648d861..40518d44d 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -36,12 +36,14 @@ dpfloat max_jvp(dpfloat x, dpfloat y) /* Fresnel Schlick example */ -__differentiate_jvp float3 fresnel(float3 f0, float3 f90, float cosTheta) +[ForwardDifferentiable] +float3 fresnel(float3 f0, float3 f90, float cosTheta) { return f0 + (f90 - f0) * pow_(max_(1 - cosTheta, 0.0), 5); } -__differentiate_jvp float g(float a, float b, float c) +[ForwardDifferentiable] +float g(float a, float b, float c) { return fresnel(float3(a), float3(b), 2 * c * c).y; } diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang index 9a311ed31..072c5158b 100644 --- a/tests/autodiff/out-parameters-jvp.slang +++ b/tests/autodiff/out-parameters-jvp.slang @@ -6,7 +6,8 @@ RWStructuredBuffer<float> outputBuffer; typedef __DifferentialPair<float> dpfloat; -__differentiate_jvp void h(float x, float y, out float result) +[ForwardDifferentiable] +void h(float x, float y, out float result) { float m = x + y; float n = x - y; diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang index 95b9cadd3..2577009c3 100644 --- a/tests/autodiff/overloads-jvp.slang +++ b/tests/autodiff/overloads-jvp.slang @@ -7,17 +7,20 @@ RWStructuredBuffer<float> outputBuffer; typedef __DifferentialPair<float> dpfloat; typedef __DifferentialPair<float3> dpfloat3; -__differentiate_jvp float f(float a) +[ForwardDifferentiable] +float f(float a) { return a * a + a; } -__differentiate_jvp float f(float3 a) +[ForwardDifferentiable] +float f(float3 a) { return a.x * a.y + a.z; } -__differentiate_jvp float g(float a) +[ForwardDifferentiable] +float g(float a) { // df((2.0, 4.0, 6.0), (1.0, 2.0, 3.0)) // 2.0 * 2.0 + 4.0 * 1.0 + 3.0 = 11.0 diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index b79b3e764..cf0eb6170 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -8,26 +8,30 @@ typedef __DifferentialPair<float2> dpfloat2; typedef __DifferentialPair<float3> dpfloat3; typedef __DifferentialPair<float4> dpfloat4; -__differentiate_jvp float3 f(float3 x) +[ForwardDifferentiable] +float3 f(float3 x) { return x; } -__differentiate_jvp float3 g(float3 x, float3 y) +[ForwardDifferentiable] +float3 g(float3 x, float3 y) { float3 a = x + y; float3 b = x - y; return a * b + 2 * x * y; } -__differentiate_jvp float2 h(float2 x, float2 y) +[ForwardDifferentiable] +float2 h(float2 x, float2 y) { float2 a = x + y; float2 b = x - y; return a * b + 2 * x * y; } -__differentiate_jvp float4 j(float4 x, float4 y) +[ForwardDifferentiable] +float4 j(float4 x, float4 y) { float4 a = x + y; float4 b = x - y; diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang index fc726d067..f7a045b25 100644 --- a/tests/autodiff/vector-swizzle-jvp.slang +++ b/tests/autodiff/vector-swizzle-jvp.slang @@ -8,12 +8,14 @@ typedef __DifferentialPair<float2> dpfloat2; typedef __DifferentialPair<float3> dpfloat3; typedef __DifferentialPair<float4> dpfloat4; -__differentiate_jvp float2 f(float3 x) +[ForwardDifferentiable] +float2 f(float3 x) { return x.zy; } -__differentiate_jvp float2 g(float3 x, float4 y) +[ForwardDifferentiable] +float2 g(float3 x, float4 y) { float3 a = x + y.zyx; float2 b = x.zx - y.yw; |
