summaryrefslogtreecommitdiff
path: root/examples/mlp-training-coopvec/mlvec.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-06-30 14:32:50 -0700
committerGitHub <noreply@github.com>2025-06-30 21:32:50 +0000
commitf28f67d988158d6c46f7ffe967152f98d32a37b2 (patch)
tree2aa620986a87ec69cf1f210c714312e42b62ac9e /examples/mlp-training-coopvec/mlvec.slang
parenta55ff722cae338a8fcf5402858c47cf0650a8e5e (diff)
Add MLP training examples. (#7550)
* Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check.
Diffstat (limited to 'examples/mlp-training-coopvec/mlvec.slang')
-rw-r--r--examples/mlp-training-coopvec/mlvec.slang63
1 files changed, 63 insertions, 0 deletions
diff --git a/examples/mlp-training-coopvec/mlvec.slang b/examples/mlp-training-coopvec/mlvec.slang
new file mode 100644
index 000000000..ce7ce8352
--- /dev/null
+++ b/examples/mlp-training-coopvec/mlvec.slang
@@ -0,0 +1,63 @@
+implementing mlp;
+
+// A wrapper of CoopVec<T> to allow it being used in differentiable context.
+//
+public struct MLVec<int N> : IDifferentiable
+{
+ public CoopVec<NFloat, N> data;
+ public typealias Differential = MLVec<N>;
+
+ public static MLVec<N> fromArray(NFloat[N] values)
+ {
+ MLVec<N> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ result.data[i] = values[i];
+ return result;
+ }
+
+ internal static NFloat[N] coopVecToArray(CoopVec<NFloat, N> v)
+ {
+ NFloat[N] arr;
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ arr[i] = v[i];
+ return arr;
+ }
+
+ [BackwardDerivativeOf(fromArray)]
+ internal static void fromArrayBwd(inout DifferentialPair<NFloat[N]> values, MLVec<N> dResult)
+ {
+ values = diffPair(values.p, coopVecToArray(dResult.data));
+ }
+
+ internal static NFloat[N] toArray(MLVec<N> vec)
+ {
+ return coopVecToArray(vec.data);
+ }
+
+ [BackwardDerivativeOf(toArray)]
+ internal static void toArrayBwd(inout DifferentialPair<MLVec<N>> vec, NFloat[N] dResult)
+ {
+ vec = diffPair(vec.p, MLVec<N>.fromArray(dResult));
+ }
+
+ [Differentiable]
+ public NFloat[N] toArray()
+ {
+ return toArray(this);
+ }
+
+ public override static Differential dadd(Differential d0, Differential d1)
+ {
+ return {d0.data + d1.data};
+ }
+ public override static Differential dmul<U:__BuiltinRealType>(U s, Differential d)
+ {
+ return {d.data * __realCast<NFloat>(s)};
+ }
+ public override static Differential dzero()
+ {
+ return {};
+ }
+}