diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-04 17:06:29 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-04 17:06:29 -0700 |
| commit | a20f6a03062d72135ae046319c378709fe2a8df6 (patch) | |
| tree | c4a5848e126527c61b5533dd7838ff16d33dbe42 | |
| parent | c6e6b7a9177bf4f7fc2f05da36c5952979006d78 (diff) | |
Use property for `DifferentialPair` accessors. (#2493)
21 files changed, 116 insertions, 104 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ae4db603e..2625d79b0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -70,22 +70,34 @@ struct DifferentialPair : IDifferentiable __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); - __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) - T.Differential d(); + property p : T + { + __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) + get; + } [__unsafeForceInlineEarly] T.Differential getDifferential() { - return d(); + return d; } - __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) - T p(); - [__unsafeForceInlineEarly] T getPrimal() { - return p(); + return p; } [__unsafeForceInlineEarly] @@ -99,18 +111,18 @@ struct DifferentialPair : IDifferentiable { return Differential( T.dadd( - a.p(), - b.p() + a.p, + b.p ), - T.Differential.dadd(a.d(), b.d())); + T.Differential.dadd(a.d, b.d)); } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { return Differential( - T.dmul(a.p(), b.p()), - T.Differential.dmul(a.d(), b.d())); + T.dmul(a.p, b.p), + T.Differential.dmul(a.d, b.d)); } }; @@ -135,8 +147,8 @@ namespace dstd DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { return DifferentialPair<T>( - exp(dpx.p()), - T.dmul(exp(dpx.p()), dpx.d())); + exp(dpx.p), + T.dmul(exp(dpx.p), dpx.d)); } // Sine @@ -153,8 +165,8 @@ namespace dstd DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { return DifferentialPair<T>( - sin(dpx.p()), - T.dmul(cos(dpx.p()), dpx.d())); + sin(dpx.p), + T.dmul(cos(dpx.p), dpx.d)); } // Cosine @@ -171,8 +183,8 @@ namespace dstd DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { return DifferentialPair<T>( - cos(dpx.p()), - T.dmul(-sin(dpx.p()), dpx.d())); + cos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); } __generic<let N : int> @@ -192,9 +204,9 @@ namespace dstd vector<float, N>.Differential d_result; for(int i = 0; i < N; ++i) { - DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p()[i], dpx.d()[i])); - result[i] = dpexp.p(); - d_result[i] = dpexp.d(); + DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p[i], dpx.d[i])); + result[i] = dpexp.p; + d_result[i] = dpexp.d; } return DifferentialPair<vector<float, N>>(result, d_result); diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index 134741d4d..70b35c244 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -15,7 +15,7 @@ float f(float x) dpfloat g_jvp_(dpfloat dpx) { - return dpfloat(dpx.p(), 2 * dpx.d()); + return dpfloat(dpx.p, 2 * dpx.d); } [ForwardDerivative(g_jvp_)] @@ -46,10 +46,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); dpfloat dpb = dpfloat(1.5, 1.0); - outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 1 - outputBuffer[1] = __fwd_diff(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0 - outputBuffer[2] = __fwd_diff(g)(dpa).d(); // Expect: 2 - outputBuffer[3] = __fwd_diff(h)(dpa, dpb).d(); // Expect: 8 - outputBuffer[4] = __fwd_diff(j)(dpa, dpb).d(); // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d; // Expect: 1 + outputBuffer[1] = __fwd_diff(f)(dpfloat(dpa.p, 0.0)).d; // Expect: 0 + outputBuffer[2] = __fwd_diff(g)(dpa).d; // Expect: 2 + outputBuffer[3] = __fwd_diff(h)(dpa, dpb).d; // Expect: 8 + outputBuffer[4] = __fwd_diff(j)(dpa, dpb).d; // Expect: 1 } } diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index a4d0b6d89..efeebb459 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -54,6 +54,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d().x; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d.x; // Expect: 1 } } diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index ce6c1024f..8048c60ff 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -23,8 +23,8 @@ namespace myintrinsiclib DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { return DifferentialPair<T>( - exp(dpx.p()), - T.dmul(exp(dpx.p()), dpx.d())); + exp(dpx.p), + T.dmul(exp(dpx.p), dpx.d)); } @@ -42,8 +42,8 @@ namespace myintrinsiclib DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { return DifferentialPair<T>( - sin(dpx.p()), - T.dmul(cos(dpx.p()), dpx.d())); + sin(dpx.p), + T.dmul(cos(dpx.p), dpx.d)); } // Cosine @@ -60,8 +60,8 @@ namespace myintrinsiclib DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { return DifferentialPair<T>( - cos(dpx.p()), - T.dmul(-sin(dpx.p()), dpx.d())); + cos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); } // Sine and cosine @@ -80,10 +80,10 @@ namespace myintrinsiclib { T _s; T _c; - sincos(x.p(), _s, _c); + sincos(x.p, _s, _c); - s = DifferentialPair<T>(_s, T.dmul(_c, x.d())); - c = DifferentialPair<T>(_c, T.dmul(-_s, x.d())); + s = DifferentialPair<T>(_s, T.dmul(_c, x.d)); + c = DifferentialPair<T>(_c, T.dmul(-_s, x.d)); } }; @@ -109,13 +109,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat dpa = dpfloat(2.0, 1.0); - outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 - outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056 + outputBuffer[0] = f(dpa.p); // Expect: 7.389056 + outputBuffer[1] = __fwd_diff(f)(dpa).d; // Expect: 7.389056 // g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to // generate a new var, load from the individual vars and store into the pair var. - //outputBuffer[2] = g(dpa.p()); // Expect: 1.381773 - //outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.301168 + //outputBuffer[2] = g(dpa.p); // Expect: 1.381773 + //outputBuffer[3] = __fwd_diff(g)(dpa).d; // Expect: -0.301168 } }
\ No newline at end of file diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 433342b52..4c96779f9 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -41,7 +41,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) A a = {1.0, 2.0}; A.Differential b = {0.2}; dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d().b.x; // Expect: 0 + outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0 outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2 } diff --git a/tests/autodiff/dstdlib-vector.slang b/tests/autodiff/dstdlib-vector.slang index 44315daf6..1a1bd0dfa 100644 --- a/tests/autodiff/dstdlib-vector.slang +++ b/tests/autodiff/dstdlib-vector.slang @@ -20,7 +20,7 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat dpa = dpfloat(2.0, 1.0); - outputBuffer[0] = f(dpa.p()); // Expect: 465.415999 - outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 1326.871736 + outputBuffer[0] = f(dpa.p); // Expect: 465.415999 + outputBuffer[1] = __fwd_diff(f)(dpa).d; // Expect: 1326.871736 } }
\ No newline at end of file diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang index aef59d445..247200511 100644 --- a/tests/autodiff/dstdlib.slang +++ b/tests/autodiff/dstdlib.slang @@ -30,11 +30,11 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) { dpfloat dpa = dpfloat(2.0, 1.0); - outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 - outputBuffer[1] = __fwd_diff(f)(dpa).d(); // Expect: 7.389056 - outputBuffer[2] = g(dpa.p()); // Expect: 0.909297 - outputBuffer[3] = __fwd_diff(g)(dpa).d(); // Expect: -0.416146 - outputBuffer[4] = h(dpa.p()); // Expect: -0.416146 - outputBuffer[5] = __fwd_diff(h)(dpa).d(); // Expect: -0.909297 + outputBuffer[0] = f(dpa.p); // Expect: 7.389056 + outputBuffer[1] = __fwd_diff(f)(dpa).d; // Expect: 7.389056 + outputBuffer[2] = g(dpa.p); // Expect: 0.909297 + outputBuffer[3] = __fwd_diff(g)(dpa).d; // Expect: -0.416146 + outputBuffer[4] = h(dpa.p); // Expect: -0.416146 + outputBuffer[5] = __fwd_diff(h)(dpa).d; // Expect: -0.909297 } }
\ No newline at end of file diff --git a/tests/autodiff/generic-custom-jvp.slang b/tests/autodiff/generic-custom-jvp.slang index 47e646ccc..6e9e863bb 100644 --- a/tests/autodiff/generic-custom-jvp.slang +++ b/tests/autodiff/generic-custom-jvp.slang @@ -13,8 +13,8 @@ import test_intrinsics; dpfloat my_pow_jvp(dpfloat x, dpfloat n) { return dpfloat( - pow(x.p(), n.p()), - x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); + pow(x.p, n.p), + x.d * n.p * pow(x.p, n.p-1) + n.d * pow(x.p, n.p) * log(x.p)); } [ForwardDerivative(my_pow_jvp)] @@ -27,9 +27,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(5.0, 1.0); dpfloat dpn = dpfloat(2, 0.0); - outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d; // Expect: 10.0 outputBuffer[1] = __fwd_diff(_pow)( - dpfloat(dpa.p(), 0.0), - dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 + dpfloat(dpa.p, 0.0), + dpfloat(dpn.p, 1.0)).d; // Expect: 40.23595 } } diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 7f4c4313e..a1bc18252 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -127,12 +127,12 @@ DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) T curr_p = (T)0.0; for (int i = 0; i < N; i++) { - curr_p = curr_p + (a.p().values[i] * b.p().values[i]); + curr_p = curr_p + (a.p.values[i] * b.p.values[i]); curr_d = T.dadd( curr_d, T.dadd( - T.dmul(a.p().values[i], b.d().values[i]), - T.dmul(b.p().values[i], a.d().values[i]))); + T.dmul(a.p.values[i], b.d.values[i]), + T.dmul(b.p.values[i], a.d.values[i]))); } return DifferentialPair<T>(curr_p, curr_d); @@ -298,9 +298,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), mydfloat4(float4(0.5, 0.8, 1.6, 2.5))); dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5))); - outputBuffer[0] = f(dpa.p()); // Expect: 22.0 - outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 - outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.values[3]; // Expect: 27.5 - outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.values[1]; // Expect: 40.5 + outputBuffer[0] = f(dpa.p); // Expect: 22.0 + outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d; // Expect: 9.5 + outputBuffer[2] = __fwd_diff(f)(dpf4).d.val.values[3]; // Expect: 27.5 + outputBuffer[3] = __fwd_diff(f)(dpf3).d.val.values[1]; // Expect: 40.5 } } diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 78a292251..61ec077f4 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -209,9 +209,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), myfloat4(float4(0.5, 0.8, 1.6, 2.5))); dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), myfloat3(float3(0.5, 1.5, 2.5))); - outputBuffer[0] = f(dpa.p()); // Expect: 22.0 - outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d(); // Expect: 9.5 - outputBuffer[2] = __fwd_diff(f)(dpf4).d().val.w; // Expect: 27.5 - outputBuffer[3] = __fwd_diff(f)(dpf3).d().val.y; // Expect: 40.5 + outputBuffer[0] = f(dpa.p); // Expect: 22.0 + outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d; // Expect: 9.5 + outputBuffer[2] = __fwd_diff(f)(dpf4).d.val.w; // Expect: 27.5 + outputBuffer[3] = __fwd_diff(f)(dpf3).d.val.y; // Expect: 40.5 } } diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 85b6a3c63..9055e860a 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -67,8 +67,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d().z.z; // Expect: 0.5 - outputBuffer[1] = __fwd_diff(f)(dpa).d().k[5]; // Expect: 1 - outputBuffer[2] = __fwd_diff(f)(dpa).d().k[2]; // Expect: 1.5 + outputBuffer[0] = __fwd_diff(f)(dpa).d.z.z; // Expect: 0.5 + outputBuffer[1] = __fwd_diff(f)(dpa).d.k[5]; // Expect: 1 + outputBuffer[2] = __fwd_diff(f)(dpa).d.k[2]; // Expect: 1.5 } } diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index a9e01b8c6..705604bbb 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -60,6 +60,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d().z; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpa).d.z; // Expect: 1 } } diff --git a/tests/autodiff/high-order-forward-diff.slang b/tests/autodiff/high-order-forward-diff.slang index 94b4d2a0d..be0029419 100644 --- a/tests/autodiff/high-order-forward-diff.slang +++ b/tests/autodiff/high-order-forward-diff.slang @@ -19,7 +19,7 @@ float f(float x) [ForwardDifferentiable] float df(float x) { - return __fwd_diff(f)(DifferentialPair<float>(x, 1.0)).d(); + return __fwd_diff(f)(DifferentialPair<float>(x, 1.0)).d; } [numthreads(1, 1, 1)] @@ -28,5 +28,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) // Given f(x) = x^4, // f''(x) = 12 * x^2 // Expect f''(4) = 192 - outputBuffer[0] = __fwd_diff(df)(DifferentialPair<float>(4.0, 1.0)).d(); + outputBuffer[0] = __fwd_diff(df)(DifferentialPair<float>(4.0, 1.0)).d; } diff --git a/tests/autodiff/imported-custom-jvp.slang b/tests/autodiff/imported-custom-jvp.slang index b4a2a9d5f..92f2ff89c 100644 --- a/tests/autodiff/imported-custom-jvp.slang +++ b/tests/autodiff/imported-custom-jvp.slang @@ -21,6 +21,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); dpfloat dpb = dpfloat(1.5, 1.0); - outputBuffer[0] = __fwd_diff(f)(dpa).d(); // Expect: 2 + outputBuffer[0] = __fwd_diff(f)(dpa).d; // Expect: 2 } } diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang index 9935a9c59..e5720d412 100644 --- a/tests/autodiff/inout-parameters-jvp.slang +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -37,12 +37,12 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dpz); - outputBuffer[0] = dpz.d(); // Expect: 12.0 - outputBuffer[1] = dpz.p(); // Expect: 6.75 + outputBuffer[0] = dpz.d; // Expect: 12.0 + outputBuffer[1] = dpz.p; // Expect: 6.75 __fwd_diff(g)(dpfloat(x, dx), dpfloat(y, dy), dpz); - outputBuffer[2] = dpz.d(); // Expect: 21.5 - outputBuffer[3] = dpz.p(); // Expect: 12.5 + outputBuffer[2] = dpz.d; // Expect: 21.5 + outputBuffer[3] = dpz.p; // Expect: 12.5 }
\ No newline at end of file diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 450e03b5d..7cf5d64e5 100644 --- a/tests/autodiff/local-redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang @@ -11,8 +11,8 @@ import test_intrinsics; dpfloat my_pow_jvp(dpfloat x, dpfloat n) { return dpfloat( - pow(x.p(), n.p()), - x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); + pow(x.p, n.p), + x.d * n.p * pow(x.p, n.p-1) + n.d * pow(x.p, n.p) * log(x.p)); } [ForwardDerivative(my_pow_jvp)] @@ -25,9 +25,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(5.0, 1.0); dpfloat dpn = dpfloat(2, 0.0); - outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[0] = __fwd_diff(_pow)(dpa, dpn).d; // Expect: 10.0 outputBuffer[1] = __fwd_diff(_pow)( - dpfloat(dpa.p(), 0.0), - dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 + dpfloat(dpa.p, 0.0), + dpfloat(dpn.p, 1.0)).d; // Expect: 40.23595 } } diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index a66adaf8e..09a55a88d 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -22,16 +22,16 @@ float max_(float x, float y) dpfloat pow_jvp(dpfloat x, dpfloat n) { return dpfloat( - pow(x.p(), n.p()), - x.d() * n.p() * pow(x.p(), n.p()-1) + - ((n.d() != 0.0) ? (n.d() * pow(x.p(), n.p()) * log(x.p())) : 0.0)); + pow(x.p, n.p), + x.d * n.p * pow(x.p, n.p-1) + + ((n.d != 0.0) ? (n.d * pow(x.p, n.p) * log(x.p)) : 0.0)); } dpfloat max_jvp(dpfloat x, dpfloat y) { return dpfloat( - max(x.p(), y.p()), - (x.p() > y.p()) ? x.d() : y.d()); + max(x.p, y.p), + (x.p > y.p) ? x.d : y.d); } @@ -63,7 +63,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[0] = __fwd_diff(fresnel)( dpfloat3(f0, d_f0), dpfloat3(f90, d_f90), - dpfloat(cosTheta, d_cosTheta)).d().y; // Expect: -0.031250 + dpfloat(cosTheta, d_cosTheta)).d.y; // Expect: -0.031250 float a = 1.0; float b = -0.4; @@ -76,13 +76,13 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[1] = __fwd_diff(g)( dpfloat(a, da), dpfloat(b, db), - dpfloat(c, dc)).d(); // Expect: -0.24375 + dpfloat(c, dc)).d; // Expect: -0.24375 outputBuffer[2] = g(a, b, c); // Expect: 0.95625 outputBuffer[3] = __fwd_diff(g)( dpfloat(a, da), dpfloat(b, db), - dpfloat(3.0, dc)).d(); // Expect: -0.4; + dpfloat(3.0, dc)).d; // Expect: -0.4; } } diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang index 31419489c..4faf18555 100644 --- a/tests/autodiff/out-parameters-jvp.slang +++ b/tests/autodiff/out-parameters-jvp.slang @@ -26,6 +26,6 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dresult; __fwd_diff(h)(dpfloat(x, dx), dpfloat(y, dy), dresult); - outputBuffer[0] = dresult.d(); // Expect: 9.5 + outputBuffer[0] = dresult.d; // Expect: 9.5 }
\ No newline at end of file diff --git a/tests/autodiff/overloads-jvp.slang b/tests/autodiff/overloads-jvp.slang index 730fe6e2d..46dc94ee7 100644 --- a/tests/autodiff/overloads-jvp.slang +++ b/tests/autodiff/overloads-jvp.slang @@ -34,10 +34,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(2.0, 1.0); dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); - outputBuffer[0] = f(dpa.p()); // Expect: 6 - outputBuffer[1] = f(dpf3.p()); // Expect: 8 - outputBuffer[2] = __fwd_diff(f)(dpf3).d(); // Expect: 5.5 - outputBuffer[3] = __fwd_diff(f)(dpa).d(); // Expect: 5 - outputBuffer[4] = __fwd_diff(g)(dpa).d(); // Expect: 11.0 + outputBuffer[0] = f(dpa.p); // Expect: 6 + outputBuffer[1] = f(dpf3.p); // Expect: 8 + outputBuffer[2] = __fwd_diff(f)(dpf3).d; // Expect: 5.5 + outputBuffer[3] = __fwd_diff(f)(dpa).d; // Expect: 5 + outputBuffer[4] = __fwd_diff(g)(dpa).d; // Expect: 11.0 } } diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index 62f6e2d50..90e2ceca6 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -52,18 +52,18 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float4 a4 = float4(2.0, 1.0, 0.0, 2.0); float4 b4 = float4(1.5, -2.0, 1.0, 1.5); - outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().z; // Expect: 1 + outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d.z; // Expect: 1 outputBuffer[1] = __fwd_diff(g)( dpfloat3(a, da), - dpfloat3(b, float3(2.0, 1.0, 0.0))).d().y; // Expect: 8 + dpfloat3(b, float3(2.0, 1.0, 0.0))).d.y; // Expect: 8 outputBuffer[2] = __fwd_diff(h)( dpfloat2(a2, float2(1.0, 0.0)), - dpfloat2(b2, float2(1.0, 1.0))).d().x; // Expect: 8 + dpfloat2(b2, float2(1.0, 1.0))).d.x; // Expect: 8 outputBuffer[3] = __fwd_diff(j)( dpfloat4(a4, float4(1.0)), - dpfloat4(b4, float4(2.0))).d().w; // Expect: 9 + dpfloat4(b4, float4(2.0))).d.w; // Expect: 9 } } diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang index 99f305425..1bbf94bfc 100644 --- a/tests/autodiff/vector-swizzle-jvp.slang +++ b/tests/autodiff/vector-swizzle-jvp.slang @@ -29,16 +29,16 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) float3 a = float3(2.0, 2.0, 2.0); float3 da = float3(1.0, 0.5, 1.0); - outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d().x; // Expect: 1 - outputBuffer[1] = __fwd_diff(f)(dpfloat3(a, da)).d().y; // Expect: 0.5 + outputBuffer[0] = __fwd_diff(f)(dpfloat3(a, da)).d.x; // Expect: 1 + outputBuffer[1] = __fwd_diff(f)(dpfloat3(a, da)).d.y; // Expect: 0.5 float3 x = float3(0.5, 2.0, 0.5); float4 y = float4(-1.5, 1.0, 4.0, 2.0); float3 dx = float3(1.0, 0.0, -1.0); float4 dy = float4(0.0, 0.5, -0.25, 1.0); - outputBuffer[2] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25 - outputBuffer[3] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5 + outputBuffer[2] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d.x; // Expect: -2.25 + outputBuffer[3] = __fwd_diff(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d.y; // Expect: 0.5 } } |
