//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; typedef float Real; typealias IDFloat = __BuiltinRealType & IDifferentiable; __generic struct dvector : IDifferentiable { typedef dvector Differential; [DerivativeMember(Differential.values)] T values[N]; }; __generic struct myvector : IDifferentiable { typedef dvector Differential; [DerivativeMember(Differential.values)] T values[N]; __init(T c) { [ForceUnroll] for (int i = 0; i < N; i++) { values[i] = c; } } static Differential dadd(Differential a, Differential b) { Differential output; for (int i = 0; i < N; i++) { output.values[i] = T.dadd(a.values[i], b.values[i]); } return output; } static Differential dmul(U a, Differential b) { Differential output; for (int i = 0; i < N; i++) { output.values[i] = T.dmul(a, b.values[i]); } return output; } static Differential dzero() { Differential output; for (int i = 0; i < N; i++) { output.values[i] = T.dzero(); } return output; } }; [ForwardDifferentiable] __generic myvector operator +(myvector a, myvector b) { myvector output; [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a.values[i] + b.values[i]; } return output; } [ForwardDifferentiable] __generic myvector operator *(myvector a, myvector b) { myvector output; [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a.values[i] * b.values[i]; } return output; } [ForwardDifferentiable] __generic myvector operator *(T a, myvector b) { myvector output; [ForceUnroll] for (int i = 0; i < N; i++) { output.values[i] = a * b.values[i]; } return output; } __generic [ForwardDerivative(dot_jvp)] T dot(myvector a, myvector b) { T curr = __realCast(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { curr = curr + (a.values[i] * b.values[i]); } return curr; } __generic typedef DifferentialPair> dpvector; __generic DifferentialPair dot_jvp(dpvector a, dpvector b) { T.Differential curr_d = (T.dzero()); T curr_p = __realCast(0.f); [ForceUnroll] for (int i = 0; i < N; i++) { curr_p = curr_p + (a.p.values[i] * b.p.values[i]); curr_d = T.dadd( curr_d, T.dadd( T.dmul(a.p.values[i], b.d.values[i]), T.dmul(b.p.values[i], a.d.values[i]))); } return DifferentialPair(curr_p, curr_d); } __generic struct lineardvector : IDifferentiable { typedef lineardvector Differential; myvector.Differential val; __init(vector a) { [ForceUnroll] for (int i = 0; i < N; i++) { val.values[i] = a[i]; } } // Add a new constructor for dadd() function. __init(Real a[N]) { [ForceUnroll] for (int i = 0; i < N; i++) { val.values[i] = a[i]; } } }; __generic struct linearvector : MyLinearArithmeticType, IDifferentiable { typedef lineardvector Differential; [DerivativeMember(Differential.val)] myvector val; [ForwardDifferentiable] static linearvector ladd(linearvector a, linearvector b) { return linearvector(a.val + b.val); } [ForwardDifferentiable] static linearvector lmul(linearvector a, linearvector b) { return linearvector(a.val * b.val); } [ForwardDifferentiable] static linearvector lscale(float a, linearvector b) { return linearvector(a * b.val); } [ForwardDifferentiable] static float ldot(linearvector a, linearvector b) { return dot(a.val, b.val); } static Differential dzero() { lineardvector dout; dout.val = myvector.dzero(); return dout; } static Differential dadd(Differential a, Differential b) { // return { myvector.dadd(a.val, b.val) }; // // Above code will not work because // myvector.dadd will return dvector type // while Differential == lineardvector type // and the constructor of lineardvector requires a vector type // and dvector != vector, though they have the // same members. // // In our new design, generic will not be C-Style struct anymore. dvector d = myvector.dadd(a.val, b.val); return {d.values}; } static Differential dmul(T a, Differential b) { dvector d = myvector.dmul(a, b.val); return {d.values}; } [ForwardDifferentiable] __init(vector a) { [ForceUnroll] for (int i = 0; i < N; i++) { val.values[i] = a[i]; } } [ForwardDifferentiable] __init(myvector a) { val = a; } }; typedef linearvector<3> myfloat3; typedef linearvector<4> myfloat4; typedef lineardvector<3> mydfloat3; typedef lineardvector<4> mydfloat4; typedef DifferentialPair dpfloat; [TreatAsDifferentiable] interface MyLinearArithmeticType { static This ladd(This a, This b); static This lmul(This a, This b); static This lscale(Real a, This b); static Real ldot(This a, This b); }; typedef DifferentialPair dpfloat4; typedef DifferentialPair dpfloat3; extension float : MyLinearArithmeticType { [ForwardDifferentiable] static float ladd(float a, float b) { return a + b; } [ForwardDifferentiable] static float lmul(float a, float b) { return a * b; } [ForwardDifferentiable] static float lscale(float a, float b) { return a * b; } [ForwardDifferentiable] static float ldot(float a, float b) { return a * b; } }; typealias MyLinearArithmeticDifferentiableType = IDifferentiable & MyLinearArithmeticType; __generic [ForwardDifferentiable] T operator +(T a, T b) { return T.ladd(a, b); } __generic [ForwardDifferentiable] T operator *(T a, T b) { return T.lmul(a, b); } __generic [ForwardDifferentiable] G f(G x) { G a = x + x; G b = x * x; return a * a + G.lscale((Real)3.0, x); } [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { dpfloat dpa = dpfloat(2.0, 1.0); dpfloat4 dpf4 = dpfloat4(myfloat4(float4(1.5, 2.0, 0.5, 1.0)), mydfloat4(float4(0.5, 0.8, 1.6, 2.5))); dpfloat3 dpf3 = dpfloat3(myfloat3(float3(1.0, 3.0, 5.0)), mydfloat3(float3(0.5, 1.5, 2.5))); outputBuffer[0] = f(dpa.p); // Expect: 22.0 outputBuffer[1] = __fwd_diff(f)(dpfloat(2.0, 0.5)).d; // Expect: 9.5 outputBuffer[2] = __fwd_diff(f)(dpf4).d.val.values[3]; // Expect: 27.5 outputBuffer[3] = __fwd_diff(f)(dpf3).d.val.values[1]; // Expect: 40.5 } }