diff options
| author | Yong He <yonghe@outlook.com> | 2025-06-30 14:32:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 21:32:50 +0000 |
| commit | f28f67d988158d6c46f7ffe967152f98d32a37b2 (patch) | |
| tree | 2aa620986a87ec69cf1f210c714312e42b62ac9e /examples/mlp-training-coopvec/mlp-training-coopvec.cpp | |
| parent | a55ff722cae338a8fcf5402858c47cf0650a8e5e (diff) | |
Add MLP training examples. (#7550)
* Add MLP training examples.
* Formatting fix.
* Fix.
* Improve documentation on coopvector.
* Improve doc.
* Update doc.
* Fix typo.
* Cleanup shader.
* Cleanup.
* Fix test.
* Fix type check recursion.
* Fix.
* Fix.
* Fix override check.
Diffstat (limited to 'examples/mlp-training-coopvec/mlp-training-coopvec.cpp')
| -rw-r--r-- | examples/mlp-training-coopvec/mlp-training-coopvec.cpp | 462 |
1 files changed, 462 insertions, 0 deletions
diff --git a/examples/mlp-training-coopvec/mlp-training-coopvec.cpp b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp new file mode 100644 index 000000000..9b5dde531 --- /dev/null +++ b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp @@ -0,0 +1,462 @@ +// In this example, we implement a simple multi-layer perceptron (MLP) training loop on +// Vulkan (through slang-rhi), using cooperative vector intrinsics. +// +// The simple MLP is trained to approximate a polynomial expression. +// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4 +// outputs. + +#include "core/slang-basic.h" +#include "examples/example-base/example-base.h" +#include "external/slang-rhi/include/slang-rhi.h" +#include "slang-com-ptr.h" +#include "slang.h" + +static const ExampleResources resourceBase("mlp-training-coopvec"); + +typedef uint16_t NFloat; + +// Define the sizes of the layers in the MLP. +static const int kLayerSizes[] = {4, 16, 4}; +static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1; + +using Slang::ComPtr; + +struct Kernel +{ + ComPtr<rhi::IShaderProgram> program; + ComPtr<rhi::IComputePipeline> pipeline; + explicit operator bool() { return program && pipeline; } +}; + +struct ClearBufferParams +{ + rhi::DeviceAddress buffer; + uint32_t count; +}; + +struct LearnGradParams +{ + rhi::DeviceAddress networkBuffer; + rhi::DeviceAddress lossBuffer; + rhi::DeviceAddress inputs; + uint32_t count; +}; + +struct AdjustParamsParams +{ + rhi::DeviceAddress adamStates; + rhi::DeviceAddress params; + rhi::DeviceAddress gradients; + uint32_t count; +}; + +struct ExampleProgram : public TestBase +{ + ComPtr<rhi::IDevice> gDevice; + + ComPtr<slang::ISession> gSlangSession; + ComPtr<slang::IModule> gSlangModule; + Kernel gLearnGradProgram; + Kernel gAdjustParamProgram; + + // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients). + // + struct NetworkParameterAllocation + { + size_t weightsOffset; + size_t weightsSize; + size_t biasOffset; + size_t biasSize; + size_t weightsGradOffset; + size_t biasGradOffset; + size_t weightsGradTrainingOffset; + size_t weightsGradTrainingSize; + }; + + SlangResult execute(int argc, char* argv[]) + { + parseOption(argc, argv); + + rhi::DeviceDesc deviceDesc; + deviceDesc.slang.targetProfile = "spirv_1_6"; + deviceDesc.deviceType = rhi::DeviceType::Vulkan; + + gDevice = rhi::getRHI()->createDevice(deviceDesc); + if (!gDevice) + return SLANG_FAIL; + + SLANG_RETURN_ON_FAIL(loadShaderKernels()); + + // Create a buffer to hold all network parameters (weights, biases, gradients). + // This buffer is arranged as following: + // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN | + // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... | + // (segment 3): | weightsGradTraining0 | weightsGradTraining1 | ... | + // + // Where the first segment contains all weights and biases for each layer in row-major + // layout. The second segment contains gradients for weights and biases in row-major layout. + // The third segment contains gradients for weights in training-optimal layout. + // The training-optimal layout is used to accumulate gradients for weights with the + // `coopVecOuterProductAccumulate` intrinsic, which requires the destination to be in + // training-optimal layout. + // After accumulating gradients, we will convert them to row-major layout (i.e. copy + // them back into the second segment) so we can read them in the optimization kernel. + + // Total size of all network parameters. + size_t networkParamsBufferSize; + + // Offset for the second segment, where gradients for weights and biases in row-major layout + // start. + size_t networkGraidentOffset; + + // Offset for the third segment, where gradients for weights in training-optimal layout + // start. + size_t networkGradientTrainingOffset; + + // Sub-allocated weight/Bias offsets for each layer. + std::vector<NetworkParameterAllocation> layerAllocations; + + // Allocate storage for network parameters, filling in `layerRowMajorAllocations`, + // `networkParamsBufferSize`, `networkGraidentOffset` and `networkGradientTrainingOffset`. + // + allocateNetworkParameterStorage( + layerAllocations, + networkParamsBufferSize, + networkGraidentOffset, + networkGradientTrainingOffset); + + // We'll initialize the buffer with random values in the range [-1, 1]. + std::vector<uint16_t> initParams; + srand(1072); + for (int i = 0; i < networkParamsBufferSize / sizeof(NFloat); i++) + { + float v = rand() / (float)RAND_MAX; + v = v * 2.0f - 1.0f; // Normalize to [-1, 1] + initParams.push_back(floatToHalf(v)); + } + auto networkParamsBuffer = createBuffer(networkParamsBufferSize, initParams.data()); + + // Create a buffer for holding the Adam optimizer state for each network parameter. + static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t); + auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize); + clearBuffer(adamStateBuffer); + + // Prepare buffer for the `network` struct that holds pointers to network parameters for + // each layer. + std::vector<uint64_t> networkConstantBufferData; + for (int i = 0; i < kLayerCount; i++) + { + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].weightsOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasGradOffset); + } + auto networkConstantBuffer = createBuffer( + networkConstantBufferData.size() * sizeof(uint64_t), + networkConstantBufferData.data()); + + // Create buffer for input data. + static const int inputCount = 32; + std::vector<float> inputBufferData; + for (int i = 0; i < inputCount; i++) + { + inputBufferData.push_back((float)rand() / RAND_MAX); + } + auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data()); + + // Create buffer for receiving current loss value. + auto lossBuffer = createBuffer(sizeof(uint64_t)); + + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + + // Run training loop. + for (int k = 0; k < 1000; k++) + { + clearBuffer(lossBuffer); + + // Clear weight gradients in the parameter buffer to 0. + clearBuffer( + networkParamsBuffer, + rhi::BufferRange{ + networkGradientTrainingOffset, + networkParamsBufferSize - networkGradientTrainingOffset}); + // Compute gradients for weights and biases. + // The weight gradients are stored in the training-optimal layout. + { + LearnGradParams entryPointParams = {}; + entryPointParams.inputs = inputBuffer->getDeviceAddress(); + entryPointParams.count = inputCount / 2; + entryPointParams.lossBuffer = lossBuffer->getDeviceAddress(); + entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress(); + dispatchKernel( + gLearnGradProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + // Copy weight gradients from training-optimal layout to row-major layout, + // so we can read them in the `adjustParameters` kernel. + { + std::vector<rhi::ConvertCooperativeVectorMatrixDesc> matrixDescs; + for (int i = 0; i < kLayerCount; i++) + { + rhi::ConvertCooperativeVectorMatrixDesc desc = {}; + desc.rowCount = kLayerSizes[i + 1]; + desc.colCount = kLayerSizes[i]; + desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.dstSize = &layerAllocations[i].weightsSize; + desc.dstData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradOffset; + desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + desc.dstStride = getNetworkLayerWeightStride(i); + desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.srcSize = layerAllocations[i].weightsGradTrainingSize; + desc.srcData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset; + desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDescs.push_back(desc); + } + auto encoder = queue->createCommandEncoder(); + encoder->convertCooperativeVectorMatrix( + matrixDescs.data(), + (uint32_t)matrixDescs.size()); + ComPtr<rhi::ICommandBuffer> commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + // Adjust parameters in row-major buffer (adam optimize). + { + AdjustParamsParams entryPointParams = {}; + entryPointParams.adamStates = adamStateBuffer->getDeviceAddress(); + entryPointParams.params = networkParamsBuffer->getDeviceAddress(); + entryPointParams.count = + (networkGradientTrainingOffset - networkGraidentOffset) / sizeof(NFloat); + entryPointParams.gradients = + networkParamsBuffer->getDeviceAddress() + networkGraidentOffset; + dispatchKernel( + gAdjustParamProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + + // Print loss value every 10 iterations. + if ((k + 1) % 10 == 0) + { + queue->waitOnHost(); + ComPtr<ISlangBlob> blob; + gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef()); + printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer()); + } + } + return SLANG_OK; + } + + // Allocate storage for network parameters, including weights, biases, and gradients. + void allocateNetworkParameterStorage( + std::vector<NetworkParameterAllocation>& paramStorage, + size_t& outParamBufferSize, + size_t& outGradientOffset, + size_t& outGradientTrainingOffset) + { + outParamBufferSize = 0; + + auto allocRowMajorStorage = [&](size_t size) + { + size = (size + 63) / 64 * 64; + size_t offset = outParamBufferSize; + outParamBufferSize += size; + return offset; + }; + + for (int i = 0; i < kLayerCount; i++) + { + size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat); + NetworkParameterAllocation layerStorage = {}; + layerStorage.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat); + layerStorage.weightsOffset = allocRowMajorStorage(layerStorage.weightsSize); + layerStorage.biasSize = biasSize; + layerStorage.biasOffset = allocRowMajorStorage(biasSize); + paramStorage.push_back(layerStorage); + } + + // Alloc storage for weight and bias gradients (row major layout). + outGradientOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize); + paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize); + } + + // Alloc training-optimal storage for weight gradients. + outGradientTrainingOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + // Allocate space for gradients in training-optimal layout. + rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {}; + matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.srcSize = paramStorage[i].weightsSize; + matrixDesc.srcData.hostAddress = nullptr; + matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + matrixDesc.srcStride = getNetworkLayerWeightStride(i); + matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.dstSize = ¶mStorage[i].weightsGradTrainingSize; + matrixDesc.dstData.hostAddress = nullptr; + matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDesc.dstStride = 0; + matrixDesc.rowCount = kLayerSizes[i + 1]; + matrixDesc.colCount = kLayerSizes[i]; + gDevice->convertCooperativeVectorMatrix(&matrixDesc, 1); + paramStorage[i].weightsGradTrainingOffset = + allocRowMajorStorage(paramStorage[i].weightsGradTrainingSize); + } + } + + // Dispatch a compute kernel with the given arguments and number of work groups. + template<typename Args> + void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + ComPtr<rhi::ICommandEncoder> encoder; + queue->createCommandEncoder(encoder.writeRef()); + { + auto computeEncoder = encoder->beginComputePass(); + auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get()); + rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args)); + computeEncoder->dispatchCompute(numWorkGroups, 1, 1); + computeEncoder->end(); + } + ComPtr<rhi::ICommandBuffer> commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + + // Create a buffer with the specified size and optional initial data. + ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr) + { + rhi::BufferDesc bufferDesc = {}; + bufferDesc.size = size; + bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess; + bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination | + rhi::BufferUsage::UnorderedAccess; + bufferDesc.memoryType = rhi::MemoryType::DeviceLocal; + return gDevice->createBuffer(bufferDesc, initData); + } + + void clearBuffer(rhi::IBuffer* buffer, rhi::BufferRange range = rhi::kEntireBuffer) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + auto encoder = queue->createCommandEncoder(); + encoder->clearBuffer(buffer, range); + auto cmdBuffer = encoder->finish(); + queue->submit(cmdBuffer); + } + + SlangResult loadShaderKernels() + { + Slang::String path = resourceBase.resolveResource("kernels.slang"); + + gSlangSession = createSlangSession(gDevice); + gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer()); + if (!gSlangModule) + return SLANG_FAIL; + + gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient"); + if (!gLearnGradProgram) + return SLANG_FAIL; + + gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters"); + if (!gAdjustParamProgram) + return SLANG_FAIL; + + return SLANG_OK; + } + + Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName) + { + ComPtr<slang::IEntryPoint> entryPoint; + slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + entryPoint->link(linkedProgram.writeRef()); + + if (isTestMode()) + { + printEntrypointHashes(1, 1, linkedProgram); + } + + Kernel result; + + rhi::ComputePipelineDesc desc; + auto program = gDevice->createShaderProgram(linkedProgram); + desc.program = program.get(); + result.program = program; + result.pipeline = gDevice->createComputePipeline(desc); + return result; + } + + static inline unsigned short floatToHalf(float val) + { + uint32_t x = 0; + memcpy(&x, &val, sizeof(float)); + + unsigned short bits = (x >> 16) & 0x8000; + unsigned short m = (x >> 12) & 0x07ff; + unsigned int e = (x >> 23) & 0xff; + if (e < 103) + return bits; + if (e > 142) + { + bits |= 0x7c00u; + bits |= e == 255 && (x & 0x007fffffu); + return bits; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; + } + + int getNetworkLayerWeightStride(int i) { return kLayerSizes[i] * sizeof(NFloat); } + + int getNetworkLayerWeightCount(int i) { return kLayerSizes[i] * kLayerSizes[i + 1]; } + + int getNetworkLayerBiasCount(int i) { return kLayerSizes[i + 1]; } + + ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device) + { + ComPtr<slang::ISession> slangSession = device->getSlangSession(); + return slangSession; + } + + ComPtr<slang::IModule> compileShaderModuleFromFile( + slang::ISession* slangSession, + char const* filePath) + { + ComPtr<slang::IModule> slangModule; + ComPtr<slang::IBlob> diagnosticBlob; + Slang::String path = resourceBase.resolveResource(filePath); + slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef()); + diagnoseIfNeeded(diagnosticBlob); + + return slangModule; + } +}; + +int exampleMain(int argc, char** argv) +{ + ExampleProgram app; + if (SLANG_FAILED(app.execute(argc, argv))) + { + return -1; + } + return 0; +} |
