diff options
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/arithmetic-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/auto-differential-type.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic.slang | 20 | ||||
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/dstdlib.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/generic-custom-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 18 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 8 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/high-order-forward-diff.slang | 23 | ||||
| -rw-r--r-- | tests/autodiff/imported-custom-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/inout-parameters-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/local-redecl-custom-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/nested-jvp.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/out-parameters-jvp.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/overloads-jvp.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/vector-arithmetic-jvp.slang | 6 | ||||
| -rw-r--r-- | tests/autodiff/vector-swizzle-jvp.slang | 6 |
19 files changed, 69 insertions, 46 deletions
diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index ec2c5bc6f..134741d4d 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; [ForwardDifferentiable] diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index 57dd3cb10..a4d0b6d89 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -33,7 +33,7 @@ struct A : IDifferentiable } }; -typedef __DifferentialPair<A> dpA; +typedef DifferentialPair<A> dpA; [ForwardDifferentiable] A f(A a) diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 7591cd624..ce6c1024f 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; typealias IDFloat = IFloat & IDifferentiable; @@ -20,9 +20,9 @@ namespace myintrinsiclib T exp(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( exp(dpx.p()), T.dmul(exp(dpx.p()), dpx.d())); } @@ -39,9 +39,9 @@ namespace myintrinsiclib T sin(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( sin(dpx.p()), T.dmul(cos(dpx.p()), dpx.d())); } @@ -57,9 +57,9 @@ namespace myintrinsiclib T cos(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } @@ -76,14 +76,14 @@ namespace myintrinsiclib } __generic<T : IDFloat> - void d_sincos(__DifferentialPair<T> x, out __DifferentialPair<T> s, out __DifferentialPair<T> c) + void d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c) { T _s; T _c; sincos(x.p(), _s, _c); - s = __DifferentialPair<T>(_s, T.dmul(_c, x.d())); - c = __DifferentialPair<T>(_c, T.dmul(-_s, x.d())); + s = DifferentialPair<T>(_s, T.dmul(_c, x.d())); + c = DifferentialPair<T>(_c, T.dmul(-_s, x.d())); } }; diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 76c15d5a1..433342b52 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -17,7 +17,7 @@ struct A : IDifferentiable float y; }; -typedef __DifferentialPair<A> dpA; +typedef DifferentialPair<A> dpA; A nonDiff(A a) { diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang index 05d915c29..aef59d445 100644 --- a/tests/autodiff/dstdlib.slang +++ b/tests/autodiff/dstdlib.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; [ForwardDifferentiable] float f(float x) diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang index 5111f0e48..47e646ccc 100644 --- a/tests/autodiff/generic-custom-jvp.slang +++ b/tests/autodiff/generic-custom-jvp.slang @@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer; typealias IDFloat = IFloat & IDifferentiable; __generic<T : IDFloat> -typedef __DifferentialPair<T> dfloat; +typedef DifferentialPair<T> dfloat; import test_intrinsics; diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 8477aa68f..7f4c4313e 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -8,8 +8,8 @@ typedef float Real; typealias IDFloat = IFloat & IDifferentiable; -__generic<T, let N : int> -struct dvector +__generic<T : IDifferentiable, let N : int> +struct dvector : IDifferentiable { T values[N]; }; @@ -118,10 +118,10 @@ T dot(myvector<T, N> a, myvector<T, N> b) } __generic<T : IDFloat, let N : int> -typedef __DifferentialPair<myvector<T, N>> dpvector; +typedef DifferentialPair<myvector<T, N>> dpvector; __generic<T : IDFloat, let N : int> -__DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) +DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) { T.Differential curr_d = (T.dzero()); T curr_p = (T)0.0; @@ -135,11 +135,11 @@ __DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) T.dmul(b.p().values[i], a.d().values[i]))); } - return __DifferentialPair<T>(curr_p, curr_d); + return DifferentialPair<T>(curr_p, curr_d); } __generic<let N : int> -struct lineardvector +struct lineardvector : IDifferentiable { myvector<Real, N>.Differential val; @@ -223,7 +223,7 @@ typedef linearvector<4> myfloat4; typedef lineardvector<3> mydfloat3; typedef lineardvector<4> mydfloat4; -typedef __DifferentialPair<Real> dpfloat; +typedef DifferentialPair<Real> dpfloat; interface MyLinearArithmeticType { @@ -233,8 +233,8 @@ interface MyLinearArithmeticType static Real ldot(This a, This b); }; -typedef __DifferentialPair<myfloat4> dpfloat4; -typedef __DifferentialPair<myfloat3> dpfloat3; +typedef DifferentialPair<myfloat4> dpfloat4; +typedef DifferentialPair<myfloat3> dpfloat3; extension float : MyLinearArithmeticType { diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 9e0d56f0f..78a292251 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -39,7 +39,7 @@ extension myvector<3> : MyLinearArithmeticType } [ForwardDifferentiable] -__init(vector<Real, 3> a) + __init(vector<Real, 3> a) { val = a; } @@ -83,7 +83,7 @@ extension myvector<4> : MyLinearArithmeticType typedef myvector<3> myfloat3; typedef myvector<4> myfloat4; -typedef __DifferentialPair<Real> dpfloat; +typedef DifferentialPair<Real> dpfloat; interface MyLinearArithmeticType { @@ -144,8 +144,8 @@ extension myfloat4 : IDifferentiable } }; -typedef __DifferentialPair<myfloat4> dpfloat4; -typedef __DifferentialPair<myfloat3> dpfloat3; +typedef DifferentialPair<myfloat4> dpfloat4; +typedef DifferentialPair<myfloat3> dpfloat3; extension float : MyLinearArithmeticType { diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 217c475af..85b6a3c63 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -struct B +struct B : IDifferentiable { float3 z; float.Differential k[10]; @@ -41,7 +41,7 @@ struct A : IDifferentiable } }; -typedef __DifferentialPair<A> dpA; +typedef DifferentialPair<A> dpA; [ForwardDifferentiable] A f(A a) diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index ff4f81a42..a9e01b8c6 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -struct B +struct B : IDifferentiable { float z; }; @@ -39,7 +39,7 @@ struct A : IDifferentiable } }; -typedef __DifferentialPair<A> dpA; +typedef DifferentialPair<A> dpA; [ForwardDifferentiable] A f(A a) diff --git a/tests/autodiff/high-order-forward-diff.slang b/tests/autodiff/high-order-forward-diff.slang new file mode 100644 index 000000000..fde659227 --- /dev/null +++ b/tests/autodiff/high-order-forward-diff.slang @@ -0,0 +1,23 @@ +//DTEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//DTEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDifferentiable] +float f(float x) +{ + return x * x; +} + +[ForwardDifferentiable] +float df(float x) +{ + return __fwd_diff(f)(DifferentialPair<float>(x, 1.0)).d(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = __fwd_diff(df)(DifferentialPair<float>(1.0, 1.0)).d(); // Expect: 2.0 +} diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang index f5251740a..b4a2a9d5f 100644 --- a/tests/autodiff/imported-custom-jvp.slang +++ b/tests/autodiff/imported-custom-jvp.slang @@ -5,7 +5,7 @@ RWStructuredBuffer<float> outputBuffer; import test_intrinsics_jvp; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; [ForwardDifferentiable] diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang index ab4a3c790..9935a9c59 100644 --- a/tests/autodiff/inout-parameters-jvp.slang +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; [ForwardDifferentiable] void g(float x, float y, inout float z) diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 3a6b6f474..450e03b5d 100644 --- a/tests/autodiff/local-redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang @@ -3,7 +3,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; import test_intrinsics; diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index 0e7d19078..a66adaf8e 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -4,8 +4,8 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; -typedef __DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float> dpfloat; +typedef DifferentialPair<float3> dpfloat3; [ForwardDerivative(pow_jvp)] float pow_(float x, float n) diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang index 072c5158b..31419489c 100644 --- a/tests/autodiff/out-parameters-jvp.slang +++ b/tests/autodiff/out-parameters-jvp.slang @@ -4,7 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; +typedef DifferentialPair<float> dpfloat; [ForwardDifferentiable] void h(float x, float y, out float result) diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang index 2577009c3..730fe6e2d 100644 --- a/tests/autodiff/overloads-jvp.slang +++ b/tests/autodiff/overloads-jvp.slang @@ -4,8 +4,8 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float> dpfloat; -typedef __DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float> dpfloat; +typedef DifferentialPair<float3> dpfloat3; [ForwardDifferentiable] float f(float a) diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index cf0eb6170..62f6e2d50 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -4,9 +4,9 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float2> dpfloat2; -typedef __DifferentialPair<float3> dpfloat3; -typedef __DifferentialPair<float4> dpfloat4; +typedef DifferentialPair<float2> dpfloat2; +typedef DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float4> dpfloat4; [ForwardDifferentiable] float3 f(float3 x) diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang index f7a045b25..99f305425 100644 --- a/tests/autodiff/vector-swizzle-jvp.slang +++ b/tests/autodiff/vector-swizzle-jvp.slang @@ -4,9 +4,9 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; -typedef __DifferentialPair<float2> dpfloat2; -typedef __DifferentialPair<float3> dpfloat3; -typedef __DifferentialPair<float4> dpfloat4; +typedef DifferentialPair<float2> dpfloat2; +typedef DifferentialPair<float3> dpfloat3; +typedef DifferentialPair<float4> dpfloat4; [ForwardDifferentiable] float2 f(float3 x) |
