summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-26 17:15:21 -0400
committerGitHub <noreply@github.com>2023-07-26 17:15:21 -0400
commitba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch)
tree2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang/diff.meta.slang
parentb8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff)
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang97
1 files changed, 62 insertions, 35 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f2fd8e3b0..3e381e55d 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -374,12 +374,13 @@ extension Array<T, N> : IDifferentiable
return result;
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
Array<T.Differential, N> result;
for (int i = 0; i < N; i++)
- result[i] = T.dmul(a[i], b[i]);
+ result[i] = T.dmul<U>(a, b[i]);
return result;
}
}
@@ -543,8 +544,8 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
for (int i = 0; i < N; ++i)
{
result = result + dpx.p[i] * dpy.p[i];
- d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast<T.Differential>(dpy.d[i])));
- d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast<T.Differential>(dpx.d[i])));
+ d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpx.p[i] * dpy.d[i]));
+ d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpy.p[i] * dpx.d[i]));
}
return DifferentialPair<T>(result, d_result);
}
@@ -797,7 +798,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \
{ \
@@ -805,7 +806,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \
{ \
@@ -813,21 +814,21 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \
{ \
- typealias ReturnType = vector<T,N>; \
- matrix<T,M,N>.Differential diff; \
+ typealias ReturnType = vector<T, N>; \
+ matrix<T, M, N>.Differential diff; \
[ForceUnroll] for (int i = 0; i < M; i++) \
{ \
var dpx = diffPair(dpm.p[i], dpm.d[i]); \
- diff[i] = FWD_DIFF_FUNC; \
+ diff[i] = __slang_noop_cast<vector<T, N>>(FWD_DIFF_FUNC); \
} \
return diffPair(NAME(dpm.p), diff); \
} \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \
{ \
@@ -835,31 +836,57 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
- inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
+ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
{ \
typealias ReturnType = vector<T, N>; \
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
- inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
+ inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
{ \
typealias ReturnType = vector<T, N>; \
matrix<T, M, N>.Differential diff; \
[ForceUnroll] for (int i = 0; i < M; i++) \
{ \
var dpx = diffPair(m.p[i], m.d[i]); \
- var dOut = mdOut[i]; \
+ var dOut = __slang_noop_cast<vector<T, N>>(mdOut[i]); \
diff[i] = BWD_DIFF_FUNC; \
} \
m = diffPair(m.p, diff); \
}
-#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut))
+#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, __mul_p_d(DIFF_FUNC, dpx.d), __mul_p_d(DIFF_FUNC, dOut))
+
+/// Element-wise multiply for scalars and vectors for (T, T.Differential)
+__generic<T : __BuiltinFloatingPointType>
+[__unsafeForceInlineEarly]
+[Differentiable]
+T.Differential __mul_p_d(T a, T.Differential b)
+{
+ return __slang_noop_cast<T.Differential>(a * __slang_noop_cast<T>(b));
+}
+
+__generic<T : __BuiltinFloatingPointType>
+[__unsafeForceInlineEarly]
+[Differentiable]
+T __mul_p_d(T a, T b)
+{
+ return (a * b);
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[__unsafeForceInlineEarly]
+[Differentiable]
+vector<T, N> __mul_p_d(vector<T, N> a, vector<T, N> b)
+{
+ return a * b;
+}
+
/// Detach and set derivatives to zero.
__generic<T : IDifferentiable>
@@ -871,14 +898,14 @@ T detach(T x);
#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0))))
// Absolute value
-UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut))
+UNARY_DERIVATIVE_IMPL(abs, (__mul_p_d(SLANG_SIGN(dpx.p), (dpx.d))), (__mul_p_d(SLANG_SIGN(dpx.p), (dOut))))
// Saturate
UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut))
// frac
UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut)
// raidans, degrees
-SIMPLE_UNARY_DERIVATIVE_IMPL(radians, T(0.01745329251994329576923690768489))
-SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, T(57.295779513082320876798154814105))
+SIMPLE_UNARY_DERIVATIVE_IMPL(radians, ReturnType(T(0.01745329251994329576923690768489)))
+SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, ReturnType(T(57.295779513082320876798154814105)))
// Exponent
SIMPLE_UNARY_DERIVATIVE_IMPL(exp, exp(dpx.p))
SIMPLE_UNARY_DERIVATIVE_IMPL(exp2, exp2(dpx.p)* T(50.69314718055994530941723212145818))
@@ -915,8 +942,8 @@ __generic<T : __BuiltinFloatingPointType>
[ForwardDerivativeOf(atan2)]
DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
{
- T.Differential dx = T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d);
- T.Differential dy = T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d);
+ T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d);
+ T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d);
return DifferentialPair<T>(
atan2(dpy.p, dpx.p),
T.dadd(dx, dy));
@@ -928,8 +955,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(atan2)]
void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d));
- dpy = diffPair(dpy.p, T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d));
+ dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d));
+ dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d));
}
VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2)
@@ -968,8 +995,8 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
}
T val = pow(dpx.p, dpy.p);
- T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d);
- T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d);
+ T.Differential d1 = __mul_p_d((val * log(dpx.p)), dpy.d);
+ T.Differential d2 = __mul_p_d((val * dpy.p / dpx.p), dpx.d);
return DifferentialPair<T>(
val,
T.dadd(d1, d2)
@@ -993,10 +1020,10 @@ void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Dif
T val = pow(dpx.p, dpy.p);
dpx = diffPair(
dpx.p,
- T.dmul(val * dpy.p / dpx.p, dOut));
+ (__mul_p_d((val * dpy.p / dpx.p), dOut)));
dpy = diffPair(
dpy.p,
- T.dmul(val * log(dpx.p), dOut));
+ (__mul_p_d((val * log(dpx.p)), dOut)));
}
}
@@ -1061,7 +1088,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D
{
return DifferentialPair<T>(
lerp(dpx.p, dpy.p, dps.p),
- T.dadd(T.dadd(T.dmul((T(1.0) - dps.p), dpx.d), T.dmul(dps.p, dpy.d)), T.dmul(dpy.p - dpx.p, dps.d))
+ T.dadd(T.dadd(__mul_p_d((T(1.0) - dps.p), dpx.d), __mul_p_d(dps.p, dpy.d)), __mul_p_d((dpy.p - dpx.p), dps.d))
);
}
__generic<T : __BuiltinFloatingPointType>
@@ -1070,9 +1097,9 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(lerp)]
void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(T(1.0) - dps.p, dOut));
- dpy = diffPair(dpy.p, T.dmul(dps.p, dOut));
- dps = diffPair(dpy.p, T.dmul((dpy.p - dpx.p), dOut));
+ dpx = diffPair(dpx.p, __mul_p_d((T(1.0) - dps.p), dOut));
+ dpy = diffPair(dpy.p, __mul_p_d(dps.p, dOut));
+ dps = diffPair(dpy.p, __mul_p_d((dpy.p - dpx.p), dOut));
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp)
@@ -1175,7 +1202,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di
{
return DifferentialPair<T>(
mad(dpx.p, dpy.p, dpz.p),
- T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d));
+ T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d));
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
@@ -1183,8 +1210,8 @@ __generic<T : __BuiltinFloatingPointType>
[PreferRecompute]
void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut));
- dpy = diffPair(dpy.p, T.dmul(dpx.p, dOut));
+ dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut));
+ dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut));
dpz = diffPair(dpz.p, dOut);
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad)