diff options
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 22 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 14 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang | 16 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang | 7 |
4 files changed, 12 insertions, 47 deletions
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 5bf3a25c3..fe4ffc426 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -17,15 +17,11 @@ struct dvector __generic<T : IDFloat, let N : int> struct myvector : IDifferentiable { - T values[N]; typedef dvector<T.Differential, N> Differential; - [__unsafeForceInlineEarly] - static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d) - { - return &(d.values); - } - + [DerivativeMember(Differential.values)] + T values[N]; + __init(T c) { for (int i = 0; i < N; i++) @@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable { typedef lineardvector<N> Differential; + [DerivativeMember(Differential.val)] myvector<Real, N> val; - [__unsafeForceInlineEarly] - static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec) - { - return &(dvec.val); - } - - static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v) - { - dvec.val = v; - } - static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) { return linearvector<N>(a.val + b.val); diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 54a99cae9..bcd5e764e 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -87,11 +87,8 @@ extension myfloat3 : IDifferentiable { typedef myfloat3 Differential; - [__unsafeForceInlineEarly] - static Ptr<float3> __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector<Real, 3> val; static Differential zero() { @@ -114,11 +111,8 @@ extension myfloat4 : IDifferentiable { typedef myfloat4 Differential; - [__unsafeForceInlineEarly] - static Ptr<float4> __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector<Real, 4> val; static Differential zero() { diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 61cb96a07..c19a3f6bb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -13,23 +13,13 @@ struct B struct A : IDifferentiable { typedef B Differential; - + + [DerivativeMember(B.z)] float3 x; + [DerivativeMember(B.k)] float y[10]; [__unsafeForceInlineEarly] - static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - - [__unsafeForceInlineEarly] - static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b) - { - return &(b.k); - } - - [__unsafeForceInlineEarly] static Differential zero() { B b = {0.0}; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 6b280433b..0e8cac13b 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -13,16 +13,11 @@ struct A : IDifferentiable { typedef B Differential; + [DerivativeMember(B.z)] float x; float y; [__unsafeForceInlineEarly] - static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - - [__unsafeForceInlineEarly] static Differential zero() { B b = {0.0}; |
