diff options
| author | Yong He <yonghe@outlook.com> | 2025-06-30 14:32:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 21:32:50 +0000 |
| commit | f28f67d988158d6c46f7ffe967152f98d32a37b2 (patch) | |
| tree | 2aa620986a87ec69cf1f210c714312e42b62ac9e /examples/mlp-training-coopvec/network.slang | |
| parent | a55ff722cae338a8fcf5402858c47cf0650a8e5e (diff) | |
Add MLP training examples. (#7550)
* Add MLP training examples.
* Formatting fix.
* Fix.
* Improve documentation on coopvector.
* Improve doc.
* Update doc.
* Fix typo.
* Cleanup shader.
* Cleanup.
* Fix test.
* Fix type check recursion.
* Fix.
* Fix.
* Fix override check.
Diffstat (limited to 'examples/mlp-training-coopvec/network.slang')
| -rw-r--r-- | examples/mlp-training-coopvec/network.slang | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/examples/mlp-training-coopvec/network.slang b/examples/mlp-training-coopvec/network.slang new file mode 100644 index 000000000..5741487c4 --- /dev/null +++ b/examples/mlp-training-coopvec/network.slang @@ -0,0 +1,58 @@ +module network; + +import common; +import mlp; + +public struct MyNetwork +{ + public FeedForwardLayer<4, 16> layer1; + public FeedForwardLayer<16, 4> layer2; + + [Differentiable] + internal MLVec<4> encodeInput(NFloat x, NFloat y) + { + return MLVec<4>.fromArray({ + x, + y, + x*x, + y*y, + }); + } + + [Differentiable] + internal MLVec<4> _eval(NFloat x, NFloat y) + { + let encoding = encodeInput(x, y); + let layer1Output = layer1.eval(encoding); + let leyer2Output = layer2.eval(layer1Output); + return leyer2Output; + } + + [Differentiable] + public half4 eval(no_diff NFloat x, no_diff NFloat y) + { + let mlv = _eval(x, y); + let arr = mlv.toArray(); + return half4(arr[0], arr[1], arr[2], arr[3]); + } +} + +[Differentiable] +public half loss(MyNetwork* network, no_diff half x, no_diff half y) +{ + let networkResult = network.eval(x, y); + let gt = no_diff groundtruth(x, y); + let diff = networkResult - gt; + return dot(diff, diff); +} + +public half4 groundtruth(half x, half y) +{ + return { + (x + y) / (1 + y * y), + 2 * x + y, + 0.5 * x * x + 1.2 * y, + x + 0.5 * y * y, + }; +} + |
