summaryrefslogtreecommitdiffstats
path: root/tests/autodiff
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-26 17:15:21 -0400
committerGitHub <noreply@github.com>2023-07-26 17:15:21 -0400
commitba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch)
tree2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /tests/autodiff
parentb8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff)
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces
Diffstat (limited to 'tests/autodiff')
-rw-r--r--tests/autodiff/auto-differential-type.slang5
-rw-r--r--tests/autodiff/custom-intrinsic.slang56
-rw-r--r--tests/autodiff/differential-method-synthesis.slang8
-rw-r--r--tests/autodiff/differential-method-synthesis.slang.expected.txt2
-rw-r--r--tests/autodiff/generic-impl-jvp.slang18
-rw-r--r--tests/autodiff/generic-jvp.slang8
-rw-r--r--tests/autodiff/getter-setter-multi.slang4
-rw-r--r--tests/autodiff/getter-setter.slang4
8 files changed, 53 insertions, 52 deletions
diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang
index efeebb459..a253a25bb 100644
--- a/tests/autodiff/auto-differential-type.slang
+++ b/tests/autodiff/auto-differential-type.slang
@@ -26,9 +26,10 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ __generic<T : __BuiltinRealType>
+ static Differential dmul(T a, Differential b)
{
- Differential o = {a.x * b.x, 0.0};
+ Differential o = { __realCast<float, T>(a * __realCast<T, float>(b.x)), 0.0};
return o;
}
};
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang
index 8048c60ff..dd122a674 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic.slang
@@ -6,81 +6,81 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
-typealias IDFloat = IFloat & IDifferentiable;
+typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable;
namespace myintrinsiclib
{
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "exp($0)")
+ __target_intrinsic(glsl, "exp($0)")
__target_intrinsic(cuda, "$P_exp($0)")
__target_intrinsic(cpp, "$P_exp($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp<T>)]
- T exp(T x);
+ [ForwardDerivative(d_myexp<T>)]
+ T myexp(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_myexp(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- exp(dpx.p),
- T.dmul(exp(dpx.p), dpx.d));
+ myexp(dpx.p),
+ T.dmul(myexp(dpx.p), dpx.d));
}
// Sine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "sin($0)")
+ __target_intrinsic(glsl, "sin($0)")
__target_intrinsic(cuda, "$P_sin($0)")
__target_intrinsic(cpp, "$P_sin($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [ForwardDerivative(d_sin<T>)]
- T sin(T x);
+ [ForwardDerivative(d_mysin<T>)]
+ T mysin(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_mysin(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- sin(dpx.p),
- T.dmul(cos(dpx.p), dpx.d));
+ mysin(dpx.p),
+ T.dmul(mycos(dpx.p), dpx.d));
}
// Cosine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "cos($0)")
+ __target_intrinsic(glsl, "cos($0)")
__target_intrinsic(cuda, "$P_cos($0)")
__target_intrinsic(cpp, "$P_cos($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [ForwardDerivative(d_cos<T>)]
- T cos(T x);
+ [ForwardDerivative(d_mycos<T>)]
+ T mycos(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_mycos(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- cos(dpx.p),
+ mycos(dpx.p),
T.dmul(-sin(dpx.p), dpx.d));
}
// Sine and cosine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
+ __target_intrinsic(hlsl, "sincos($0, $1, $2)")
__target_intrinsic(cuda, "$P_sincos($0, $1, $2)")
- [ForwardDerivative(d_sincos<T>)]
- void sincos(T x, out T s, out T c)
+ [ForwardDerivative(d_mysincos<T>)]
+ void mysincos(T x, out T s, out T c)
{
s = sin(x);
c = cos(x);
}
__generic<T : IDFloat>
- void d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
+ void d_mysincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
{
T _s;
T _c;
- sincos(x.p, _s, _c);
+ mysincos(x.p, _s, _c);
s = DifferentialPair<T>(_s, T.dmul(_c, x.d));
c = DifferentialPair<T>(_c, T.dmul(-_s, x.d));
@@ -90,7 +90,7 @@ namespace myintrinsiclib
[ForwardDifferentiable]
float f(float x)
{
- return myintrinsiclib.exp(x);
+ return myintrinsiclib.myexp(x);
}
[ForwardDifferentiable]
@@ -98,7 +98,7 @@ float g(float x)
{
float s;
float t;
- myintrinsiclib.sincos(x, s, t);
+ myintrinsiclib.mysincos(x, s, t);
return s + t;
}
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang
index 3220976e7..e9385b78c 100644
--- a/tests/autodiff/differential-method-synthesis.slang
+++ b/tests/autodiff/differential-method-synthesis.slang
@@ -41,8 +41,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
A a = {1.0, 2.0};
A.Differential b = {0.2};
dpA dpa = dpA(a, b);
- outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0
- outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4
- outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2
+ outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0
+ outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4
+ outputBuffer[2] = A.dmul<float>(2.0, b).b.x; // Expect: 0.4
}
-}
+}
diff --git a/tests/autodiff/differential-method-synthesis.slang.expected.txt b/tests/autodiff/differential-method-synthesis.slang.expected.txt
index 5fbff9752..353c35ec8 100644
--- a/tests/autodiff/differential-method-synthesis.slang.expected.txt
+++ b/tests/autodiff/differential-method-synthesis.slang.expected.txt
@@ -1,6 +1,6 @@
type: float
0.000000
0.400000
-0.200000
+0.400000
0.000000
0.000000
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index 98adc4a7c..674d5c5ca 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer;
typedef float Real;
-typealias IDFloat = IFloat & IDifferentiable;
+typealias IDFloat = __BuiltinRealType & IDifferentiable;
__generic<T : IDifferentiable, let N : int>
struct dvector : IDifferentiable
@@ -44,13 +44,13 @@ struct myvector : IDifferentiable
}
- static Differential dmul(This a, Differential b)
+ static Differential dmul<U: __BuiltinRealType>(U a, Differential b)
{
Differential output;
for (int i = 0; i < N; i++)
{
- output.values[i] = T.dmul(a.values[i], b.values[i]);
+ output.values[i] = T.dmul<U>(a, b.values[i]);
}
return output;
@@ -112,7 +112,7 @@ __generic<T : IDFloat, let N : int>
[ForwardDerivative(dot_jvp)]
T dot(myvector<T, N> a, myvector<T, N> b)
{
- T curr = (T)0.0;
+ T curr = __realCast<T, float>(0.f);
[ForceUnroll]
for (int i = 0; i < N; i++)
{
@@ -129,7 +129,7 @@ __generic<T : IDFloat, let N : int>
DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b)
{
T.Differential curr_d = (T.dzero());
- T curr_p = (T)0.0;
+ T curr_p = __realCast<T, float>(0.f);
[ForceUnroll]
for (int i = 0; i < N; i++)
{
@@ -137,8 +137,8 @@ DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b)
curr_d = T.dadd(
curr_d,
T.dadd(
- T.dmul(a.p.values[i], b.d.values[i]),
- T.dmul(b.p.values[i], a.d.values[i])));
+ T.dmul<T>(a.p.values[i], b.d.values[i]),
+ T.dmul<T>(b.p.values[i], a.d.values[i])));
}
return DifferentialPair<T>(curr_p, curr_d);
@@ -203,9 +203,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
return { myvector<Real, N>.dadd(a.val, b.val) };
}
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return { myvector<Real, N>.dmul(a.val, b.val) };
+ return { myvector<Real, N>.dmul<T>(a, b.val) };
}
[ForwardDifferentiable]
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 2be0045d4..7e5625477 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -113,9 +113,9 @@ extension myfloat3 : IDifferentiable
}
[ForwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- return a * b;
+ return { __realCast<Real, T>(a) * b.val };
}
};
@@ -139,9 +139,9 @@ extension myfloat4 : IDifferentiable
}
[ForwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return a * b;
+ return { __realCast<Real, T>(a) * b.val };
}
};
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
index 9055e860a..9f03ac4eb 100644
--- a/tests/autodiff/getter-setter-multi.slang
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -34,9 +34,9 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- B o = {a.x * b.z};
+ B o = {__realCast<float, T>(a) * b.z};
return o;
}
};
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 06caadce8..bc7343f27 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -32,9 +32,9 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- B o = {a.x * b.z};
+ B o = {__realCast<float, T>(a) * b.z};
return o;
}
};