summaryrefslogtreecommitdiff
path: root/tests/autodiff
diff options
context:
space:
mode:
Diffstat (limited to 'tests/autodiff')
-rw-r--r--tests/autodiff/generic-impl-jvp.slang22
-rw-r--r--tests/autodiff/generic-jvp.slang14
-rw-r--r--tests/autodiff/getter-setter-multi.slang16
-rw-r--r--tests/autodiff/getter-setter.slang7
4 files changed, 12 insertions, 47 deletions
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index 5bf3a25c3..fe4ffc426 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -17,15 +17,11 @@ struct dvector
__generic<T : IDFloat, let N : int>
struct myvector : IDifferentiable
{
- T values[N];
typedef dvector<T.Differential, N> Differential;
- [__unsafeForceInlineEarly]
- static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d)
- {
- return &(d.values);
- }
-
+ [DerivativeMember(Differential.values)]
+ T values[N];
+
__init(T c)
{
for (int i = 0; i < N; i++)
@@ -158,19 +154,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
{
typedef lineardvector<N> Differential;
+ [DerivativeMember(Differential.val)]
myvector<Real, N> val;
- [__unsafeForceInlineEarly]
- static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec)
- {
- return &(dvec.val);
- }
-
- static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v)
- {
- dvec.val = v;
- }
-
static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b)
{
return linearvector<N>(a.val + b.val);
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 54a99cae9..bcd5e764e 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -87,11 +87,8 @@ extension myfloat3 : IDifferentiable
{
typedef myfloat3 Differential;
- [__unsafeForceInlineEarly]
- static Ptr<float3> __getDifferentialFor_val(inout Differential dx)
- {
- return &(dx.val);
- }
+ [DerivativeMember(Differential.val)]
+ extern vector<Real, 3> val;
static Differential zero()
{
@@ -114,11 +111,8 @@ extension myfloat4 : IDifferentiable
{
typedef myfloat4 Differential;
- [__unsafeForceInlineEarly]
- static Ptr<float4> __getDifferentialFor_val(inout Differential dx)
- {
- return &(dx.val);
- }
+ [DerivativeMember(Differential.val)]
+ extern vector<Real, 4> val;
static Differential zero()
{
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
index 61cb96a07..c19a3f6bb 100644
--- a/tests/autodiff/getter-setter-multi.slang
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -13,23 +13,13 @@ struct B
struct A : IDifferentiable
{
typedef B Differential;
-
+
+ [DerivativeMember(B.z)]
float3 x;
+ [DerivativeMember(B.k)]
float y[10];
[__unsafeForceInlineEarly]
- static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b)
- {
- return &(b.z);
- }
-
- [__unsafeForceInlineEarly]
- static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b)
- {
- return &(b.k);
- }
-
- [__unsafeForceInlineEarly]
static Differential zero()
{
B b = {0.0};
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 6b280433b..0e8cac13b 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -13,16 +13,11 @@ struct A : IDifferentiable
{
typedef B Differential;
+ [DerivativeMember(B.z)]
float x;
float y;
[__unsafeForceInlineEarly]
- static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b)
- {
- return &(b.z);
- }
-
- [__unsafeForceInlineEarly]
static Differential zero()
{
B b = {0.0};