summaryrefslogtreecommitdiffstats
path: root/examples/mlp-training/adam.slang
blob: 8f9b15f010f01c3055dbedb014508b8ba5c38954 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
module adam;

import mlp_sw;
import common;

public struct AdamState
{
    internal NFloat mean;
    internal NFloat variance;
    internal int iteration;
}

public struct AdamOptimizer
{
    // Adam parameters
    public static const NFloat beta1 = 0.9h;
    public static const NFloat beta2 = 0.999h;
    public static const NFloat epsilon = 1e-7h;
    public static const NFloat learningRate = 0.01h;

    public static void step(inout AdamState state, inout NFloat param, inout NFloat grad)
    {
        state.iteration++;
        if (isinf(grad))
        {
            if (grad > 0)
                grad = 10000.0h;
            else
                grad = -10000.0h;
        }
        state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad;
        state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad;
        NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration)));
        NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration)));
        param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon));
        grad = NFloat(0.f);
    }
}