summaryrefslogtreecommitdiffstats
path: root/examples/mlp-training/mlp-training.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/mlp-training/mlp-training.cpp')
-rw-r--r--examples/mlp-training/mlp-training.cpp389
1 files changed, 389 insertions, 0 deletions
diff --git a/examples/mlp-training/mlp-training.cpp b/examples/mlp-training/mlp-training.cpp
new file mode 100644
index 000000000..38090578f
--- /dev/null
+++ b/examples/mlp-training/mlp-training.cpp
@@ -0,0 +1,389 @@
+// In this example, we implement a simple multi-layer perceptron (MLP) training loop on
+// Vulkan (through slang-rhi). See also the mlp-training-coopvec example, which
+// implements the same MLP training loop using cooperative vector intrinsics for better
+// performance.
+//
+// The simple MLP is trained to approximate a polynomial expression.
+// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4
+// outputs.
+
+#include "core/slang-basic.h"
+#include "examples/example-base/example-base.h"
+#include "external/slang-rhi/include/slang-rhi.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+
+#include <string>
+
+using Slang::ComPtr;
+
+static const ExampleResources resourceBase("mlp-training");
+
+typedef uint16_t NFloat;
+
+static const int kLayerSizes[] = {4, 16, 4};
+static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1;
+
+int getNetworkLayerWeightStride(int i)
+{
+ return kLayerSizes[i] * sizeof(NFloat);
+}
+
+int getNetworkLayerWeightCount(int i)
+{
+ return kLayerSizes[i] * kLayerSizes[i + 1];
+}
+
+int getNetworkLayerBiasCount(int i)
+{
+ return kLayerSizes[i + 1];
+}
+
+struct Kernel
+{
+ ComPtr<rhi::IShaderProgram> program;
+ ComPtr<rhi::IComputePipeline> pipeline;
+ operator bool() { return program && pipeline; }
+};
+
+struct ClearBufferParams
+{
+ rhi::DeviceAddress buffer;
+ uint32_t count;
+};
+
+struct LearnGradParams
+{
+ rhi::DeviceAddress networkBuffer;
+ rhi::DeviceAddress lossBuffer;
+ rhi::DeviceAddress inputs;
+ uint32_t count;
+};
+
+struct AdjustParamsParams
+{
+ rhi::DeviceAddress adamStates;
+ rhi::DeviceAddress params;
+ rhi::DeviceAddress gradients;
+ uint32_t count;
+};
+
+struct ExampleProgram : public TestBase
+{
+ ComPtr<rhi::IDevice> gDevice;
+
+ ComPtr<slang::ISession> gSlangSession;
+ ComPtr<slang::IModule> gSlangModule;
+ Kernel gLearnGradProgram;
+ Kernel gAdjustParamProgram;
+
+ // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients).
+ //
+ struct NetworkParameterAllocation
+ {
+ size_t weightsOffset;
+ size_t weightsSize;
+ size_t biasOffset;
+ size_t biasSize;
+ size_t weightsGradOffset;
+ size_t biasGradOffset;
+ };
+
+ SlangResult execute(int argc, char* argv[])
+ {
+ parseOption(argc, argv);
+
+ rhi::DeviceDesc deviceDesc;
+ deviceDesc.slang.targetProfile = "spirv_1_6";
+ deviceDesc.deviceType = rhi::DeviceType::Vulkan;
+
+ gDevice = rhi::getRHI()->createDevice(deviceDesc);
+ if (!gDevice)
+ return SLANG_FAIL;
+
+ SLANG_RETURN_ON_FAIL(loadShaderKernels());
+
+ // Create a buffer to hold all network parameters (weights, biases, gradients).
+ // This buffer is arranged as following:
+ // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN |
+ // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... |
+ //
+ // Where the first segment contains all weights and biases for each layer in row-major
+ // layout. The second segment contains gradients for weights and biases in row-major layout.
+
+ // Total size of all network parameters.
+ size_t paramBufferSize;
+
+ // Offset for the second segment, where gradients for weights and biases in row-major layout
+ // start.
+ size_t gradientOffset;
+
+ // Sub-allocated weight/Bias offsets for each layer.
+ std::vector<NetworkParameterAllocation> layerAllocations;
+ allocateNetworkParameterStorage(layerAllocations, paramBufferSize, gradientOffset);
+
+ std::vector<uint16_t> initParams;
+ srand(1072);
+ for (int i = 0; i < paramBufferSize / sizeof(NFloat); i++)
+ {
+ if (i < gradientOffset / sizeof(NFloat))
+ {
+ float v = rand() / (float)RAND_MAX;
+ v = v * 2.0f - 1.0f; // Normalize to [-1, 1]
+ initParams.push_back(floatToHalf(v));
+ }
+ else
+ {
+ // Initialize gradients to zero.
+ initParams.push_back(0);
+ }
+ }
+ auto networkParamsBuffer = createBuffer(paramBufferSize, initParams.data());
+
+ static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t);
+ auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize);
+ clearBuffer(adamStateBuffer);
+
+ std::vector<uint64_t> networkConstantBufferData;
+ auto paramBufferAddr = networkParamsBuffer->getDeviceAddress();
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].weightsOffset);
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].weightsGradOffset);
+ networkConstantBufferData.push_back(paramBufferAddr + layerAllocations[i].biasOffset);
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].biasGradOffset);
+ }
+ auto networkConstantBuffer = createBuffer(
+ networkConstantBufferData.size() * sizeof(uint64_t),
+ networkConstantBufferData.data());
+
+ static const int inputCount = 32;
+ std::vector<float> inputBufferData;
+ for (int i = 0; i < inputCount; i++)
+ {
+ inputBufferData.push_back((float)rand() / RAND_MAX);
+ }
+ auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data());
+
+ // Create buffer for receiving current loss value.
+ auto lossBuffer = createBuffer(sizeof(uint64_t));
+
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+
+ for (int k = 0; k < 1000; k++)
+ {
+ clearBuffer(lossBuffer);
+
+ // Compute gradients.
+ {
+ LearnGradParams entryPointParams = {};
+ entryPointParams.inputs = inputBuffer->getDeviceAddress();
+ entryPointParams.count = inputCount / 2;
+ entryPointParams.lossBuffer = lossBuffer->getDeviceAddress();
+ entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress();
+ dispatchKernel(
+ gLearnGradProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+ // Adjust parameters in row-major buffer (adam optimize).
+ {
+ AdjustParamsParams entryPointParams = {};
+ entryPointParams.adamStates = adamStateBuffer->getDeviceAddress();
+ entryPointParams.params = networkParamsBuffer->getDeviceAddress();
+ entryPointParams.count = (paramBufferSize - gradientOffset) / sizeof(NFloat);
+ entryPointParams.gradients =
+ networkParamsBuffer->getDeviceAddress() + gradientOffset;
+ dispatchKernel(
+ gAdjustParamProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+ if ((k + 1) % 10 == 0)
+ {
+ queue->waitOnHost();
+ ComPtr<ISlangBlob> blob;
+ gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef());
+ printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer());
+ }
+ }
+ return SLANG_OK;
+ }
+
+ // Allocate storage for network parameters, including weights, biases, and gradients.
+ void allocateNetworkParameterStorage(
+ std::vector<NetworkParameterAllocation>& paramStorage,
+ size_t& outParamBufferSize,
+ size_t& outGradientOffset)
+ {
+ outParamBufferSize = 0;
+
+ auto allocRowMajorStorage = [&](size_t size)
+ {
+ size = (size + 63) / 64 * 64;
+ size_t offset = outParamBufferSize;
+ outParamBufferSize += size;
+ return offset;
+ };
+
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat);
+ NetworkParameterAllocation layer = {};
+ layer.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat);
+ layer.weightsOffset = allocRowMajorStorage(layer.weightsSize);
+ layer.biasSize = biasSize;
+ layer.biasOffset = allocRowMajorStorage(biasSize);
+ paramStorage.push_back(layer);
+ }
+
+ // Alloc storage for gradients.
+ outGradientOffset = outParamBufferSize;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize);
+ paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize);
+ }
+ }
+
+ template<typename Args>
+ void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ ComPtr<rhi::ICommandEncoder> encoder;
+ queue->createCommandEncoder(encoder.writeRef());
+ {
+ auto computeEncoder = encoder->beginComputePass();
+ auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get());
+ rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args));
+ computeEncoder->dispatchCompute(numWorkGroups, 1, 1);
+ computeEncoder->end();
+ }
+ ComPtr<rhi::ICommandBuffer> commandBuffer;
+ encoder->finish(commandBuffer.writeRef());
+ queue->submit(commandBuffer);
+ }
+
+ // Create a buffer with the specified size and optional initial data.
+ ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr)
+ {
+ rhi::BufferDesc bufferDesc = {};
+ bufferDesc.size = size;
+ bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess;
+ bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination |
+ rhi::BufferUsage::UnorderedAccess;
+ bufferDesc.memoryType = rhi::MemoryType::DeviceLocal;
+ return gDevice->createBuffer(bufferDesc, initData);
+ }
+
+ void clearBuffer(rhi::IBuffer* buffer)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ auto encoder = queue->createCommandEncoder();
+ encoder->clearBuffer(buffer);
+ auto cmdBuffer = encoder->finish();
+ queue->submit(cmdBuffer);
+ }
+
+ Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName)
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef());
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ entryPoint->link(linkedProgram.writeRef());
+
+ if (isTestMode())
+ {
+ printEntrypointHashes(1, 1, linkedProgram);
+ }
+
+ Kernel result;
+
+ rhi::ComputePipelineDesc desc;
+ auto program = gDevice->createShaderProgram(linkedProgram);
+ desc.program = program.get();
+ result.program = program;
+ result.pipeline = gDevice->createComputePipeline(desc);
+ return result;
+ }
+
+ inline unsigned short floatToHalf(float val)
+ {
+ uint32_t x = 0;
+ memcpy(&x, &val, sizeof(float));
+
+ unsigned short bits = (x >> 16) & 0x8000;
+ unsigned short m = (x >> 12) & 0x07ff;
+ unsigned int e = (x >> 23) & 0xff;
+ if (e < 103)
+ return bits;
+ if (e > 142)
+ {
+ bits |= 0x7c00u;
+ bits |= e == 255 && (x & 0x007fffffu);
+ return bits;
+ }
+ if (e < 113)
+ {
+ m |= 0x0800u;
+ bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
+ return bits;
+ }
+ bits |= ((e - 112) << 10) | (m >> 1);
+ bits += m & 1;
+ return bits;
+ }
+
+ ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device)
+ {
+ ComPtr<slang::ISession> slangSession = device->getSlangSession();
+ return slangSession;
+ }
+
+ ComPtr<slang::IModule> compileShaderModuleFromFile(
+ slang::ISession* slangSession,
+ char const* filePath)
+ {
+ ComPtr<slang::IModule> slangModule;
+ ComPtr<slang::IBlob> diagnosticBlob;
+ Slang::String path = resourceBase.resolveResource(filePath);
+ slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef());
+ diagnoseIfNeeded(diagnosticBlob);
+
+ return slangModule;
+ }
+
+ SlangResult loadShaderKernels()
+ {
+ Slang::String path = resourceBase.resolveResource("kernels.slang");
+
+ gSlangSession = createSlangSession(gDevice);
+ gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer());
+ if (!gSlangModule)
+ return SLANG_FAIL;
+
+ gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient");
+ if (!gLearnGradProgram)
+ return SLANG_FAIL;
+
+ gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters");
+ if (!gAdjustParamProgram)
+ return SLANG_FAIL;
+
+ return SLANG_OK;
+ }
+};
+
+int exampleMain(int argc, char** argv)
+{
+ ExampleProgram app;
+ if (SLANG_FAILED(app.execute(argc, argv)))
+ {
+ return -1;
+ }
+ return 0;
+}