diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-26 17:15:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-26 17:15:21 -0400 |
| commit | ba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch) | |
| tree | 2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /tests/autodiff | |
| parent | b8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff) | |
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)`
- Add AST synthesis support for generic containers
- Refactor relevant tests
* Merge dmul synthesis with dadd and dzero, and disambiguate using an enum
* Fix trailing spaces
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/auto-differential-type.slang | 5 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic.slang | 56 | ||||
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang | 8 | ||||
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 18 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 8 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter-multi.slang | 4 | ||||
| -rw-r--r-- | tests/autodiff/getter-setter.slang | 4 |
8 files changed, 53 insertions, 52 deletions
diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index efeebb459..a253a25bb 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -26,9 +26,10 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + __generic<T : __BuiltinRealType> + static Differential dmul(T a, Differential b) { - Differential o = {a.x * b.x, 0.0}; + Differential o = { __realCast<float, T>(a * __realCast<T, float>(b.x)), 0.0}; return o; } }; diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 8048c60ff..dd122a674 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -6,81 +6,81 @@ RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; -typealias IDFloat = IFloat & IDifferentiable; +typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; namespace myintrinsiclib { __generic<T : IDFloat> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "exp($0)") + __target_intrinsic(glsl, "exp($0)") __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp<T>)] - T exp(T x); + [ForwardDerivative(d_myexp<T>)] + T myexp(T x); __generic<T : IDFloat> - DifferentialPair<T> d_exp(DifferentialPair<T> dpx) + DifferentialPair<T> d_myexp(DifferentialPair<T> dpx) { return DifferentialPair<T>( - exp(dpx.p), - T.dmul(exp(dpx.p), dpx.d)); + myexp(dpx.p), + T.dmul(myexp(dpx.p), dpx.d)); } // Sine __generic<T : IDFloat> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "sin($0)") + __target_intrinsic(glsl, "sin($0)") __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_sin<T>)] - T sin(T x); + [ForwardDerivative(d_mysin<T>)] + T mysin(T x); __generic<T : IDFloat> - DifferentialPair<T> d_sin(DifferentialPair<T> dpx) + DifferentialPair<T> d_mysin(DifferentialPair<T> dpx) { return DifferentialPair<T>( - sin(dpx.p), - T.dmul(cos(dpx.p), dpx.d)); + mysin(dpx.p), + T.dmul(mycos(dpx.p), dpx.d)); } // Cosine __generic<T : IDFloat> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "cos($0)") + __target_intrinsic(glsl, "cos($0)") __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_cos<T>)] - T cos(T x); + [ForwardDerivative(d_mycos<T>)] + T mycos(T x); __generic<T : IDFloat> - DifferentialPair<T> d_cos(DifferentialPair<T> dpx) + DifferentialPair<T> d_mycos(DifferentialPair<T> dpx) { return DifferentialPair<T>( - cos(dpx.p), + mycos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } // Sine and cosine __generic<T : IDFloat> - __target_intrinsic(hlsl) + __target_intrinsic(hlsl, "sincos($0, $1, $2)") __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") - [ForwardDerivative(d_sincos<T>)] - void sincos(T x, out T s, out T c) + [ForwardDerivative(d_mysincos<T>)] + void mysincos(T x, out T s, out T c) { s = sin(x); c = cos(x); } __generic<T : IDFloat> - void d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c) + void d_mysincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c) { T _s; T _c; - sincos(x.p, _s, _c); + mysincos(x.p, _s, _c); s = DifferentialPair<T>(_s, T.dmul(_c, x.d)); c = DifferentialPair<T>(_c, T.dmul(-_s, x.d)); @@ -90,7 +90,7 @@ namespace myintrinsiclib [ForwardDifferentiable] float f(float x) { - return myintrinsiclib.exp(x); + return myintrinsiclib.myexp(x); } [ForwardDifferentiable] @@ -98,7 +98,7 @@ float g(float x) { float s; float t; - myintrinsiclib.sincos(x, s, t); + myintrinsiclib.mysincos(x, s, t); return s + t; } diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 3220976e7..e9385b78c 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -41,8 +41,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) A a = {1.0, 2.0}; A.Differential b = {0.2}; dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0 - outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 - outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2 + outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0 + outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 + outputBuffer[2] = A.dmul<float>(2.0, b).b.x; // Expect: 0.4 } -} +} diff --git a/tests/autodiff/differential-method-synthesis.slang.expected.txt b/tests/autodiff/differential-method-synthesis.slang.expected.txt index 5fbff9752..353c35ec8 100644 --- a/tests/autodiff/differential-method-synthesis.slang.expected.txt +++ b/tests/autodiff/differential-method-synthesis.slang.expected.txt @@ -1,6 +1,6 @@ type: float 0.000000 0.400000 -0.200000 +0.400000 0.000000 0.000000 diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 98adc4a7c..674d5c5ca 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer; typedef float Real; -typealias IDFloat = IFloat & IDifferentiable; +typealias IDFloat = __BuiltinRealType & IDifferentiable; __generic<T : IDifferentiable, let N : int> struct dvector : IDifferentiable @@ -44,13 +44,13 @@ struct myvector : IDifferentiable } - static Differential dmul(This a, Differential b) + static Differential dmul<U: __BuiltinRealType>(U a, Differential b) { Differential output; for (int i = 0; i < N; i++) { - output.values[i] = T.dmul(a.values[i], b.values[i]); + output.values[i] = T.dmul<U>(a, b.values[i]); } return output; @@ -112,7 +112,7 @@ __generic<T : IDFloat, let N : int> [ForwardDerivative(dot_jvp)] T dot(myvector<T, N> a, myvector<T, N> b) { - T curr = (T)0.0; + T curr = __realCast<T, float>(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { @@ -129,7 +129,7 @@ __generic<T : IDFloat, let N : int> DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) { T.Differential curr_d = (T.dzero()); - T curr_p = (T)0.0; + T curr_p = __realCast<T, float>(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { @@ -137,8 +137,8 @@ DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b) curr_d = T.dadd( curr_d, T.dadd( - T.dmul(a.p.values[i], b.d.values[i]), - T.dmul(b.p.values[i], a.d.values[i]))); + T.dmul<T>(a.p.values[i], b.d.values[i]), + T.dmul<T>(b.p.values[i], a.d.values[i]))); } return DifferentialPair<T>(curr_p, curr_d); @@ -203,9 +203,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable return { myvector<Real, N>.dadd(a.val, b.val) }; } - static Differential dmul(This a, Differential b) + static Differential dmul<T: __BuiltinRealType>(T a, Differential b) { - return { myvector<Real, N>.dmul(a.val, b.val) }; + return { myvector<Real, N>.dmul<T>(a, b.val) }; } [ForwardDifferentiable] diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 2be0045d4..7e5625477 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -113,9 +113,9 @@ extension myfloat3 : IDifferentiable } [ForwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul<T : __BuiltinRealType>(T a, Differential b) { - return a * b; + return { __realCast<Real, T>(a) * b.val }; } }; @@ -139,9 +139,9 @@ extension myfloat4 : IDifferentiable } [ForwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul<T: __BuiltinRealType>(T a, Differential b) { - return a * b; + return { __realCast<Real, T>(a) * b.val }; } }; diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 9055e860a..9f03ac4eb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -34,9 +34,9 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul<T: __BuiltinRealType>(T a, Differential b) { - B o = {a.x * b.z}; + B o = {__realCast<float, T>(a) * b.z}; return o; } }; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 06caadce8..bc7343f27 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -32,9 +32,9 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul<T : __BuiltinRealType>(T a, Differential b) { - B o = {a.x * b.z}; + B o = {__realCast<float, T>(a) * b.z}; return o; } }; |
