summaryrefslogtreecommitdiff
path: root/examples/mlp-training-coopvec/kernels.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-06-30 14:32:50 -0700
committerGitHub <noreply@github.com>2025-06-30 21:32:50 +0000
commitf28f67d988158d6c46f7ffe967152f98d32a37b2 (patch)
tree2aa620986a87ec69cf1f210c714312e42b62ac9e /examples/mlp-training-coopvec/kernels.slang
parenta55ff722cae338a8fcf5402858c47cf0650a8e5e (diff)
Add MLP training examples. (#7550)
* Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check.
Diffstat (limited to 'examples/mlp-training-coopvec/kernels.slang')
-rw-r--r--examples/mlp-training-coopvec/kernels.slang41
1 files changed, 41 insertions, 0 deletions
diff --git a/examples/mlp-training-coopvec/kernels.slang b/examples/mlp-training-coopvec/kernels.slang
new file mode 100644
index 000000000..712494b1f
--- /dev/null
+++ b/examples/mlp-training-coopvec/kernels.slang
@@ -0,0 +1,41 @@
+module kernels;
+
+import common;
+import mlp;
+import network;
+import adam;
+
+[numthreads(256, 1, 1)]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvCooperativeVectorNV)]
+void learnGradient(
+ uint32_t tid : SV_DispatchThreadID,
+ uniform MyNetwork* network,
+ uniform Atomic<uint32_t>* lossBuffer,
+ uniform float2* inputs,
+ uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+
+ var input = (half2)inputs[tid];
+ bwd_diff(loss)(network, input.x, input.y, 1.0h);
+ let thisLoss = (float)loss(network, input.x, input.y);
+ let maxLoss = WaveActiveMax(thisLoss);
+ if (WaveIsFirstLane())
+ {
+ lossBuffer.max(bit_cast<uint32_t>(maxLoss));
+ }
+}
+
+[numthreads(256, 1, 1)]
+void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+ if (isnan(gradients[tid]))
+ {
+ gradients[tid] = 0.0h;
+ return;
+ }
+ AdamOptimizer::step(states[tid], params[tid], gradients[tid]);
+} \ No newline at end of file