diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b /tests | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/generic-custom-jvp.slang | 35 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 304 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 200 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang | 83 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang | 69 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/imported-custom-jvp.slang | 25 | ||||
| -rw-r--r-- | tests/autodiff/overloads-jvp.slang | 40 | ||||
| -rw-r--r-- | tests/autodiff/overloads-jvp.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/test-intrinsics-jvp.slang | 3 | ||||
| -rw-r--r-- | tests/autodiff/vector-arithmetic-jvp.slang | 3 |
14 files changed, 775 insertions, 17 deletions
diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang new file mode 100644 index 000000000..3f0d85b60 --- /dev/null +++ b/tests/autodiff/generic-custom-jvp.slang @@ -0,0 +1,35 @@ +//TEST_IGNORE_FILE: + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias IDFloat = IFloat & IDifferentiable; + +__generic<T : IDFloat> +typedef __DifferentialPair<T> dfloat; + +import test_intrinsics; + +dpfloat my_pow_jvp(dpfloat x, dpfloat n) +{ + return dpfloat( + pow(x.p(), n.p()), + x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); +} + +[__custom_jvp(my_pow_jvp)] +float _pow(float, float); + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(5.0, 1.0); + dpfloat dpn = dpfloat(2, 0.0); + + outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[1] = __jvp(_pow)( + dpfloat(dpa.p(), 0.0), + dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 + } +} diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang new file mode 100644 index 000000000..5bf3a25c3 --- /dev/null +++ b/tests/autodiff/generic-impl-jvp.slang @@ -0,0 +1,304 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef float Real; + +typealias IDFloat = IFloat & IDifferentiable; + +__generic<T, let N : int> +struct dvector +{ + T values[N]; +}; + +__generic<T : IDFloat, let N : int> +struct myvector : IDifferentiable +{ + T values[N]; + typedef dvector<T.Differential, N> Differential; + + [__unsafeForceInlineEarly] + static Ptr<T.Differential[N]> __getDifferentialFor_values(inout Differential d) + { + return &(d.values); + } + + __init(T c) + { + for (int i = 0; i < N; i++) + { + values[i] = c; + } + } + + static Differential dadd(Differential a, Differential b) + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.dadd(a.values[i], b.values[i]); + } + + return output; + } + + + static Differential dmul(This a, Differential b) + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.dmul(a.values[i], b.values[i]); + } + + return output; + } + + static Differential zero() + { + Differential output; + + for (int i = 0; i < N; i++) + { + output.values[i] = T.zero(); + } + + return output; + } +}; + +__generic<T : IDFloat, let N : int> +__differentiate_jvp myvector<T, N> operator +(myvector<T, N> a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a.values[i] + b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> + __differentiate_jvp myvector<T, N> operator *(myvector<T, N> a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a.values[i] * b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> + __differentiate_jvp myvector<T, N> operator *(T a, myvector<T, N> b) +{ + myvector<T, N> output; + for (int i = 0; i < N; i++) + { + output.values[i] = a * b.values[i]; + } + return output; +} + +__generic<T : IDFloat, let N : int> +[__custom_jvp(dot_jvp)] +T dot(myvector<T, N> a, myvector<T, N> b) +{ + T curr = (T)0.0; + for (int i = 0; i < N; i++) + { + curr = curr + (a.values[i] * b.values[i]); + } + + return curr; +} + +__generic<T : IDFloat, let N : int> +typedef __DifferentialPair<myvector<T, N>> dpvector; + +__generic<T : IDFloat, let N : int> +__DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) +{ + T.Differential curr_d = (T.zero()); + T curr_p = (T)0.0; + for (int i = 0; i < N; i++) + { + curr_p = curr_p + (a.p().values[i] * b.p().values[i]); + curr_d = T.dadd( + curr_d, + T.dadd( + T.dmul(a.p().values[i], b.d().values[i]), + T.dmul(b.p().values[i], a.d().values[i]))); + } + + return __DifferentialPair<T>(curr_p, curr_d); +} + +__generic<let N : int> +struct lineardvector +{ + myvector<Real, N>.Differential val; + + __init(vector<Real.Differential, N> a) + { + for (int i = 0; i < N; i++) + { + val.values[i] = a[i]; + } + } +}; + +__generic<let N : int> +struct linearvector : MyLinearArithmeticType, IDifferentiable +{ + typedef lineardvector<N> Differential; + + myvector<Real, N> val; + + [__unsafeForceInlineEarly] + static Ptr<myvector<Real, N>.Differential> __getDifferentialFor_val(inout Differential dvec) + { + return &(dvec.val); + } + + static void __setDifferentialForVal(lineardvector<N> dvec, myvector<Real, N>.Differential v) + { + dvec.val = v; + } + + static __differentiate_jvp linearvector<N> ladd(linearvector<N> a, linearvector<N> b) + { + return linearvector<N>(a.val + b.val); + } + + static __differentiate_jvp linearvector<N> lmul(linearvector<N> a, linearvector<N> b) + { + return linearvector<N>(a.val * b.val); + } + + static __differentiate_jvp linearvector<N> lscale(float a, linearvector<N> b) + { + return linearvector<N>(a * b.val); + } + + static __differentiate_jvp float ldot(linearvector<N> a, linearvector<N> b) + { + return dot(a.val, b.val); + } + + static Differential zero() + { + lineardvector<N> dout; + dout.val = myvector<Real, N>.zero(); + return dout; + } + + static Differential dadd(Differential a, Differential b) + { + return { myvector<Real, N>.dadd(a.val, b.val) }; + } + + static Differential dmul(This a, Differential b) + { + return { myvector<Real, N>.dmul(a.val, b.val) }; + } + + __differentiate_jvp __init(vector<Real, N> a) + { + for (int i = 0; i < N; i++) + { + val.values[i] = a[i]; + } + } + + __differentiate_jvp __init(myvector<Real, N> a) + { + val = a; + } +}; + +typedef linearvector<3> myfloat3; +typedef linearvector<4> myfloat4; + +typedef lineardvector<3> mydfloat3; +typedef lineardvector<4> mydfloat4; + +typedef __DifferentialPair<Real> dpfloat; + +interface MyLinearArithmeticType +{ + static This ladd(This a, This b); + static This lmul(This a, This b); + static This lscale(Real a, This b); + static Real ldot(This a, This b); +}; + +typedef __DifferentialPair<myfloat4> dpfloat4; +typedef __DifferentialPair<myfloat3> dpfloat3; + +extension float : MyLinearArithmeticType +{ + static __differentiate_jvp float ladd(float a, float b) + { + return a + b; + } + + static __differentiate_jvp float lmul(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float lscale(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float ldot(float a, float b) + { + return a * b; + } +}; + +typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator +(T a, T b) +{ + return T.ladd(a, b); +} + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator *(T a, T b) +{ + return T.lmul(a, b); +} + +__generic<G : MyLinearArithmeticDifferentiableType> +__differentiate_jvp G f(G x) +{ + G a = x + x; + G b = x * x; + + return a * a + G.lscale((Real)3.0, x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), mydfloat4(float4(0.5, 0.8, 1.6, 2.5))); + dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5))); + + outputBuffer[0] = f(dpa.p()); // Expect: 22.0 + outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __jvp(f)(dpf4).d().val.values[3]; // Expect: 27.5 + outputBuffer[3] = __jvp(f)(dpf3).d().val.values[1]; // Expect: 40.5 + } +} diff --git a/tests/autodiff/generic-impl-jvp.slang.expected.txt b/tests/autodiff/generic-impl-jvp.slang.expected.txt new file mode 100644 index 000000000..ceeaf120e --- /dev/null +++ b/tests/autodiff/generic-impl-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +22.000000 +9.500000 +27.500000 +40.500000 +0.000000 diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 48993c21c..54a99cae9 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -1,30 +1,202 @@ -//TEST_IGNORE_FILE:(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST_IGNORE_FILE:(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; -typedef __DifferentialPair<double> dpdouble; -typedef __DifferentialPair<float3> dpfloat3; +typedef float Real; -__generic<T:__BuiltinArithmeticType> -__differentiate_jvp T g(T x) +__generic<let N : int> +struct myvector { - return x + x; + vector<Real, N> val; } +extension myvector<3> : MyLinearArithmeticType +{ + static __differentiate_jvp myvector<3> ladd(myvector<3> a, myvector<3> b) + { + return myvector<3>(a.val + b.val); + } + + static __differentiate_jvp myvector<3> lmul(myvector<3> a, myvector<3> b) + { + return myvector<3>(a.val * b.val); + } + + static __differentiate_jvp myvector<3> lscale(float a, myvector<3> b) + { + return myvector<3>(a * b.val); + } + + static __differentiate_jvp float ldot(myvector<3> a, myvector<3> b) + { + return dot(a.val, b.val); + } + + __differentiate_jvp __init(vector<Real, 3> a) + { + val = a; + } +}; + + +extension myvector<4> : MyLinearArithmeticType +{ + static __differentiate_jvp myvector<4> ladd(myvector<4> a, myvector<4> b) + { + return myvector<4>(a.val + b.val); + } + + static __differentiate_jvp myvector<4> lmul(myvector<4> a, myvector<4> b) + { + return myvector<4>(a.val * b.val); + } + + static __differentiate_jvp myvector<4> lscale(float a, myvector<4> b) + { + return myvector<4>(a * b.val); + } + + static __differentiate_jvp float ldot(myvector<4> a, myvector<4> b) + { + return dot(a.val, b.val); + } + + __differentiate_jvp __init(vector<Real, 4> a) + { + val = a; + } + +}; + +typedef myvector<3> myfloat3; +typedef myvector<4> myfloat4; + +typedef __DifferentialPair<Real> dpfloat; + +interface MyLinearArithmeticType +{ + static This ladd(This a, This b); + static This lmul(This a, This b); + static This lscale(Real a, This b); + static Real ldot(This a, This b); +}; + +extension myfloat3 : IDifferentiable +{ + typedef myfloat3 Differential; + + [__unsafeForceInlineEarly] + static Ptr<float3> __getDifferentialFor_val(inout Differential dx) + { + return &(dx.val); + } + + static Differential zero() + { + return myfloat3(0); + } + + static __differentiate_jvp Differential dadd(Differential a, Differential b) + { + return a + b; + } + + static __differentiate_jvp Differential dmul(Differential a, Differential b) + { + return a * b; + } + +}; + +extension myfloat4 : IDifferentiable +{ + typedef myfloat4 Differential; + + [__unsafeForceInlineEarly] + static Ptr<float4> __getDifferentialFor_val(inout Differential dx) + { + return &(dx.val); + } + + static Differential zero() + { + return myfloat4(0); + } + + static __differentiate_jvp Differential dadd(Differential a, Differential b) + { + return a + b; + } + + static __differentiate_jvp Differential dmul(Differential a, Differential b) + { + return a * b; + } +}; + +typedef __DifferentialPair<myfloat4> dpfloat4; +typedef __DifferentialPair<myfloat3> dpfloat3; + +extension float : MyLinearArithmeticType +{ + static __differentiate_jvp float ladd(float a, float b) + { + return a + b; + } + + static __differentiate_jvp float lmul(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float lscale(float a, float b) + { + return a * b; + } + + static __differentiate_jvp float ldot(float a, float b) + { + return a * b; + } +}; + +typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator +(T a, T b) +{ + return T.ladd(a, b); +} + +__generic<T : MyLinearArithmeticDifferentiableType> +__differentiate_jvp T operator *(T a, T b) +{ + return T.lmul(a, b); +} + +__generic<G : MyLinearArithmeticDifferentiableType> +__differentiate_jvp G f(G x) +{ + G a = x + x; + G b = x * x; + + return a * a + G.lscale((Real)3.0, x); +} + + [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { dpfloat dpa = dpfloat(2.0, 1.0); - dpdouble dpb = dpdouble(1.5, 2.0); - dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); + dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), myfloat4(float4(0.5, 0.8, 1.6, 2.5))); + dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5))); - outputBuffer[0] = f(dpa.p()); // Expect: 1 - outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.0)).d(); // Expect: 0 - outputBuffer[2] = (float)__jvp(f)(dpb).d(); // Expect: 2 - outputBuffer[3] = __jvp(f)(dpf3).d().y; // Expect: 1.5 + outputBuffer[0] = f(dpa.p()); // Expect: 22.0 + outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 + outputBuffer[2] = __jvp(f)(dpf4).d().val.w; // Expect: 27.5 + outputBuffer[3] = __jvp(f)(dpf3).d().val.y; // Expect: 40.5 } } diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt new file mode 100644 index 000000000..ceeaf120e --- /dev/null +++ b/tests/autodiff/generic-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +22.000000 +9.500000 +27.500000 +40.500000 +0.000000 diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang new file mode 100644 index 000000000..61cb96a07 --- /dev/null +++ b/tests/autodiff/getter-setter-multi.slang @@ -0,0 +1,83 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B +{ + float3 z; + float.Differential k[10]; +}; + +struct A : IDifferentiable +{ + typedef B Differential; + + float3 x; + float y[10]; + + [__unsafeForceInlineEarly] + static Ptr<float3.Differential> __getDifferentialFor_x(inout Differential b) + { + return &(b.z); + } + + [__unsafeForceInlineEarly] + static Ptr<float.Differential[10]> __getDifferentialFor_y(inout Differential b) + { + return &(b.k); + } + + [__unsafeForceInlineEarly] + static Differential zero() + { + B b = {0.0}; + return b; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + B o = {a.z + b.z}; + return o; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + B o = {a.x * b.z}; + return o; + } +}; + +typedef __DifferentialPair<A> dpA; + +__differentiate_jvp A f(A a) +{ + A aout; + + aout.y[5] = (2 * a.x).y; + aout.y[2] = (3 * a.y[4]); + aout.x = float3(5 * a.x.z, 3 * a.x.y, 0.5 * a.x.x); + + return aout; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + float arr[10] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; + A a = {float3(1.0, 2.0, 3.0), arr}; + + float d_arr[10] = { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0 }; + B b = {float3(1.0, 0.5, 0.3), d_arr}; + + dpA dpa = dpA(a, b); + + outputBuffer[0] = __jvp(f)(dpa).d().z.z; // Expect: 0.5 + outputBuffer[1] = __jvp(f)(dpa).d().k[5]; // Expect: 1 + outputBuffer[2] = __jvp(f)(dpa).d().k[2]; // Expect: 1.5 + } +} diff --git a/tests/autodiff/getter-setter-multi.slang.expected.txt b/tests/autodiff/getter-setter-multi.slang.expected.txt new file mode 100644 index 000000000..ece9872b0 --- /dev/null +++ b/tests/autodiff/getter-setter-multi.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.500000 +1.000000 +1.500000 +0.000000 +0.000000 diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang new file mode 100644 index 000000000..6b280433b --- /dev/null +++ b/tests/autodiff/getter-setter.slang @@ -0,0 +1,69 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B +{ + float z; +}; + +struct A : IDifferentiable +{ + typedef B Differential; + + float x; + float y; + + [__unsafeForceInlineEarly] + static Ptr<float.Differential> __getDifferentialFor_x(inout Differential b) + { + return &(b.z); + } + + [__unsafeForceInlineEarly] + static Differential zero() + { + B b = {0.0}; + return b; + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + B o = {a.z + b.z}; + return o; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + B o = {a.x * b.z}; + return o; + } +}; + +typedef __DifferentialPair<A> dpA; + +__differentiate_jvp A f(A a) +{ + A aout; + aout.y = 2 * a.x; + aout.x = 5 * a.x; + + return aout; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + A a = {1.0, 2.0}; + B b = {0.2}; + + dpA dpa = dpA(a, b); + + outputBuffer[0] = __jvp(f)(dpa).d().z; // Expect: 1 + } +} diff --git a/tests/autodiff/getter-setter.slang.expected.txt b/tests/autodiff/getter-setter.slang.expected.txt new file mode 100644 index 000000000..ca54c9afe --- /dev/null +++ b/tests/autodiff/getter-setter.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +1.000000 +0.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang new file mode 100644 index 000000000..ee8bdf51d --- /dev/null +++ b/tests/autodiff/imported-custom-jvp.slang @@ -0,0 +1,25 @@ +//TEST_IGNORE_FILE: + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +import test_intrinsics_jvp; + +typedef __DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +__differentiate_jvp float f(float x) +{ + return pow_(x, 2.0); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat dpb = dpfloat(1.5, 1.0); + + outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 2 + } +} diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang new file mode 100644 index 000000000..26b5c0076 --- /dev/null +++ b/tests/autodiff/overloads-jvp.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef __DifferentialPair<float> dpfloat; +typedef __DifferentialPair<float3> dpfloat3; + +__differentiate_jvp float f(float a) +{ + return a * a + a; +} + +__differentiate_jvp float f(float3 a) +{ + return a.x * a.y + a.z; +} + +__differentiate_jvp float g(float a) +{ + // df((2.0, 4.0, 6.0), (1.0, 2.0, 3.0)) + // 2.0 * 2.0 + 4.0 * 1.0 + 3.0 = 11.0 + return f(float3(a, 2*a, 3*a)); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); + + outputBuffer[0] = f(dpa.p()); // Expect: 6 + outputBuffer[1] = f(dpf3.p()); // Expect: 8 + outputBuffer[2] = __jvp(f)(dpf3).d(); // Expect: 5.5 + outputBuffer[3] = __jvp(f)(dpa).d(); // Expect: 5 + outputBuffer[4] = __jvp(g)(dpa).d(); // Expect: 11.0 + } +} diff --git a/tests/autodiff/overloads-jvp.slang.expected.txt b/tests/autodiff/overloads-jvp.slang.expected.txt new file mode 100644 index 000000000..999777e1e --- /dev/null +++ b/tests/autodiff/overloads-jvp.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +6.0 +8.0 +5.5 +5.0 +11.0
\ No newline at end of file diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang index 333c89189..cb4c5c6b4 100644 --- a/tests/autodiff/test-intrinsics-jvp.slang +++ b/tests/autodiff/test-intrinsics-jvp.slang @@ -14,4 +14,5 @@ float max_(float x, float y); float max_jvp(float x, float y, float dx, float dy) { return (x > y) ? dx : dy; -}
\ No newline at end of file +} + diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index 393cc18ec..e05d94733 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -37,11 +37,10 @@ __differentiate_jvp float4 j(float4 x, float4 y) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - { + { float3 a = float3(2.0, 2.0, 2.0); float3 b = float3(1.5, 1.5, 1.5); float3 da = float3(1.0, 1.0, 1.0); - //dpfloat3 dpa = dpfloat3(a, da); float2 a2 = float2(2.0, 1.0); float2 b2 = float2(1.5, -2.0); |
