blob: ce7ce835286ea86f2fd4b68515b9dbc15c3775ec (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
implementing mlp;
// A wrapper of CoopVec<T> to allow it being used in differentiable context.
//
public struct MLVec<int N> : IDifferentiable
{
public CoopVec<NFloat, N> data;
public typealias Differential = MLVec<N>;
public static MLVec<N> fromArray(NFloat[N] values)
{
MLVec<N> result;
[ForceUnroll]
for (int i = 0; i < N; i++)
result.data[i] = values[i];
return result;
}
internal static NFloat[N] coopVecToArray(CoopVec<NFloat, N> v)
{
NFloat[N] arr;
[ForceUnroll]
for (int i = 0; i < N; i++)
arr[i] = v[i];
return arr;
}
[BackwardDerivativeOf(fromArray)]
internal static void fromArrayBwd(inout DifferentialPair<NFloat[N]> values, MLVec<N> dResult)
{
values = diffPair(values.p, coopVecToArray(dResult.data));
}
internal static NFloat[N] toArray(MLVec<N> vec)
{
return coopVecToArray(vec.data);
}
[BackwardDerivativeOf(toArray)]
internal static void toArrayBwd(inout DifferentialPair<MLVec<N>> vec, NFloat[N] dResult)
{
vec = diffPair(vec.p, MLVec<N>.fromArray(dResult));
}
[Differentiable]
public NFloat[N] toArray()
{
return toArray(this);
}
public override static Differential dadd(Differential d0, Differential d1)
{
return {d0.data + d1.data};
}
public override static Differential dmul<U:__BuiltinRealType>(U s, Differential d)
{
return {d.data * __realCast<NFloat>(s)};
}
public override static Differential dzero()
{
return {};
}
}
|