summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /tests
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/generic-custom-jvp.slang35
-rw-r--r--tests/autodiff/generic-impl-jvp.slang304
-rw-r--r--tests/autodiff/generic-impl-jvp.slang.expected.txt6
-rw-r--r--tests/autodiff/generic-jvp.slang200
-rw-r--r--tests/autodiff/generic-jvp.slang.expected.txt6
-rw-r--r--tests/autodiff/getter-setter-multi.slang83
-rw-r--r--tests/autodiff/getter-setter-multi.slang.expected.txt6
-rw-r--r--tests/autodiff/getter-setter.slang69
-rw-r--r--tests/autodiff/getter-setter.slang.expected.txt6
-rw-r--r--tests/autodiff/imported-custom-jvp.slang25
-rw-r--r--tests/autodiff/overloads-jvp.slang40
-rw-r--r--tests/autodiff/overloads-jvp.slang.expected.txt6
-rw-r--r--tests/autodiff/test-intrinsics-jvp.slang3
-rw-r--r--tests/autodiff/vector-arithmetic-jvp.slang3
14 files changed, 775 insertions, 17 deletions
diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang
new file mode 100644
index 000000000..3f0d85b60
--- /dev/null
+++ b/tests/autodiff/generic-custom-jvp.slang
@@ -0,0 +1,35 @@
+//TEST_IGNORE_FILE:
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias IDFloat = IFloat & IDifferentiable;
+
+__generic<T : IDFloat>
+typedef __DifferentialPair<T> dfloat;
+
+import test_intrinsics;
+
+dpfloat my_pow_jvp(dpfloat x, dpfloat n)
+{
+ return dpfloat(
+ pow(x.p(), n.p()),
+ x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p()));
+}
+
+[__custom_jvp(my_pow_jvp)]
+float _pow(float, float);
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(5.0, 1.0);
+ dpfloat dpn = dpfloat(2, 0.0);
+
+ outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0
+ outputBuffer[1] = __jvp(_pow)(
+ dpfloat(dpa.p(), 0.0),
+ dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595
+ }
+}
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
new file mode 100644
index 000000000..5bf3a25c3
--- /dev/null
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -0,0 +1,304 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef float Real;
+
+typealias IDFloat = IFloat & IDifferentiable;
+
+__generic<T, let N : int>
+struct dvector
+{
+ T values[N];
+};
+
+__generic<T : IDFloat, let N : int>
+struct myvector : IDifferentiable
+{
+ T values[N];
+ typedef dvector<T.Differential, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d)
+ {
+ return &(d.values);
+ }
+
+ __init(T c)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ values[i] = c;
+ }
+ }
+
+ static Differential dadd(Differential a, Differential b)
+ {
+ Differential output;
+
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = T.dadd(a.values[i], b.values[i]);
+ }
+
+ return output;
+ }
+
+
+ static Differential dmul(This a, Differential b)
+ {
+ Differential output;
+
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = T.dmul(a.values[i], b.values[i]);
+ }
+
+ return output;
+ }
+
+ static Differential zero()
+ {
+ Differential output;
+
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = T.zero();
+ }
+
+ return output;
+ }
+};
+
+__generic<T : IDFloat, let N : int>
+__differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b)
+{
+ myvector<T, N> output;
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = a.values[i] + b.values[i];
+ }
+ return output;
+}
+
+__generic<T : IDFloat, let N : int>
+ __differentiate_jvp myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b)
+{
+ myvector<T, N> output;
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = a.values[i] * b.values[i];
+ }
+ return output;
+}
+
+__generic<T : IDFloat, let N : int>
+ __differentiate_jvp myvector<T, N> operator *(T a, myvector<T, N> b)
+{
+ myvector<T, N> output;
+ for (int i = 0; i < N; i++)
+ {
+ output.values[i] = a * b.values[i];
+ }
+ return output;
+}
+
+__generic<T : IDFloat, let N : int>
+[__custom_jvp(dot_jvp)]
+T dot(myvector<T, N> a, myvector<T, N> b)
+{
+ T curr = (T)0.0;
+ for (int i = 0; i < N; i++)
+ {
+ curr = curr + (a.values[i] * b.values[i]);
+ }
+
+ return curr;
+}
+
+__generic<T : IDFloat, let N : int>
+typedef __DifferentialPair<myvector<T, N>> dpvector;
+
+__generic<T : IDFloat, let N : int>
+__DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b)
+{
+ T.Differential curr_d = (T.zero());
+ T curr_p = (T)0.0;
+ for (int i = 0; i < N; i++)
+ {
+ curr_p = curr_p + (a.p().values[i] * b.p().values[i]);
+ curr_d = T.dadd(
+ curr_d,
+ T.dadd(
+ T.dmul(a.p().values[i], b.d().values[i]),
+ T.dmul(b.p().values[i], a.d().values[i])));
+ }
+
+ return __DifferentialPair<T>(curr_p, curr_d);
+}
+
+__generic<let N : int>
+struct lineardvector
+{
+ myvector<Real, N>.Differential val;
+
+ __init(vector<Real.Differential, N> a)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ val.values[i] = a[i];
+ }
+ }
+};
+
+__generic<let N : int>
+struct linearvector : MyLinearArithmeticType, IDifferentiable
+{
+ typedef lineardvector<N> Differential;
+
+ myvector<Real, N> val;
+
+ [__unsafeForceInlineEarly]
+ static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec)
+ {
+ return &(dvec.val);
+ }
+
+ static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v)
+ {
+ dvec.val = v;
+ }
+
+ static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b)
+ {
+ return linearvector<N>(a.val + b.val);
+ }
+
+ static __differentiate_jvp linearvector<N> lmul(linearvector<N> a, linearvector<N> b)
+ {
+ return linearvector<N>(a.val * b.val);
+ }
+
+ static __differentiate_jvp linearvector<N> lscale(float a, linearvector<N> b)
+ {
+ return linearvector<N>(a * b.val);
+ }
+
+ static __differentiate_jvp float ldot(linearvector<N> a, linearvector<N> b)
+ {
+ return dot(a.val, b.val);
+ }
+
+ static Differential zero()
+ {
+ lineardvector<N> dout;
+ dout.val = myvector<Real, N>.zero();
+ return dout;
+ }
+
+ static Differential dadd(Differential a, Differential b)
+ {
+ return { myvector<Real, N>.dadd(a.val, b.val) };
+ }
+
+ static Differential dmul(This a, Differential b)
+ {
+ return { myvector<Real, N>.dmul(a.val, b.val) };
+ }
+
+ __differentiate_jvp __init(vector<Real, N> a)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ val.values[i] = a[i];
+ }
+ }
+
+ __differentiate_jvp __init(myvector<Real, N> a)
+ {
+ val = a;
+ }
+};
+
+typedef linearvector<3> myfloat3;
+typedef linearvector<4> myfloat4;
+
+typedef lineardvector<3> mydfloat3;
+typedef lineardvector<4> mydfloat4;
+
+typedef __DifferentialPair<Real> dpfloat;
+
+interface MyLinearArithmeticType
+{
+ static This ladd(This a, This b);
+ static This lmul(This a, This b);
+ static This lscale(Real a, This b);
+ static Real ldot(This a, This b);
+};
+
+typedef __DifferentialPair<myfloat4> dpfloat4;
+typedef __DifferentialPair<myfloat3> dpfloat3;
+
+extension float : MyLinearArithmeticType
+{
+ static __differentiate_jvp float ladd(float a, float b)
+ {
+ return a + b;
+ }
+
+ static __differentiate_jvp float lmul(float a, float b)
+ {
+ return a * b;
+ }
+
+ static __differentiate_jvp float lscale(float a, float b)
+ {
+ return a * b;
+ }
+
+ static __differentiate_jvp float ldot(float a, float b)
+ {
+ return a * b;
+ }
+};
+
+typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType;
+
+__generic<T : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp T operator +(T a, T b)
+{
+ return T.ladd(a, b);
+}
+
+__generic<T : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp T operator *(T a, T b)
+{
+ return T.lmul(a, b);
+}
+
+__generic<G : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp G f(G x)
+{
+ G a = x + x;
+ G b = x * x;
+
+ return a * a + G.lscale((Real)3.0, x);
+}
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), mydfloat4(float4(0.5, 0.8, 1.6, 2.5)));
+ dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5)));
+
+ outputBuffer[0] = f(dpa.p()); // Expect: 22.0
+ outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
+ outputBuffer[2] = __jvp(f)(dpf4).d().val.values[3]; // Expect: 27.5
+ outputBuffer[3] = __jvp(f)(dpf3).d().val.values[1]; // Expect: 40.5
+ }
+}
diff --git a/tests/autodiff/generic-impl-jvp.slang.expected.txt b/tests/autodiff/generic-impl-jvp.slang.expected.txt
new file mode 100644
index 000000000..ceeaf120e
--- /dev/null
+++ b/tests/autodiff/generic-impl-jvp.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+22.000000
+9.500000
+27.500000
+40.500000
+0.000000
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 48993c21c..54a99cae9 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -1,30 +1,202 @@
-//TEST_IGNORE_FILE:(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_IGNORE_FILE:(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typedef __DifferentialPair<float> dpfloat;
-typedef __DifferentialPair<double> dpdouble;
-typedef __DifferentialPair<float3> dpfloat3;
+typedef float Real;
-__generic<T:__BuiltinArithmeticType>
-__differentiate_jvp T g(T x)
+__generic<let N : int>
+struct myvector
{
- return x + x;
+ vector<Real, N> val;
}
+extension myvector<3> : MyLinearArithmeticType
+{
+ static __differentiate_jvp myvector<3> ladd(myvector<3> a, myvector<3> b)
+ {
+ return myvector<3>(a.val + b.val);
+ }
+
+ static __differentiate_jvp myvector<3> lmul(myvector<3> a, myvector<3> b)
+ {
+ return myvector<3>(a.val * b.val);
+ }
+
+ static __differentiate_jvp myvector<3> lscale(float a, myvector<3> b)
+ {
+ return myvector<3>(a * b.val);
+ }
+
+ static __differentiate_jvp float ldot(myvector<3> a, myvector<3> b)
+ {
+ return dot(a.val, b.val);
+ }
+
+ __differentiate_jvp __init(vector<Real, 3> a)
+ {
+ val = a;
+ }
+};
+
+
+extension myvector<4> : MyLinearArithmeticType
+{
+ static __differentiate_jvp myvector<4> ladd(myvector<4> a, myvector<4> b)
+ {
+ return myvector<4>(a.val + b.val);
+ }
+
+ static __differentiate_jvp myvector<4> lmul(myvector<4> a, myvector<4> b)
+ {
+ return myvector<4>(a.val * b.val);
+ }
+
+ static __differentiate_jvp myvector<4> lscale(float a, myvector<4> b)
+ {
+ return myvector<4>(a * b.val);
+ }
+
+ static __differentiate_jvp float ldot(myvector<4> a, myvector<4> b)
+ {
+ return dot(a.val, b.val);
+ }
+
+ __differentiate_jvp __init(vector<Real, 4> a)
+ {
+ val = a;
+ }
+
+};
+
+typedef myvector<3> myfloat3;
+typedef myvector<4> myfloat4;
+
+typedef __DifferentialPair<Real> dpfloat;
+
+interface MyLinearArithmeticType
+{
+ static This ladd(This a, This b);
+ static This lmul(This a, This b);
+ static This lscale(Real a, This b);
+ static Real ldot(This a, This b);
+};
+
+extension myfloat3 : IDifferentiable
+{
+ typedef myfloat3 Differential;
+
+ [__unsafeForceInlineEarly]
+ static Ptr<float3> __getDifferentialFor_val(inout Differential dx)
+ {
+ return &(dx.val);
+ }
+
+ static Differential zero()
+ {
+ return myfloat3(0);
+ }
+
+ static __differentiate_jvp Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ static __differentiate_jvp Differential dmul(Differential a, Differential b)
+ {
+ return a * b;
+ }
+
+};
+
+extension myfloat4 : IDifferentiable
+{
+ typedef myfloat4 Differential;
+
+ [__unsafeForceInlineEarly]
+ static Ptr<float4> __getDifferentialFor_val(inout Differential dx)
+ {
+ return &(dx.val);
+ }
+
+ static Differential zero()
+ {
+ return myfloat4(0);
+ }
+
+ static __differentiate_jvp Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ static __differentiate_jvp Differential dmul(Differential a, Differential b)
+ {
+ return a * b;
+ }
+};
+
+typedef __DifferentialPair<myfloat4> dpfloat4;
+typedef __DifferentialPair<myfloat3> dpfloat3;
+
+extension float : MyLinearArithmeticType
+{
+ static __differentiate_jvp float ladd(float a, float b)
+ {
+ return a + b;
+ }
+
+ static __differentiate_jvp float lmul(float a, float b)
+ {
+ return a * b;
+ }
+
+ static __differentiate_jvp float lscale(float a, float b)
+ {
+ return a * b;
+ }
+
+ static __differentiate_jvp float ldot(float a, float b)
+ {
+ return a * b;
+ }
+};
+
+typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType;
+
+__generic<T : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp T operator +(T a, T b)
+{
+ return T.ladd(a, b);
+}
+
+__generic<T : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp T operator *(T a, T b)
+{
+ return T.lmul(a, b);
+}
+
+__generic<G : MyLinearArithmeticDifferentiableType>
+__differentiate_jvp G f(G x)
+{
+ G a = x + x;
+ G b = x * x;
+
+ return a * a + G.lscale((Real)3.0, x);
+}
+
+
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
{
dpfloat dpa = dpfloat(2.0, 1.0);
- dpdouble dpb = dpdouble(1.5, 2.0);
- dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5));
+ dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), myfloat4(float4(0.5, 0.8, 1.6, 2.5)));
+ dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5)));
- outputBuffer[0] = f(dpa.p()); // Expect: 1
- outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.0)).d(); // Expect: 0
- outputBuffer[2] = (float)__jvp(f)(dpb).d(); // Expect: 2
- outputBuffer[3] = __jvp(f)(dpf3).d().y; // Expect: 1.5
+ outputBuffer[0] = f(dpa.p()); // Expect: 22.0
+ outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5
+ outputBuffer[2] = __jvp(f)(dpf4).d().val.w; // Expect: 27.5
+ outputBuffer[3] = __jvp(f)(dpf3).d().val.y; // Expect: 40.5
}
}
diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt
new file mode 100644
index 000000000..ceeaf120e
--- /dev/null
+++ b/tests/autodiff/generic-jvp.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+22.000000
+9.500000
+27.500000
+40.500000
+0.000000
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
new file mode 100644
index 000000000..61cb96a07
--- /dev/null
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -0,0 +1,83 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct B
+{
+ float3 z;
+ float.Differential k[10];
+};
+
+struct A : IDifferentiable
+{
+ typedef B Differential;
+
+ float3 x;
+ float y[10];
+
+ [__unsafeForceInlineEarly]
+ static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b)
+ {
+ return &(b.z);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b)
+ {
+ return &(b.k);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ B b = {0.0};
+ return b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ B o = {a.z + b.z};
+ return o;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ B o = {a.x * b.z};
+ return o;
+ }
+};
+
+typedef __DifferentialPair<A> dpA;
+
+__differentiate_jvp A f(A a)
+{
+ A aout;
+
+ aout.y[5] = (2 * a.x).y;
+ aout.y[2] = (3 * a.y[4]);
+ aout.x = float3(5 * a.x.z, 3 * a.x.y, 0.5 * a.x.x);
+
+ return aout;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ float arr[10] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 };
+ A a = {float3(1.0, 2.0, 3.0), arr};
+
+ float d_arr[10] = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0 };
+ B b = {float3(1.0, 0.5, 0.3), d_arr};
+
+ dpA dpa = dpA(a, b);
+
+ outputBuffer[0] = __jvp(f)(dpa).d().z.z; // Expect: 0.5
+ outputBuffer[1] = __jvp(f)(dpa).d().k[5]; // Expect: 1
+ outputBuffer[2] = __jvp(f)(dpa).d().k[2]; // Expect: 1.5
+ }
+}
diff --git a/tests/autodiff/getter-setter-multi.slang.expected.txt b/tests/autodiff/getter-setter-multi.slang.expected.txt
new file mode 100644
index 000000000..ece9872b0
--- /dev/null
+++ b/tests/autodiff/getter-setter-multi.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.500000
+1.000000
+1.500000
+0.000000
+0.000000
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
new file mode 100644
index 000000000..6b280433b
--- /dev/null
+++ b/tests/autodiff/getter-setter.slang
@@ -0,0 +1,69 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct B
+{
+ float z;
+};
+
+struct A : IDifferentiable
+{
+ typedef B Differential;
+
+ float x;
+ float y;
+
+ [__unsafeForceInlineEarly]
+ static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b)
+ {
+ return &(b.z);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential zero()
+ {
+ B b = {0.0};
+ return b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ B o = {a.z + b.z};
+ return o;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ B o = {a.x * b.z};
+ return o;
+ }
+};
+
+typedef __DifferentialPair<A> dpA;
+
+__differentiate_jvp A f(A a)
+{
+ A aout;
+ aout.y = 2 * a.x;
+ aout.x = 5 * a.x;
+
+ return aout;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ A a = {1.0, 2.0};
+ B b = {0.2};
+
+ dpA dpa = dpA(a, b);
+
+ outputBuffer[0] = __jvp(f)(dpa).d().z; // Expect: 1
+ }
+}
diff --git a/tests/autodiff/getter-setter.slang.expected.txt b/tests/autodiff/getter-setter.slang.expected.txt
new file mode 100644
index 000000000..ca54c9afe
--- /dev/null
+++ b/tests/autodiff/getter-setter.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+1.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang
new file mode 100644
index 000000000..ee8bdf51d
--- /dev/null
+++ b/tests/autodiff/imported-custom-jvp.slang
@@ -0,0 +1,25 @@
+//TEST_IGNORE_FILE:
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+import test_intrinsics_jvp;
+
+typedef __DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+__differentiate_jvp float f(float x)
+{
+ return pow_(x, 2.0);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat dpb = dpfloat(1.5, 1.0);
+
+ outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 2
+ }
+}
diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang
new file mode 100644
index 000000000..26b5c0076
--- /dev/null
+++ b/tests/autodiff/overloads-jvp.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef __DifferentialPair<float> dpfloat;
+typedef __DifferentialPair<float3> dpfloat3;
+
+__differentiate_jvp float f(float a)
+{
+ return a * a + a;
+}
+
+__differentiate_jvp float f(float3 a)
+{
+ return a.x * a.y + a.z;
+}
+
+__differentiate_jvp float g(float a)
+{
+ // df((2.0, 4.0, 6.0), (1.0, 2.0, 3.0))
+ // 2.0 * 2.0 + 4.0 * 1.0 + 3.0 = 11.0
+ return f(float3(a, 2*a, 3*a));
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5));
+
+ outputBuffer[0] = f(dpa.p()); // Expect: 6
+ outputBuffer[1] = f(dpf3.p()); // Expect: 8
+ outputBuffer[2] = __jvp(f)(dpf3).d(); // Expect: 5.5
+ outputBuffer[3] = __jvp(f)(dpa).d(); // Expect: 5
+ outputBuffer[4] = __jvp(g)(dpa).d(); // Expect: 11.0
+ }
+}
diff --git a/tests/autodiff/overloads-jvp.slang.expected.txt b/tests/autodiff/overloads-jvp.slang.expected.txt
new file mode 100644
index 000000000..999777e1e
--- /dev/null
+++ b/tests/autodiff/overloads-jvp.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+6.0
+8.0
+5.5
+5.0
+11.0 \ No newline at end of file
diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang
index 333c89189..cb4c5c6b4 100644
--- a/tests/autodiff/test-intrinsics-jvp.slang
+++ b/tests/autodiff/test-intrinsics-jvp.slang
@@ -14,4 +14,5 @@ float max_(float x, float y);
float max_jvp(float x, float y, float dx, float dy)
{
return (x > y) ? dx : dy;
-} \ No newline at end of file
+}
+
diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang
index 393cc18ec..e05d94733 100644
--- a/tests/autodiff/vector-arithmetic-jvp.slang
+++ b/tests/autodiff/vector-arithmetic-jvp.slang
@@ -37,11 +37,10 @@ __differentiate_jvp float4 j(float4 x, float4 y)
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
- {
+ {
float3 a = float3(2.0, 2.0, 2.0);
float3 b = float3(1.5, 1.5, 1.5);
float3 da = float3(1.0, 1.0, 1.0);
- //dpfloat3 dpa = dpfloat3(a, da);
float2 a2 = float2(2.0, 1.0);
float2 b2 = float2(1.5, -2.0);