From ba89fc84267bfd09f1c8abf10a5b85d09bbc79de Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:15:21 -0400 Subject: Refactor `dmul(This, Differential)` to `dmul(T, Differential)` (#3029) * Refactor `dmul(This, Differential)` to `dmul(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces --- tests/autodiff/auto-differential-type.slang | 5 +- tests/autodiff/custom-intrinsic.slang | 56 +++++++++++----------- tests/autodiff/differential-method-synthesis.slang | 8 ++-- ...ifferential-method-synthesis.slang.expected.txt | 2 +- tests/autodiff/generic-impl-jvp.slang | 18 +++---- tests/autodiff/generic-jvp.slang | 8 ++-- tests/autodiff/getter-setter-multi.slang | 4 +- tests/autodiff/getter-setter.slang | 4 +- 8 files changed, 53 insertions(+), 52 deletions(-) (limited to 'tests') diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang index efeebb459..a253a25bb 100644 --- a/tests/autodiff/auto-differential-type.slang +++ b/tests/autodiff/auto-differential-type.slang @@ -26,9 +26,10 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + __generic + static Differential dmul(T a, Differential b) { - Differential o = {a.x * b.x, 0.0}; + Differential o = { __realCast(a * __realCast(b.x)), 0.0}; return o; } }; diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang index 8048c60ff..dd122a674 100644 --- a/tests/autodiff/custom-intrinsic.slang +++ b/tests/autodiff/custom-intrinsic.slang @@ -6,81 +6,81 @@ RWStructuredBuffer outputBuffer; typedef DifferentialPair dpfloat; -typealias IDFloat = IFloat & IDifferentiable; +typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; namespace myintrinsiclib { __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "exp($0)") + __target_intrinsic(glsl, "exp($0)") __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp)] - T exp(T x); + [ForwardDerivative(d_myexp)] + T myexp(T x); __generic - DifferentialPair d_exp(DifferentialPair dpx) + DifferentialPair d_myexp(DifferentialPair dpx) { return DifferentialPair( - exp(dpx.p), - T.dmul(exp(dpx.p), dpx.d)); + myexp(dpx.p), + T.dmul(myexp(dpx.p), dpx.d)); } // Sine __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "sin($0)") + __target_intrinsic(glsl, "sin($0)") __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_sin)] - T sin(T x); + [ForwardDerivative(d_mysin)] + T mysin(T x); __generic - DifferentialPair d_sin(DifferentialPair dpx) + DifferentialPair d_mysin(DifferentialPair dpx) { return DifferentialPair( - sin(dpx.p), - T.dmul(cos(dpx.p), dpx.d)); + mysin(dpx.p), + T.dmul(mycos(dpx.p), dpx.d)); } // Cosine __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) + __target_intrinsic(hlsl, "cos($0)") + __target_intrinsic(glsl, "cos($0)") __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_cos)] - T cos(T x); + [ForwardDerivative(d_mycos)] + T mycos(T x); __generic - DifferentialPair d_cos(DifferentialPair dpx) + DifferentialPair d_mycos(DifferentialPair dpx) { return DifferentialPair( - cos(dpx.p), + mycos(dpx.p), T.dmul(-sin(dpx.p), dpx.d)); } // Sine and cosine __generic - __target_intrinsic(hlsl) + __target_intrinsic(hlsl, "sincos($0, $1, $2)") __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") - [ForwardDerivative(d_sincos)] - void sincos(T x, out T s, out T c) + [ForwardDerivative(d_mysincos)] + void mysincos(T x, out T s, out T c) { s = sin(x); c = cos(x); } __generic - void d_sincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) + void d_mysincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) { T _s; T _c; - sincos(x.p, _s, _c); + mysincos(x.p, _s, _c); s = DifferentialPair(_s, T.dmul(_c, x.d)); c = DifferentialPair(_c, T.dmul(-_s, x.d)); @@ -90,7 +90,7 @@ namespace myintrinsiclib [ForwardDifferentiable] float f(float x) { - return myintrinsiclib.exp(x); + return myintrinsiclib.myexp(x); } [ForwardDifferentiable] @@ -98,7 +98,7 @@ float g(float x) { float s; float t; - myintrinsiclib.sincos(x, s, t); + myintrinsiclib.mysincos(x, s, t); return s + t; } diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang index 3220976e7..e9385b78c 100644 --- a/tests/autodiff/differential-method-synthesis.slang +++ b/tests/autodiff/differential-method-synthesis.slang @@ -41,8 +41,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) A a = {1.0, 2.0}; A.Differential b = {0.2}; dpA dpa = dpA(a, b); - outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0 - outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 - outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2 + outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0 + outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4 + outputBuffer[2] = A.dmul(2.0, b).b.x; // Expect: 0.4 } -} +} diff --git a/tests/autodiff/differential-method-synthesis.slang.expected.txt b/tests/autodiff/differential-method-synthesis.slang.expected.txt index 5fbff9752..353c35ec8 100644 --- a/tests/autodiff/differential-method-synthesis.slang.expected.txt +++ b/tests/autodiff/differential-method-synthesis.slang.expected.txt @@ -1,6 +1,6 @@ type: float 0.000000 0.400000 -0.200000 +0.400000 0.000000 0.000000 diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index 98adc4a7c..674d5c5ca 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -6,7 +6,7 @@ RWStructuredBuffer outputBuffer; typedef float Real; -typealias IDFloat = IFloat & IDifferentiable; +typealias IDFloat = __BuiltinRealType & IDifferentiable; __generic struct dvector : IDifferentiable @@ -44,13 +44,13 @@ struct myvector : IDifferentiable } - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { Differential output; for (int i = 0; i < N; i++) { - output.values[i] = T.dmul(a.values[i], b.values[i]); + output.values[i] = T.dmul(a, b.values[i]); } return output; @@ -112,7 +112,7 @@ __generic [ForwardDerivative(dot_jvp)] T dot(myvector a, myvector b) { - T curr = (T)0.0; + T curr = __realCast(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { @@ -129,7 +129,7 @@ __generic DifferentialPair dot_jvp(dpvector a, dpvector b) { T.Differential curr_d = (T.dzero()); - T curr_p = (T)0.0; + T curr_p = __realCast(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { @@ -137,8 +137,8 @@ DifferentialPair dot_jvp(dpvector a, dpvector b) curr_d = T.dadd( curr_d, T.dadd( - T.dmul(a.p.values[i], b.d.values[i]), - T.dmul(b.p.values[i], a.d.values[i]))); + T.dmul(a.p.values[i], b.d.values[i]), + T.dmul(b.p.values[i], a.d.values[i]))); } return DifferentialPair(curr_p, curr_d); @@ -203,9 +203,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable return { myvector.dadd(a.val, b.val) }; } - static Differential dmul(This a, Differential b) + static Differential dmul(T a, Differential b) { - return { myvector.dmul(a.val, b.val) }; + return { myvector.dmul(a, b.val) }; } [ForwardDifferentiable] diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 2be0045d4..7e5625477 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -113,9 +113,9 @@ extension myfloat3 : IDifferentiable } [ForwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(T a, Differential b) { - return a * b; + return { __realCast(a) * b.val }; } }; @@ -139,9 +139,9 @@ extension myfloat4 : IDifferentiable } [ForwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(T a, Differential b) { - return a * b; + return { __realCast(a) * b.val }; } }; diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang index 9055e860a..9f03ac4eb 100644 --- a/tests/autodiff/getter-setter-multi.slang +++ b/tests/autodiff/getter-setter-multi.slang @@ -34,9 +34,9 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(T a, Differential b) { - B o = {a.x * b.z}; + B o = {__realCast(a) * b.z}; return o; } }; diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang index 06caadce8..bc7343f27 100644 --- a/tests/autodiff/getter-setter.slang +++ b/tests/autodiff/getter-setter.slang @@ -32,9 +32,9 @@ struct A : IDifferentiable } [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(T a, Differential b) { - B o = {a.x * b.z}; + B o = {__realCast(a) * b.z}; return o; } }; -- cgit v1.2.3