diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-24 22:19:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-24 22:19:38 -0700 |
| commit | 41cb7c13e37ec32ffb6557d21da079d77151e136 (patch) | |
| tree | 38d2c44938e2679c42c5c0e73f5411e59015df93 /tests/autodiff/generic-impl-jvp.slang | |
| parent | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff) | |
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors.
* Fix getter-setter test.
* Fix getter-setter-multi test.
* Fix nested-jvp test.
* Use [DerivativeMember] attribute to differentiate through member access.
* Clean up.
* More cleanup.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff/generic-impl-jvp.slang')
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 22 |
1 files changed, 4 insertions, 18 deletions
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 5bf3a25c3..fe4ffc426 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -17,15 +17,11 @@ struct dvector __generic<T : IDFloat, let N : int> struct myvector : IDifferentiable { - T values[N]; typedef dvector<T.Differential, N> Differential; - [__unsafeForceInlineEarly] - static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d) - { - return &(d.values); - } - + [DerivativeMember(Differential.values)] + T values[N]; + __init(T c) { for (int i = 0; i < N; i++) @@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable { typedef lineardvector<N> Differential; + [DerivativeMember(Differential.val)] myvector<Real, N> val; - [__unsafeForceInlineEarly] - static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec) - { - return &(dvec.val); - } - - static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v) - { - dvec.val = v; - } - static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) { return linearvector<N>(a.val + b.val); |
