diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-24 22:19:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-24 22:19:38 -0700 |
| commit | 41cb7c13e37ec32ffb6557d21da079d77151e136 (patch) | |
| tree | 38d2c44938e2679c42c5c0e73f5411e59015df93 /tests/autodiff | |
| parent | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff) | |
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors.
* Fix getter-setter test.
* Fix getter-setter-multi test.
* Fix nested-jvp test.
* Use [DerivativeMember] attribute to differentiate through member access.
* Clean up.
* More cleanup.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 22 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 14 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang | 16 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang | 7 |
4 files changed, 12 insertions, 47 deletions
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 5bf3a25c3..fe4ffc426 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -17,15 +17,11 @@ struct dvector __generic<T : IDFloat, let N : int> struct myvector : IDifferentiable { - T values[N]; typedef dvector<T.Differential, N> Differential; - [__unsafeForceInlineEarly] - static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d) - { - return &(d.values); - } - + [DerivativeMember(Differential.values)] + T values[N]; + __init(T c) { for (int i = 0; i < N; i++) @@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable { typedef lineardvector<N> Differential; + [DerivativeMember(Differential.val)] myvector<Real, N> val; - [__unsafeForceInlineEarly] - static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec) - { - return &(dvec.val); - } - - static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v) - { - dvec.val = v; - } - static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) { return linearvector<N>(a.val + b.val); diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 54a99cae9..bcd5e764e 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -87,11 +87,8 @@ extension myfloat3 : IDifferentiable { typedef myfloat3 Differential; - [__unsafeForceInlineEarly] - static Ptr<float3> __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector<Real, 3> val; static Differential zero() { @@ -114,11 +111,8 @@ extension myfloat4 : IDifferentiable { typedef myfloat4 Differential; - [__unsafeForceInlineEarly] - static Ptr<float4> __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector<Real, 4> val; static Differential zero() { diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 61cb96a07..c19a3f6bb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -13,23 +13,13 @@ struct B struct A : IDifferentiable { typedef B Differential; - + + [DerivativeMember(B.z)] float3 x; + [DerivativeMember(B.k)] float y[10]; [__unsafeForceInlineEarly] - static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - - [__unsafeForceInlineEarly] - static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b) - { - return &(b.k); - } - - [__unsafeForceInlineEarly] static Differential zero() { B b = {0.0}; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 6b280433b..0e8cac13b 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -13,16 +13,11 @@ struct A : IDifferentiable { typedef B Differential; + [DerivativeMember(B.z)] float x; float y; [__unsafeForceInlineEarly] - static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - - [__unsafeForceInlineEarly] static Differential zero() { B b = {0.0}; |
