summaryrefslogtreecommitdiffstats
path: root/examples/mlp-training/kernels.slang
diff options
context:
space:
mode:
Diffstat (limited to 'examples/mlp-training/kernels.slang')
-rw-r--r--examples/mlp-training/kernels.slang41
1 files changed, 41 insertions, 0 deletions
diff --git a/examples/mlp-training/kernels.slang b/examples/mlp-training/kernels.slang
new file mode 100644
index 000000000..5be076879
--- /dev/null
+++ b/examples/mlp-training/kernels.slang
@@ -0,0 +1,41 @@
+module kernels;
+
+import common;
+import mlp_sw;
+import network;
+import adam;
+
+[numthreads(256, 1, 1)]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)]
+void learnGradient(
+ uint32_t tid : SV_DispatchThreadID,
+ uniform MyNetwork* network,
+ uniform Atomic<uint32_t>* lossBuffer,
+ uniform float2* inputs,
+ uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+
+ var input = (half2)inputs[tid];
+ bwd_diff(loss)(network, input.x, input.y, 1.0h);
+ let thisLoss = (float)loss(network, input.x, input.y);
+ let maxLoss = WaveActiveMax(thisLoss);
+ if (WaveIsFirstLane())
+ {
+ lossBuffer.max(bit_cast<uint32_t>(maxLoss));
+ }
+}
+
+[numthreads(256, 1, 1)]
+void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+ if (isnan(gradients[tid]))
+ {
+ gradients[tid] = 0.0h;
+ return;
+ }
+ AdamOptimizer::step(states[tid], params[tid], gradients[tid]);
+} \ No newline at end of file