summaryrefslogtreecommitdiffstats
path: root/examples/mlp-training/network.slang
diff options
context:
space:
mode:
Diffstat (limited to 'examples/mlp-training/network.slang')
-rw-r--r--examples/mlp-training/network.slang59
1 files changed, 59 insertions, 0 deletions
diff --git a/examples/mlp-training/network.slang b/examples/mlp-training/network.slang
new file mode 100644
index 000000000..a48820f11
--- /dev/null
+++ b/examples/mlp-training/network.slang
@@ -0,0 +1,59 @@
+module network;
+
+import common;
+import mlp_sw;
+
+public struct MyNetwork
+{
+ public FeedForwardLayer<4, 16> layer1;
+ public FeedForwardLayer<16, 4> layer2;
+
+ [Differentiable]
+ internal MLVec<4> encodeInput(NFloat x, NFloat y)
+ {
+ return MLVec<4>.fromArray({
+ x,
+ y,
+ x*x,
+ y*y,
+ });
+ }
+
+ [Differentiable]
+ internal MLVec<4> _eval(NFloat x, NFloat y)
+ {
+ let encoding = encodeInput(x, y);
+ let layer1Output = layer1.eval(encoding);
+ let leyer2Output = layer2.eval(layer1Output);
+ return leyer2Output;
+ }
+
+ [Differentiable]
+ public half4 eval(no_diff NFloat x, no_diff NFloat y)
+ {
+ let mlv = _eval(x, y);
+ let arr = mlv.toArray();
+ return half4(arr[0], arr[1], arr[2], arr[3]);
+ }
+}
+
+[Differentiable]
+public half loss(MyNetwork* network, no_diff half x, no_diff half y)
+{
+ let networkResult = network.eval(x, y);
+ let gt = no_diff groundtruth(x, y);
+ let diff = networkResult - gt;
+
+ return dot(diff, diff);
+}
+
+public half4 groundtruth(half x, half y)
+{
+ return {
+ (x + y) / (1 + y * y),
+ 2 * x + y,
+ 0.5 * x * x + 1.2 * y,
+ x + 0.5 * y * y,
+ };
+}
+