From 41cb7c13e37ec32ffb6557d21da079d77151e136 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 24 Oct 2022 22:19:38 -0700 Subject: Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460) * wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He --- tests/autodiff/generic-impl-jvp.slang | 22 ++++------------------ tests/autodiff/generic-jvp.slang | 14 ++++---------- tests/autodiff/getter-setter-multi.slang | 16 +++------------- tests/autodiff/getter-setter.slang | 7 +------ 4 files changed, 12 insertions(+), 47 deletions(-) (limited to 'tests/autodiff') diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 5bf3a25c3..fe4ffc426 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -17,15 +17,11 @@ struct dvector __generic struct myvector : IDifferentiable { - T values[N]; typedef dvector Differential; - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_values(inout Differential d) - { - return &(d.values); - } - + [DerivativeMember(Differential.values)] + T values[N]; + __init(T c) { for (int i = 0; i < N; i++) @@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable { typedef lineardvector Differential; + [DerivativeMember(Differential.val)] myvector val; - [__unsafeForceInlineEarly] - static Ptr.Differential> __getDifferentialFor_val(inout Differential dvec) - { - return &(dvec.val); - } - - static void __setDifferentialForVal(lineardvector dvec, myvector.Differential v) - { - dvec.val = v; - } - static __differentiate_jvp linearvector ladd(linearvector a, linearvector b) { return linearvector(a.val + b.val); diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 54a99cae9..bcd5e764e 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -87,11 +87,8 @@ extension myfloat3 : IDifferentiable { typedef myfloat3 Differential; - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector val; static Differential zero() { @@ -114,11 +111,8 @@ extension myfloat4 : IDifferentiable { typedef myfloat4 Differential; - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_val(inout Differential dx) - { - return &(dx.val); - } + [DerivativeMember(Differential.val)] + extern vector val; static Differential zero() { diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 61cb96a07..c19a3f6bb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -13,22 +13,12 @@ struct B struct A : IDifferentiable { typedef B Differential; - + + [DerivativeMember(B.z)] float3 x; + [DerivativeMember(B.k)] float y[10]; - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_y(inout Differential b) - { - return &(b.k); - } - [__unsafeForceInlineEarly] static Differential zero() { diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 6b280433b..0e8cac13b 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -13,15 +13,10 @@ struct A : IDifferentiable { typedef B Differential; + [DerivativeMember(B.z)] float x; float y; - [__unsafeForceInlineEarly] - static Ptr __getDifferentialFor_x(inout Differential b) - { - return &(b.z); - } - [__unsafeForceInlineEarly] static Differential zero() { -- cgit v1.2.3