summaryrefslogtreecommitdiff
path: root/tests/autodiff/generic-impl-jvp.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-24 22:19:38 -0700
committerGitHub <noreply@github.com>2022-10-24 22:19:38 -0700
commit41cb7c13e37ec32ffb6557d21da079d77151e136 (patch)
tree38d2c44938e2679c42c5c0e73f5411e59015df93 /tests/autodiff/generic-impl-jvp.slang
parent1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff)
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff/generic-impl-jvp.slang')
-rw-r--r--tests/autodiff/generic-impl-jvp.slang22
1 files changed, 4 insertions, 18 deletions
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index 5bf3a25c3..fe4ffc426 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -17,15 +17,11 @@ struct dvector
__generic<T : IDFloat, let N : int>
struct myvector : IDifferentiable
{
- T values[N];
typedef dvector<T.Differential, N> Differential;
- [__unsafeForceInlineEarly]
- static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d)
- {
- return &(d.values);
- }
-
+ [DerivativeMember(Differential.values)]
+ T values[N];
+
__init(T c)
{
for (int i = 0; i < N; i++)
@@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
{
typedef lineardvector<N> Differential;
+ [DerivativeMember(Differential.val)]
myvector<Real, N> val;
- [__unsafeForceInlineEarly]
- static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec)
- {
- return &(dvec.val);
- }
-
- static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v)
- {
- dvec.val = v;
- }
-
static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b)
{
return linearvector<N>(a.val + b.val);