diff options
Diffstat (limited to 'examples/mlp-training/kernels.slang')
| -rw-r--r-- | examples/mlp-training/kernels.slang | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/examples/mlp-training/kernels.slang b/examples/mlp-training/kernels.slang new file mode 100644 index 000000000..5be076879 --- /dev/null +++ b/examples/mlp-training/kernels.slang @@ -0,0 +1,41 @@ +module kernels; + +import common; +import mlp_sw; +import network; +import adam; + +[numthreads(256, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)] +void learnGradient( + uint32_t tid : SV_DispatchThreadID, + uniform MyNetwork* network, + uniform Atomic<uint32_t>* lossBuffer, + uniform float2* inputs, + uniform uint32_t count) +{ + if (tid >= count) + return; + + var input = (half2)inputs[tid]; + bwd_diff(loss)(network, input.x, input.y, 1.0h); + let thisLoss = (float)loss(network, input.x, input.y); + let maxLoss = WaveActiveMax(thisLoss); + if (WaveIsFirstLane()) + { + lossBuffer.max(bit_cast<uint32_t>(maxLoss)); + } +} + +[numthreads(256, 1, 1)] +void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count) +{ + if (tid >= count) + return; + if (isnan(gradients[tid])) + { + gradients[tid] = 0.0h; + return; + } + AdamOptimizer::step(states[tid], params[tid], gradients[tid]); +}
\ No newline at end of file |
