diff options
| author | Yong He <yonghe@outlook.com> | 2025-06-30 14:32:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 21:32:50 +0000 |
| commit | f28f67d988158d6c46f7ffe967152f98d32a37b2 (patch) | |
| tree | 2aa620986a87ec69cf1f210c714312e42b62ac9e /examples/mlp-training-coopvec/mlvec.slang | |
| parent | a55ff722cae338a8fcf5402858c47cf0650a8e5e (diff) | |
Add MLP training examples. (#7550)
* Add MLP training examples.
* Formatting fix.
* Fix.
* Improve documentation on coopvector.
* Improve doc.
* Update doc.
* Fix typo.
* Cleanup shader.
* Cleanup.
* Fix test.
* Fix type check recursion.
* Fix.
* Fix.
* Fix override check.
Diffstat (limited to 'examples/mlp-training-coopvec/mlvec.slang')
| -rw-r--r-- | examples/mlp-training-coopvec/mlvec.slang | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/examples/mlp-training-coopvec/mlvec.slang b/examples/mlp-training-coopvec/mlvec.slang new file mode 100644 index 000000000..ce7ce8352 --- /dev/null +++ b/examples/mlp-training-coopvec/mlvec.slang @@ -0,0 +1,63 @@ +implementing mlp; + +// A wrapper of CoopVec<T> to allow it being used in differentiable context. +// +public struct MLVec<int N> : IDifferentiable +{ + public CoopVec<NFloat, N> data; + public typealias Differential = MLVec<N>; + + public static MLVec<N> fromArray(NFloat[N] values) + { + MLVec<N> result; + [ForceUnroll] + for (int i = 0; i < N; i++) + result.data[i] = values[i]; + return result; + } + + internal static NFloat[N] coopVecToArray(CoopVec<NFloat, N> v) + { + NFloat[N] arr; + [ForceUnroll] + for (int i = 0; i < N; i++) + arr[i] = v[i]; + return arr; + } + + [BackwardDerivativeOf(fromArray)] + internal static void fromArrayBwd(inout DifferentialPair<NFloat[N]> values, MLVec<N> dResult) + { + values = diffPair(values.p, coopVecToArray(dResult.data)); + } + + internal static NFloat[N] toArray(MLVec<N> vec) + { + return coopVecToArray(vec.data); + } + + [BackwardDerivativeOf(toArray)] + internal static void toArrayBwd(inout DifferentialPair<MLVec<N>> vec, NFloat[N] dResult) + { + vec = diffPair(vec.p, MLVec<N>.fromArray(dResult)); + } + + [Differentiable] + public NFloat[N] toArray() + { + return toArray(this); + } + + public override static Differential dadd(Differential d0, Differential d1) + { + return {d0.data + d1.data}; + } + public override static Differential dmul<U:__BuiltinRealType>(U s, Differential d) + { + return {d.data * __realCast<NFloat>(s)}; + } + public override static Differential dzero() + { + return {}; + } +} |
