From f28f67d988158d6c46f7ffe967152f98d32a37b2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 30 Jun 2025 14:32:50 -0700 Subject: Add MLP training examples. (#7550) * Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check. --- examples/CMakeLists.txt | 5 + examples/mlp-training-coopvec/README.md | 6 + examples/mlp-training-coopvec/adam.slang | 38 ++ examples/mlp-training-coopvec/common.slang | 1 + examples/mlp-training-coopvec/kernels.slang | 41 ++ .../mlp-training-coopvec/mlp-training-coopvec.cpp | 462 +++++++++++++++++++++ examples/mlp-training-coopvec/mlp.slang | 73 ++++ examples/mlp-training-coopvec/mlvec.slang | 63 +++ examples/mlp-training-coopvec/network.slang | 58 +++ examples/mlp-training/README.md | 6 + examples/mlp-training/adam.slang | 38 ++ examples/mlp-training/common.slang | 1 + examples/mlp-training/kernels.slang | 41 ++ examples/mlp-training/mlp-training.cpp | 389 +++++++++++++++++ examples/mlp-training/mlp_sw.slang | 59 +++ examples/mlp-training/mlvec_sw.slang | 64 +++ examples/mlp-training/network.slang | 59 +++ source/slang/core.meta.slang | 4 +- source/slang/hlsl.meta.slang | 277 +++++++++++- source/slang/slang-ast-modifier.h | 7 + source/slang/slang-check-conversion.cpp | 11 +- source/slang/slang-check-decl.cpp | 28 +- source/slang/slang-check-expr.cpp | 5 +- source/slang/slang-check-shader.cpp | 10 + .../capability/capabilitySimplification2.slang | 10 +- 25 files changed, 1710 insertions(+), 46 deletions(-) create mode 100644 examples/mlp-training-coopvec/README.md create mode 100644 examples/mlp-training-coopvec/adam.slang create mode 100644 examples/mlp-training-coopvec/common.slang create mode 100644 examples/mlp-training-coopvec/kernels.slang create mode 100644 examples/mlp-training-coopvec/mlp-training-coopvec.cpp create mode 100644 examples/mlp-training-coopvec/mlp.slang create mode 100644 examples/mlp-training-coopvec/mlvec.slang create mode 100644 examples/mlp-training-coopvec/network.slang create mode 100644 examples/mlp-training/README.md create mode 100644 examples/mlp-training/adam.slang create mode 100644 examples/mlp-training/common.slang create mode 100644 examples/mlp-training/kernels.slang create mode 100644 examples/mlp-training/mlp-training.cpp create mode 100644 examples/mlp-training/mlp_sw.slang create mode 100644 examples/mlp-training/mlvec_sw.slang create mode 100644 examples/mlp-training/network.slang diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9411d10ba..87b0b39f8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -114,4 +114,9 @@ if(SLANG_ENABLE_EXAMPLES) if(SLANG_ENABLE_AFTERMATH) example(nv-aftermath-example WIN32_EXECUTABLE) endif() + + if(SLANG_ENABLE_SLANG_RHI) + example(mlp-training LINK_WITH_PRIVATE slang-rhi) + example(mlp-training-coopvec LINK_WITH_PRIVATE slang-rhi) + endif() endif() diff --git a/examples/mlp-training-coopvec/README.md b/examples/mlp-training-coopvec/README.md new file mode 100644 index 000000000..8d56fff92 --- /dev/null +++ b/examples/mlp-training-coopvec/README.md @@ -0,0 +1,6 @@ +Slang "MLP-Training-CoopVec" Example +========================== + +This example shows how to use the Slang to train a feed-forward neural network +using automatic differentiation and the cooperative vector intrinsics. Also see the +"MLP-Training" example for the same task without using cooperative vector. \ No newline at end of file diff --git a/examples/mlp-training-coopvec/adam.slang b/examples/mlp-training-coopvec/adam.slang new file mode 100644 index 000000000..33a357122 --- /dev/null +++ b/examples/mlp-training-coopvec/adam.slang @@ -0,0 +1,38 @@ +module adam; + +import mlp; +import common; + +public struct AdamState +{ + internal NFloat mean; + internal NFloat variance; + internal int iteration; +} + +public struct AdamOptimizer +{ + // Adam parameters + public static const NFloat beta1 = 0.9h; + public static const NFloat beta2 = 0.999h; + public static const NFloat epsilon = 1e-7h; + public static const NFloat learningRate = 0.01h; + + public static void step(inout AdamState state, inout NFloat param, inout NFloat grad) + { + state.iteration++; + if (isinf(grad)) + { + if (grad > 0) + grad = 10000.0h; + else + grad = -10000.0h; + } + state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad; + state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad; + NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration))); + NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration))); + param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon)); + grad = NFloat(0.f); + } +} diff --git a/examples/mlp-training-coopvec/common.slang b/examples/mlp-training-coopvec/common.slang new file mode 100644 index 000000000..92dc3b563 --- /dev/null +++ b/examples/mlp-training-coopvec/common.slang @@ -0,0 +1 @@ +public typealias NFloat = half; \ No newline at end of file diff --git a/examples/mlp-training-coopvec/kernels.slang b/examples/mlp-training-coopvec/kernels.slang new file mode 100644 index 000000000..712494b1f --- /dev/null +++ b/examples/mlp-training-coopvec/kernels.slang @@ -0,0 +1,41 @@ +module kernels; + +import common; +import mlp; +import network; +import adam; + +[numthreads(256, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvCooperativeVectorNV)] +void learnGradient( + uint32_t tid : SV_DispatchThreadID, + uniform MyNetwork* network, + uniform Atomic* lossBuffer, + uniform float2* inputs, + uniform uint32_t count) +{ + if (tid >= count) + return; + + var input = (half2)inputs[tid]; + bwd_diff(loss)(network, input.x, input.y, 1.0h); + let thisLoss = (float)loss(network, input.x, input.y); + let maxLoss = WaveActiveMax(thisLoss); + if (WaveIsFirstLane()) + { + lossBuffer.max(bit_cast(maxLoss)); + } +} + +[numthreads(256, 1, 1)] +void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count) +{ + if (tid >= count) + return; + if (isnan(gradients[tid])) + { + gradients[tid] = 0.0h; + return; + } + AdamOptimizer::step(states[tid], params[tid], gradients[tid]); +} \ No newline at end of file diff --git a/examples/mlp-training-coopvec/mlp-training-coopvec.cpp b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp new file mode 100644 index 000000000..9b5dde531 --- /dev/null +++ b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp @@ -0,0 +1,462 @@ +// In this example, we implement a simple multi-layer perceptron (MLP) training loop on +// Vulkan (through slang-rhi), using cooperative vector intrinsics. +// +// The simple MLP is trained to approximate a polynomial expression. +// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4 +// outputs. + +#include "core/slang-basic.h" +#include "examples/example-base/example-base.h" +#include "external/slang-rhi/include/slang-rhi.h" +#include "slang-com-ptr.h" +#include "slang.h" + +static const ExampleResources resourceBase("mlp-training-coopvec"); + +typedef uint16_t NFloat; + +// Define the sizes of the layers in the MLP. +static const int kLayerSizes[] = {4, 16, 4}; +static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1; + +using Slang::ComPtr; + +struct Kernel +{ + ComPtr program; + ComPtr pipeline; + explicit operator bool() { return program && pipeline; } +}; + +struct ClearBufferParams +{ + rhi::DeviceAddress buffer; + uint32_t count; +}; + +struct LearnGradParams +{ + rhi::DeviceAddress networkBuffer; + rhi::DeviceAddress lossBuffer; + rhi::DeviceAddress inputs; + uint32_t count; +}; + +struct AdjustParamsParams +{ + rhi::DeviceAddress adamStates; + rhi::DeviceAddress params; + rhi::DeviceAddress gradients; + uint32_t count; +}; + +struct ExampleProgram : public TestBase +{ + ComPtr gDevice; + + ComPtr gSlangSession; + ComPtr gSlangModule; + Kernel gLearnGradProgram; + Kernel gAdjustParamProgram; + + // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients). + // + struct NetworkParameterAllocation + { + size_t weightsOffset; + size_t weightsSize; + size_t biasOffset; + size_t biasSize; + size_t weightsGradOffset; + size_t biasGradOffset; + size_t weightsGradTrainingOffset; + size_t weightsGradTrainingSize; + }; + + SlangResult execute(int argc, char* argv[]) + { + parseOption(argc, argv); + + rhi::DeviceDesc deviceDesc; + deviceDesc.slang.targetProfile = "spirv_1_6"; + deviceDesc.deviceType = rhi::DeviceType::Vulkan; + + gDevice = rhi::getRHI()->createDevice(deviceDesc); + if (!gDevice) + return SLANG_FAIL; + + SLANG_RETURN_ON_FAIL(loadShaderKernels()); + + // Create a buffer to hold all network parameters (weights, biases, gradients). + // This buffer is arranged as following: + // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN | + // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... | + // (segment 3): | weightsGradTraining0 | weightsGradTraining1 | ... | + // + // Where the first segment contains all weights and biases for each layer in row-major + // layout. The second segment contains gradients for weights and biases in row-major layout. + // The third segment contains gradients for weights in training-optimal layout. + // The training-optimal layout is used to accumulate gradients for weights with the + // `coopVecOuterProductAccumulate` intrinsic, which requires the destination to be in + // training-optimal layout. + // After accumulating gradients, we will convert them to row-major layout (i.e. copy + // them back into the second segment) so we can read them in the optimization kernel. + + // Total size of all network parameters. + size_t networkParamsBufferSize; + + // Offset for the second segment, where gradients for weights and biases in row-major layout + // start. + size_t networkGraidentOffset; + + // Offset for the third segment, where gradients for weights in training-optimal layout + // start. + size_t networkGradientTrainingOffset; + + // Sub-allocated weight/Bias offsets for each layer. + std::vector layerAllocations; + + // Allocate storage for network parameters, filling in `layerRowMajorAllocations`, + // `networkParamsBufferSize`, `networkGraidentOffset` and `networkGradientTrainingOffset`. + // + allocateNetworkParameterStorage( + layerAllocations, + networkParamsBufferSize, + networkGraidentOffset, + networkGradientTrainingOffset); + + // We'll initialize the buffer with random values in the range [-1, 1]. + std::vector initParams; + srand(1072); + for (int i = 0; i < networkParamsBufferSize / sizeof(NFloat); i++) + { + float v = rand() / (float)RAND_MAX; + v = v * 2.0f - 1.0f; // Normalize to [-1, 1] + initParams.push_back(floatToHalf(v)); + } + auto networkParamsBuffer = createBuffer(networkParamsBufferSize, initParams.data()); + + // Create a buffer for holding the Adam optimizer state for each network parameter. + static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t); + auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize); + clearBuffer(adamStateBuffer); + + // Prepare buffer for the `network` struct that holds pointers to network parameters for + // each layer. + std::vector networkConstantBufferData; + for (int i = 0; i < kLayerCount; i++) + { + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].weightsOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasGradOffset); + } + auto networkConstantBuffer = createBuffer( + networkConstantBufferData.size() * sizeof(uint64_t), + networkConstantBufferData.data()); + + // Create buffer for input data. + static const int inputCount = 32; + std::vector inputBufferData; + for (int i = 0; i < inputCount; i++) + { + inputBufferData.push_back((float)rand() / RAND_MAX); + } + auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data()); + + // Create buffer for receiving current loss value. + auto lossBuffer = createBuffer(sizeof(uint64_t)); + + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + + // Run training loop. + for (int k = 0; k < 1000; k++) + { + clearBuffer(lossBuffer); + + // Clear weight gradients in the parameter buffer to 0. + clearBuffer( + networkParamsBuffer, + rhi::BufferRange{ + networkGradientTrainingOffset, + networkParamsBufferSize - networkGradientTrainingOffset}); + // Compute gradients for weights and biases. + // The weight gradients are stored in the training-optimal layout. + { + LearnGradParams entryPointParams = {}; + entryPointParams.inputs = inputBuffer->getDeviceAddress(); + entryPointParams.count = inputCount / 2; + entryPointParams.lossBuffer = lossBuffer->getDeviceAddress(); + entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress(); + dispatchKernel( + gLearnGradProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + // Copy weight gradients from training-optimal layout to row-major layout, + // so we can read them in the `adjustParameters` kernel. + { + std::vector matrixDescs; + for (int i = 0; i < kLayerCount; i++) + { + rhi::ConvertCooperativeVectorMatrixDesc desc = {}; + desc.rowCount = kLayerSizes[i + 1]; + desc.colCount = kLayerSizes[i]; + desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.dstSize = &layerAllocations[i].weightsSize; + desc.dstData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradOffset; + desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + desc.dstStride = getNetworkLayerWeightStride(i); + desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.srcSize = layerAllocations[i].weightsGradTrainingSize; + desc.srcData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset; + desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDescs.push_back(desc); + } + auto encoder = queue->createCommandEncoder(); + encoder->convertCooperativeVectorMatrix( + matrixDescs.data(), + (uint32_t)matrixDescs.size()); + ComPtr commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + // Adjust parameters in row-major buffer (adam optimize). + { + AdjustParamsParams entryPointParams = {}; + entryPointParams.adamStates = adamStateBuffer->getDeviceAddress(); + entryPointParams.params = networkParamsBuffer->getDeviceAddress(); + entryPointParams.count = + (networkGradientTrainingOffset - networkGraidentOffset) / sizeof(NFloat); + entryPointParams.gradients = + networkParamsBuffer->getDeviceAddress() + networkGraidentOffset; + dispatchKernel( + gAdjustParamProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + + // Print loss value every 10 iterations. + if ((k + 1) % 10 == 0) + { + queue->waitOnHost(); + ComPtr blob; + gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef()); + printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer()); + } + } + return SLANG_OK; + } + + // Allocate storage for network parameters, including weights, biases, and gradients. + void allocateNetworkParameterStorage( + std::vector& paramStorage, + size_t& outParamBufferSize, + size_t& outGradientOffset, + size_t& outGradientTrainingOffset) + { + outParamBufferSize = 0; + + auto allocRowMajorStorage = [&](size_t size) + { + size = (size + 63) / 64 * 64; + size_t offset = outParamBufferSize; + outParamBufferSize += size; + return offset; + }; + + for (int i = 0; i < kLayerCount; i++) + { + size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat); + NetworkParameterAllocation layerStorage = {}; + layerStorage.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat); + layerStorage.weightsOffset = allocRowMajorStorage(layerStorage.weightsSize); + layerStorage.biasSize = biasSize; + layerStorage.biasOffset = allocRowMajorStorage(biasSize); + paramStorage.push_back(layerStorage); + } + + // Alloc storage for weight and bias gradients (row major layout). + outGradientOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize); + paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize); + } + + // Alloc training-optimal storage for weight gradients. + outGradientTrainingOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + // Allocate space for gradients in training-optimal layout. + rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {}; + matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.srcSize = paramStorage[i].weightsSize; + matrixDesc.srcData.hostAddress = nullptr; + matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + matrixDesc.srcStride = getNetworkLayerWeightStride(i); + matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.dstSize = ¶mStorage[i].weightsGradTrainingSize; + matrixDesc.dstData.hostAddress = nullptr; + matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDesc.dstStride = 0; + matrixDesc.rowCount = kLayerSizes[i + 1]; + matrixDesc.colCount = kLayerSizes[i]; + gDevice->convertCooperativeVectorMatrix(&matrixDesc, 1); + paramStorage[i].weightsGradTrainingOffset = + allocRowMajorStorage(paramStorage[i].weightsGradTrainingSize); + } + } + + // Dispatch a compute kernel with the given arguments and number of work groups. + template + void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + ComPtr encoder; + queue->createCommandEncoder(encoder.writeRef()); + { + auto computeEncoder = encoder->beginComputePass(); + auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get()); + rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args)); + computeEncoder->dispatchCompute(numWorkGroups, 1, 1); + computeEncoder->end(); + } + ComPtr commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + + // Create a buffer with the specified size and optional initial data. + ComPtr createBuffer(size_t size, void* initData = nullptr) + { + rhi::BufferDesc bufferDesc = {}; + bufferDesc.size = size; + bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess; + bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination | + rhi::BufferUsage::UnorderedAccess; + bufferDesc.memoryType = rhi::MemoryType::DeviceLocal; + return gDevice->createBuffer(bufferDesc, initData); + } + + void clearBuffer(rhi::IBuffer* buffer, rhi::BufferRange range = rhi::kEntireBuffer) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + auto encoder = queue->createCommandEncoder(); + encoder->clearBuffer(buffer, range); + auto cmdBuffer = encoder->finish(); + queue->submit(cmdBuffer); + } + + SlangResult loadShaderKernels() + { + Slang::String path = resourceBase.resolveResource("kernels.slang"); + + gSlangSession = createSlangSession(gDevice); + gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer()); + if (!gSlangModule) + return SLANG_FAIL; + + gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient"); + if (!gLearnGradProgram) + return SLANG_FAIL; + + gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters"); + if (!gAdjustParamProgram) + return SLANG_FAIL; + + return SLANG_OK; + } + + Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName) + { + ComPtr entryPoint; + slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()); + + ComPtr linkedProgram; + entryPoint->link(linkedProgram.writeRef()); + + if (isTestMode()) + { + printEntrypointHashes(1, 1, linkedProgram); + } + + Kernel result; + + rhi::ComputePipelineDesc desc; + auto program = gDevice->createShaderProgram(linkedProgram); + desc.program = program.get(); + result.program = program; + result.pipeline = gDevice->createComputePipeline(desc); + return result; + } + + static inline unsigned short floatToHalf(float val) + { + uint32_t x = 0; + memcpy(&x, &val, sizeof(float)); + + unsigned short bits = (x >> 16) & 0x8000; + unsigned short m = (x >> 12) & 0x07ff; + unsigned int e = (x >> 23) & 0xff; + if (e < 103) + return bits; + if (e > 142) + { + bits |= 0x7c00u; + bits |= e == 255 && (x & 0x007fffffu); + return bits; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; + } + + int getNetworkLayerWeightStride(int i) { return kLayerSizes[i] * sizeof(NFloat); } + + int getNetworkLayerWeightCount(int i) { return kLayerSizes[i] * kLayerSizes[i + 1]; } + + int getNetworkLayerBiasCount(int i) { return kLayerSizes[i + 1]; } + + ComPtr createSlangSession(rhi::IDevice* device) + { + ComPtr slangSession = device->getSlangSession(); + return slangSession; + } + + ComPtr compileShaderModuleFromFile( + slang::ISession* slangSession, + char const* filePath) + { + ComPtr slangModule; + ComPtr diagnosticBlob; + Slang::String path = resourceBase.resolveResource(filePath); + slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef()); + diagnoseIfNeeded(diagnosticBlob); + + return slangModule; + } +}; + +int exampleMain(int argc, char** argv) +{ + ExampleProgram app; + if (SLANG_FAILED(app.execute(argc, argv))) + { + return -1; + } + return 0; +} diff --git a/examples/mlp-training-coopvec/mlp.slang b/examples/mlp-training-coopvec/mlp.slang new file mode 100644 index 000000000..44380ad77 --- /dev/null +++ b/examples/mlp-training-coopvec/mlp.slang @@ -0,0 +1,73 @@ +module mlp; + +import common; + +__include mlvec; + +// We use Float16 for the CoopVec component type since it is more widely supported. +// +static const CoopVecComponentType kComponentType = CoopVecComponentType.Float16; + +public struct FeedForwardLayer +{ + internal void* weights; + internal void* weightsGrad; + internal void* biases; + internal void* biasesGrad; + + public MLVec eval(MLVec input) + { + // Compute mul(weights, inputVec) + biases. + // `weights` is treated as an OutputSize(row) x InputSize(col) matrix. + var output = coopVecMatMulAdd( + input.data, kComponentType, // input and format + weights, kComponentType, // weights and format + biases, kComponentType, // biases and format + CoopVecMatrixLayout.RowMajor, // matrix layout + false, // transpose matrix? must be `false` since we specified RowMajor. + InputSize * sizeof(NFloat)); // matrix stride + output = max(output, output * 0.001h); // Leaky ReLU activation + return {output}; + } + + [BackwardDerivativeOf(eval)] + public void evalBwd( + inout DifferentialPair> input, + MLVec resultGrad) + { + let fwd = eval(input.p); + + // Back-prop resultGrad through activation. + [ForceUnroll] + for (int i = 0; i < OutputSize; i++) + { + if (fwd.data[i] < 0.0) + resultGrad.data[i] *= 0.01h; + } + + // Back-prop gradients to the weights matrix. + coopVecOuterProductAccumulate( + resultGrad.data, + input.p.data, + weightsGrad, + 0, // matrixStride, ignored since layout is TrainingOptimal + CoopVecMatrixLayout.TrainingOptimal, // matrix layout, must be TrainingOptimal. + kComponentType); + + // Back-prop gradients to the biases vector. + coopVecReduceSumAccumulate(resultGrad.data, (void*)biasesGrad); + + // Back-prop gradients to the input vector by computing + // mul(transpose(weights), resultGrad). + // By specifying the matrix layout as ColumnMajor, we can + // achieve the effect of transposing the weights matrix. + let dInput = coopVecMatMul( + resultGrad.data, kComponentType, + weights, kComponentType, + CoopVecMatrixLayout.ColumnMajor, + false, // transpose, must be `false` since we specified ColumnMajor. + InputSize * sizeof(NFloat)); + + input = {input.p, {dInput}}; + } +} diff --git a/examples/mlp-training-coopvec/mlvec.slang b/examples/mlp-training-coopvec/mlvec.slang new file mode 100644 index 000000000..ce7ce8352 --- /dev/null +++ b/examples/mlp-training-coopvec/mlvec.slang @@ -0,0 +1,63 @@ +implementing mlp; + +// A wrapper of CoopVec to allow it being used in differentiable context. +// +public struct MLVec : IDifferentiable +{ + public CoopVec data; + public typealias Differential = MLVec; + + public static MLVec fromArray(NFloat[N] values) + { + MLVec result; + [ForceUnroll] + for (int i = 0; i < N; i++) + result.data[i] = values[i]; + return result; + } + + internal static NFloat[N] coopVecToArray(CoopVec v) + { + NFloat[N] arr; + [ForceUnroll] + for (int i = 0; i < N; i++) + arr[i] = v[i]; + return arr; + } + + [BackwardDerivativeOf(fromArray)] + internal static void fromArrayBwd(inout DifferentialPair values, MLVec dResult) + { + values = diffPair(values.p, coopVecToArray(dResult.data)); + } + + internal static NFloat[N] toArray(MLVec vec) + { + return coopVecToArray(vec.data); + } + + [BackwardDerivativeOf(toArray)] + internal static void toArrayBwd(inout DifferentialPair> vec, NFloat[N] dResult) + { + vec = diffPair(vec.p, MLVec.fromArray(dResult)); + } + + [Differentiable] + public NFloat[N] toArray() + { + return toArray(this); + } + + public override static Differential dadd(Differential d0, Differential d1) + { + return {d0.data + d1.data}; + } + public override static Differential dmul(U s, Differential d) + { + return {d.data * __realCast(s)}; + } + public override static Differential dzero() + { + return {}; + } +} diff --git a/examples/mlp-training-coopvec/network.slang b/examples/mlp-training-coopvec/network.slang new file mode 100644 index 000000000..5741487c4 --- /dev/null +++ b/examples/mlp-training-coopvec/network.slang @@ -0,0 +1,58 @@ +module network; + +import common; +import mlp; + +public struct MyNetwork +{ + public FeedForwardLayer<4, 16> layer1; + public FeedForwardLayer<16, 4> layer2; + + [Differentiable] + internal MLVec<4> encodeInput(NFloat x, NFloat y) + { + return MLVec<4>.fromArray({ + x, + y, + x*x, + y*y, + }); + } + + [Differentiable] + internal MLVec<4> _eval(NFloat x, NFloat y) + { + let encoding = encodeInput(x, y); + let layer1Output = layer1.eval(encoding); + let leyer2Output = layer2.eval(layer1Output); + return leyer2Output; + } + + [Differentiable] + public half4 eval(no_diff NFloat x, no_diff NFloat y) + { + let mlv = _eval(x, y); + let arr = mlv.toArray(); + return half4(arr[0], arr[1], arr[2], arr[3]); + } +} + +[Differentiable] +public half loss(MyNetwork* network, no_diff half x, no_diff half y) +{ + let networkResult = network.eval(x, y); + let gt = no_diff groundtruth(x, y); + let diff = networkResult - gt; + return dot(diff, diff); +} + +public half4 groundtruth(half x, half y) +{ + return { + (x + y) / (1 + y * y), + 2 * x + y, + 0.5 * x * x + 1.2 * y, + x + 0.5 * y * y, + }; +} + diff --git a/examples/mlp-training/README.md b/examples/mlp-training/README.md new file mode 100644 index 000000000..c5266bbf1 --- /dev/null +++ b/examples/mlp-training/README.md @@ -0,0 +1,6 @@ +Slang "MLP-Training" Example +========================== + +This example shows how to use the Slang to train a feed-forward neural network +using automatic differentiation. Also see the "MLP-Training-CoopVec" example +to see how to use the cooperative vector intrinsics to speedup training. \ No newline at end of file diff --git a/examples/mlp-training/adam.slang b/examples/mlp-training/adam.slang new file mode 100644 index 000000000..8f9b15f01 --- /dev/null +++ b/examples/mlp-training/adam.slang @@ -0,0 +1,38 @@ +module adam; + +import mlp_sw; +import common; + +public struct AdamState +{ + internal NFloat mean; + internal NFloat variance; + internal int iteration; +} + +public struct AdamOptimizer +{ + // Adam parameters + public static const NFloat beta1 = 0.9h; + public static const NFloat beta2 = 0.999h; + public static const NFloat epsilon = 1e-7h; + public static const NFloat learningRate = 0.01h; + + public static void step(inout AdamState state, inout NFloat param, inout NFloat grad) + { + state.iteration++; + if (isinf(grad)) + { + if (grad > 0) + grad = 10000.0h; + else + grad = -10000.0h; + } + state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad; + state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad; + NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration))); + NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration))); + param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon)); + grad = NFloat(0.f); + } +} diff --git a/examples/mlp-training/common.slang b/examples/mlp-training/common.slang new file mode 100644 index 000000000..92dc3b563 --- /dev/null +++ b/examples/mlp-training/common.slang @@ -0,0 +1 @@ +public typealias NFloat = half; \ No newline at end of file diff --git a/examples/mlp-training/kernels.slang b/examples/mlp-training/kernels.slang new file mode 100644 index 000000000..5be076879 --- /dev/null +++ b/examples/mlp-training/kernels.slang @@ -0,0 +1,41 @@ +module kernels; + +import common; +import mlp_sw; +import network; +import adam; + +[numthreads(256, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)] +void learnGradient( + uint32_t tid : SV_DispatchThreadID, + uniform MyNetwork* network, + uniform Atomic* lossBuffer, + uniform float2* inputs, + uniform uint32_t count) +{ + if (tid >= count) + return; + + var input = (half2)inputs[tid]; + bwd_diff(loss)(network, input.x, input.y, 1.0h); + let thisLoss = (float)loss(network, input.x, input.y); + let maxLoss = WaveActiveMax(thisLoss); + if (WaveIsFirstLane()) + { + lossBuffer.max(bit_cast(maxLoss)); + } +} + +[numthreads(256, 1, 1)] +void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count) +{ + if (tid >= count) + return; + if (isnan(gradients[tid])) + { + gradients[tid] = 0.0h; + return; + } + AdamOptimizer::step(states[tid], params[tid], gradients[tid]); +} \ No newline at end of file diff --git a/examples/mlp-training/mlp-training.cpp b/examples/mlp-training/mlp-training.cpp new file mode 100644 index 000000000..38090578f --- /dev/null +++ b/examples/mlp-training/mlp-training.cpp @@ -0,0 +1,389 @@ +// In this example, we implement a simple multi-layer perceptron (MLP) training loop on +// Vulkan (through slang-rhi). See also the mlp-training-coopvec example, which +// implements the same MLP training loop using cooperative vector intrinsics for better +// performance. +// +// The simple MLP is trained to approximate a polynomial expression. +// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4 +// outputs. + +#include "core/slang-basic.h" +#include "examples/example-base/example-base.h" +#include "external/slang-rhi/include/slang-rhi.h" +#include "slang-com-ptr.h" +#include "slang.h" + +#include + +using Slang::ComPtr; + +static const ExampleResources resourceBase("mlp-training"); + +typedef uint16_t NFloat; + +static const int kLayerSizes[] = {4, 16, 4}; +static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1; + +int getNetworkLayerWeightStride(int i) +{ + return kLayerSizes[i] * sizeof(NFloat); +} + +int getNetworkLayerWeightCount(int i) +{ + return kLayerSizes[i] * kLayerSizes[i + 1]; +} + +int getNetworkLayerBiasCount(int i) +{ + return kLayerSizes[i + 1]; +} + +struct Kernel +{ + ComPtr program; + ComPtr pipeline; + operator bool() { return program && pipeline; } +}; + +struct ClearBufferParams +{ + rhi::DeviceAddress buffer; + uint32_t count; +}; + +struct LearnGradParams +{ + rhi::DeviceAddress networkBuffer; + rhi::DeviceAddress lossBuffer; + rhi::DeviceAddress inputs; + uint32_t count; +}; + +struct AdjustParamsParams +{ + rhi::DeviceAddress adamStates; + rhi::DeviceAddress params; + rhi::DeviceAddress gradients; + uint32_t count; +}; + +struct ExampleProgram : public TestBase +{ + ComPtr gDevice; + + ComPtr gSlangSession; + ComPtr gSlangModule; + Kernel gLearnGradProgram; + Kernel gAdjustParamProgram; + + // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients). + // + struct NetworkParameterAllocation + { + size_t weightsOffset; + size_t weightsSize; + size_t biasOffset; + size_t biasSize; + size_t weightsGradOffset; + size_t biasGradOffset; + }; + + SlangResult execute(int argc, char* argv[]) + { + parseOption(argc, argv); + + rhi::DeviceDesc deviceDesc; + deviceDesc.slang.targetProfile = "spirv_1_6"; + deviceDesc.deviceType = rhi::DeviceType::Vulkan; + + gDevice = rhi::getRHI()->createDevice(deviceDesc); + if (!gDevice) + return SLANG_FAIL; + + SLANG_RETURN_ON_FAIL(loadShaderKernels()); + + // Create a buffer to hold all network parameters (weights, biases, gradients). + // This buffer is arranged as following: + // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN | + // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... | + // + // Where the first segment contains all weights and biases for each layer in row-major + // layout. The second segment contains gradients for weights and biases in row-major layout. + + // Total size of all network parameters. + size_t paramBufferSize; + + // Offset for the second segment, where gradients for weights and biases in row-major layout + // start. + size_t gradientOffset; + + // Sub-allocated weight/Bias offsets for each layer. + std::vector layerAllocations; + allocateNetworkParameterStorage(layerAllocations, paramBufferSize, gradientOffset); + + std::vector initParams; + srand(1072); + for (int i = 0; i < paramBufferSize / sizeof(NFloat); i++) + { + if (i < gradientOffset / sizeof(NFloat)) + { + float v = rand() / (float)RAND_MAX; + v = v * 2.0f - 1.0f; // Normalize to [-1, 1] + initParams.push_back(floatToHalf(v)); + } + else + { + // Initialize gradients to zero. + initParams.push_back(0); + } + } + auto networkParamsBuffer = createBuffer(paramBufferSize, initParams.data()); + + static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t); + auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize); + clearBuffer(adamStateBuffer); + + std::vector networkConstantBufferData; + auto paramBufferAddr = networkParamsBuffer->getDeviceAddress(); + for (int i = 0; i < kLayerCount; i++) + { + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].weightsOffset); + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].weightsGradOffset); + networkConstantBufferData.push_back(paramBufferAddr + layerAllocations[i].biasOffset); + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].biasGradOffset); + } + auto networkConstantBuffer = createBuffer( + networkConstantBufferData.size() * sizeof(uint64_t), + networkConstantBufferData.data()); + + static const int inputCount = 32; + std::vector inputBufferData; + for (int i = 0; i < inputCount; i++) + { + inputBufferData.push_back((float)rand() / RAND_MAX); + } + auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data()); + + // Create buffer for receiving current loss value. + auto lossBuffer = createBuffer(sizeof(uint64_t)); + + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + + for (int k = 0; k < 1000; k++) + { + clearBuffer(lossBuffer); + + // Compute gradients. + { + LearnGradParams entryPointParams = {}; + entryPointParams.inputs = inputBuffer->getDeviceAddress(); + entryPointParams.count = inputCount / 2; + entryPointParams.lossBuffer = lossBuffer->getDeviceAddress(); + entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress(); + dispatchKernel( + gLearnGradProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + // Adjust parameters in row-major buffer (adam optimize). + { + AdjustParamsParams entryPointParams = {}; + entryPointParams.adamStates = adamStateBuffer->getDeviceAddress(); + entryPointParams.params = networkParamsBuffer->getDeviceAddress(); + entryPointParams.count = (paramBufferSize - gradientOffset) / sizeof(NFloat); + entryPointParams.gradients = + networkParamsBuffer->getDeviceAddress() + gradientOffset; + dispatchKernel( + gAdjustParamProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + if ((k + 1) % 10 == 0) + { + queue->waitOnHost(); + ComPtr blob; + gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef()); + printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer()); + } + } + return SLANG_OK; + } + + // Allocate storage for network parameters, including weights, biases, and gradients. + void allocateNetworkParameterStorage( + std::vector& paramStorage, + size_t& outParamBufferSize, + size_t& outGradientOffset) + { + outParamBufferSize = 0; + + auto allocRowMajorStorage = [&](size_t size) + { + size = (size + 63) / 64 * 64; + size_t offset = outParamBufferSize; + outParamBufferSize += size; + return offset; + }; + + for (int i = 0; i < kLayerCount; i++) + { + size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat); + NetworkParameterAllocation layer = {}; + layer.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat); + layer.weightsOffset = allocRowMajorStorage(layer.weightsSize); + layer.biasSize = biasSize; + layer.biasOffset = allocRowMajorStorage(biasSize); + paramStorage.push_back(layer); + } + + // Alloc storage for gradients. + outGradientOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize); + paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize); + } + } + + template + void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + ComPtr encoder; + queue->createCommandEncoder(encoder.writeRef()); + { + auto computeEncoder = encoder->beginComputePass(); + auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get()); + rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args)); + computeEncoder->dispatchCompute(numWorkGroups, 1, 1); + computeEncoder->end(); + } + ComPtr commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + + // Create a buffer with the specified size and optional initial data. + ComPtr createBuffer(size_t size, void* initData = nullptr) + { + rhi::BufferDesc bufferDesc = {}; + bufferDesc.size = size; + bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess; + bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination | + rhi::BufferUsage::UnorderedAccess; + bufferDesc.memoryType = rhi::MemoryType::DeviceLocal; + return gDevice->createBuffer(bufferDesc, initData); + } + + void clearBuffer(rhi::IBuffer* buffer) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + auto encoder = queue->createCommandEncoder(); + encoder->clearBuffer(buffer); + auto cmdBuffer = encoder->finish(); + queue->submit(cmdBuffer); + } + + Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName) + { + ComPtr entryPoint; + slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()); + + ComPtr linkedProgram; + entryPoint->link(linkedProgram.writeRef()); + + if (isTestMode()) + { + printEntrypointHashes(1, 1, linkedProgram); + } + + Kernel result; + + rhi::ComputePipelineDesc desc; + auto program = gDevice->createShaderProgram(linkedProgram); + desc.program = program.get(); + result.program = program; + result.pipeline = gDevice->createComputePipeline(desc); + return result; + } + + inline unsigned short floatToHalf(float val) + { + uint32_t x = 0; + memcpy(&x, &val, sizeof(float)); + + unsigned short bits = (x >> 16) & 0x8000; + unsigned short m = (x >> 12) & 0x07ff; + unsigned int e = (x >> 23) & 0xff; + if (e < 103) + return bits; + if (e > 142) + { + bits |= 0x7c00u; + bits |= e == 255 && (x & 0x007fffffu); + return bits; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; + } + + ComPtr createSlangSession(rhi::IDevice* device) + { + ComPtr slangSession = device->getSlangSession(); + return slangSession; + } + + ComPtr compileShaderModuleFromFile( + slang::ISession* slangSession, + char const* filePath) + { + ComPtr slangModule; + ComPtr diagnosticBlob; + Slang::String path = resourceBase.resolveResource(filePath); + slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef()); + diagnoseIfNeeded(diagnosticBlob); + + return slangModule; + } + + SlangResult loadShaderKernels() + { + Slang::String path = resourceBase.resolveResource("kernels.slang"); + + gSlangSession = createSlangSession(gDevice); + gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer()); + if (!gSlangModule) + return SLANG_FAIL; + + gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient"); + if (!gLearnGradProgram) + return SLANG_FAIL; + + gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters"); + if (!gAdjustParamProgram) + return SLANG_FAIL; + + return SLANG_OK; + } +}; + +int exampleMain(int argc, char** argv) +{ + ExampleProgram app; + if (SLANG_FAILED(app.execute(argc, argv))) + { + return -1; + } + return 0; +} diff --git a/examples/mlp-training/mlp_sw.slang b/examples/mlp-training/mlp_sw.slang new file mode 100644 index 000000000..1e222b99f --- /dev/null +++ b/examples/mlp-training/mlp_sw.slang @@ -0,0 +1,59 @@ +module mlp_sw; + +import common; + +__include mlvec_sw; + +public struct FeedForwardLayer +{ + public NFloat* weights; + public NFloat* weightsGrad; + public NFloat* biases; + public NFloat* biasesGrad; + + [BackwardDerivative(evalBwd)] + public MLVec eval(MLVec input) + { + var output = matMulAdd( + input, + weights, + biases); + // ReLU activation + for (int i = 0; i < OutputSize; i++) + if (output.data[i] < 0.0) + output.data[i] *= 0.001h; + return output; + } + + public void evalBwd( + inout DifferentialPair> input, + MLVec resultGrad) + { + let fwd = eval(input.p); + + // Back-prop resultGrad through activation. + for (int i = 0; i < OutputSize; i++) + { + if (fwd.data[i] < 0.0) + resultGrad.data[i] *= 0.01h; + } + + // Back-prop gradients to the weights matrix. + outerProductAccumulate( + resultGrad, + input.p, + weightsGrad); + + // Back-prop gradients to the biases vector. + for (int i = 0; i < OutputSize; i++) + { + NFloat originalValue; + InterlockedAddF16Emulated(biasesGrad + i, resultGrad.data[i], originalValue); + } + + // Back-prop gradients to the input vector. + let dInput = matMulTransposed(resultGrad, weights); + + input = {input.p, dInput}; + } +} diff --git a/examples/mlp-training/mlvec_sw.slang b/examples/mlp-training/mlvec_sw.slang new file mode 100644 index 000000000..695755706 --- /dev/null +++ b/examples/mlp-training/mlvec_sw.slang @@ -0,0 +1,64 @@ +implementing mlp_sw; + +public struct MLVec : IDifferentiable +{ + public NFloat data[N]; + + [Differentiable] + public NFloat[N] toArray() + { + return data; + } + + [Differentiable] + public static MLVec fromArray(NFloat[N] values) + { + MLVec result; + [ForceUnroll] + for (int i = 0; i < N; i++) + result.data[i] = values[i]; + return result; + } +} + +MLVec matMulAdd(MLVec input, NFloat* matrix, NFloat* bias) +{ + let getMatElem = (int row, int col) => matrix[row*InputSize + col]; + let getBias = (int idx) => bias[idx]; + MLVec result = {}; + for (int i = 0; i < OutputSize; i++) + { + NFloat r = getBias(i); + for (int j = 0; j < InputSize; j++) + r += getMatElem(i, j) * input.data[j]; + result.data[i] = r; + } + return result; +} + +MLVec matMulTransposed(MLVec input, NFloat* matrix) +{ + let getMatElem = (int row, int col) => matrix[col*OutputSize + row]; + MLVec result = {}; + for (int i = 0; i < OutputSize; i++) + { + NFloat r = {}; + for (int j = 0; j < InputSize; j++) + r += getMatElem(i, j) * input.data[j]; + result.data[i] = r; + } + return result; +} + +void outerProductAccumulate(MLVec v0, MLVec v1, NFloat* matrix) +{ + for (int i = 0; i < M; i++) + { + for (int j = 0; j < N; j++) + { + let elem = v0.data[i] * v1.data[j]; + half original; + InterlockedAddF16Emulated(matrix + (i*N + j), elem, original); + } + } +} diff --git a/examples/mlp-training/network.slang b/examples/mlp-training/network.slang new file mode 100644 index 000000000..a48820f11 --- /dev/null +++ b/examples/mlp-training/network.slang @@ -0,0 +1,59 @@ +module network; + +import common; +import mlp_sw; + +public struct MyNetwork +{ + public FeedForwardLayer<4, 16> layer1; + public FeedForwardLayer<16, 4> layer2; + + [Differentiable] + internal MLVec<4> encodeInput(NFloat x, NFloat y) + { + return MLVec<4>.fromArray({ + x, + y, + x*x, + y*y, + }); + } + + [Differentiable] + internal MLVec<4> _eval(NFloat x, NFloat y) + { + let encoding = encodeInput(x, y); + let layer1Output = layer1.eval(encoding); + let leyer2Output = layer2.eval(layer1Output); + return leyer2Output; + } + + [Differentiable] + public half4 eval(no_diff NFloat x, no_diff NFloat y) + { + let mlv = _eval(x, y); + let arr = mlv.toArray(); + return half4(arr[0], arr[1], arr[2], arr[3]); + } +} + +[Differentiable] +public half loss(MyNetwork* network, no_diff half x, no_diff half y) +{ + let networkResult = network.eval(x, y); + let gt = no_diff groundtruth(x, y); + let diff = networkResult - gt; + + return dot(diff, diff); +} + +public half4 groundtruth(half x, half y) +{ + return { + (x + y) / (1 + y * y), + 2 * x + y, + 0.5 * x * x + 1.2 * y, + x + 0.5 * y * y, + }; +} + diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index deaeae439..2e56c1082 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1809,12 +1809,12 @@ extension Ptr __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } __generic - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(Ptr ptr); __generic - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(NativeRef ptr); } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 32d7ea824..38f274984 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26610,16 +26610,24 @@ CoopVec coopVecMatMulPacked( } } -/// Multiply a cooperative vector with a matrix. -/// @param input The input cooperative vector to multiply with the matrix. +/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and +/// returns a M-element vector. +/// @param input The K-element input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc). -/// @param matrix The matrix to multiply with the input vector. +/// @param matrix The M-by-K matrix to multiply with the input vector. /// @param matrixOffset Byte offset into the matrix buffer. /// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc). /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride in bytes between rows/columns of the matrix. /// @return A new cooperative vector containing the result of the matrix multiplication. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26650,7 +26658,9 @@ CoopVec coopVecMatMul( matrixStride); } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector to the result. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). /// @param k The number of columns in the matrix. @@ -26667,6 +26677,14 @@ CoopVec coopVecMatMul( /// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values /// can be packed into each element of the input vector. The k parameter specifies the actual number of /// values to use from the packed input. +/// +/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26804,7 +26822,9 @@ CoopVec coopVecMatMulAddPacked coopVecMatMulAddPacked CoopVec coopVecMatMulPacked( @@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate +CoopVec coopVecMatMulPacked( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + int operands = 0; // NoneKHR + let zero = 0; + let cvtMatPtr = (Ptr)matrixPtr; + if (__isSignedInt()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm + { + result:$$CoopVec = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; + }; + } +} + +// specialized coopVecMatMul for non-packed inputs +[require(spirv, cooperative_vector)] +__generic +CoopVec coopVecMatMul( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector)] +CoopVec coopVecMatMulAddPacked( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + void* biasPtr, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let zero : int32_t = 0; + let cvtMatPtr = (Ptr)matrixPtr; + let cvtBiasPtr = (Ptr)biasPtr; + return spirv_asm + { + result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector)] +CoopVec coopVecMatMulAdd( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + void* bias, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + bias, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate( + CoopVec a, + CoopVec b, + void* matrixPtr, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + let zero : int32_t = 0; + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let cvtMatrixPtr = (Ptr)matrixPtr; + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate( + CoopVec v, + void* buffer +) +{ + let zero : int32_t = 0; + let bufferPtr = (Ptr)(buffer); + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v; + }; + } +} + //@public: /// Mark a variable as being workgroup uniform. @@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue) { return packHalf2x16(float2(unpackedValue)); } + +[require(spirv)] +void InterlockedAddF16Emulated(half* dest, half value, out half originalValue) +{ + let buf = (half2*)(dest); + uint64_t byteAddress = (uint64_t)dest; + if ((byteAddress & 3) == 0) + { + originalValue = __atomic_add(*buf, half2(value, half(0.0))).x; + } + else + { + originalValue = __atomic_add(*buf, half2(half(0.0), value)).y; + } +} + +[require(spirv)] +void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue) +{ + originalValue = __atomic_add(*dest, value); +} \ No newline at end of file diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 88dea0b7e..56f9d873a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier FIDDLE(...) }; +FIDDLE() +class ExplicitlyDeclaredCapabilityModifier : public Modifier +{ + FIDDLE(...) + FIDDLE() CapabilitySet declaredCapabilityRequirements; +}; + // Marks a synthesized variable as local temporary variable. FIDDLE() class LocalTempVarModifier : public Modifier diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 506abc1be..6456dbe98 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor( if (!structDecl) return false; + if (!structDecl->checkState.isBeingChecked()) + ensureDecl(structDecl, DeclCheckState::AttributesChecked); + HashSet isVisit; bool isCStyle = false; if (!_getSynthesizedConstructor( @@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto toMakeArrayFromElementExpr = m_astBuilder->create(); toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc; toMakeArrayFromElementExpr->type = QualType(toType); - - *outToExpr = toMakeArrayFromElementExpr; + if (outToExpr) + *outToExpr = toMakeArrayFromElementExpr; return true; } for (UInt ee = 0; ee < elementCount; ++ee) @@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto defaultConstructExpr = m_astBuilder->create(); defaultConstructExpr->loc = fromInitializerListExpr->loc; defaultConstructExpr->type = QualType(toType); - - *outToExpr = defaultConstructExpr; + if (outToExpr) + *outToExpr = defaultConstructExpr; return true; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1a70e25d7..0dd859bb2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName())) { // Remove the `ToBeSynthesizedModifier`. - if (as(existingDecl->modifiers.first)) + if (auto mod = existingDecl->modifiers.findModifier()) { - existingDecl->modifiers.first = existingDecl->modifiers.first->next; + removeModifier(existingDecl, mod); } else { @@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness addModifier(aggTypeDecl, m_astBuilder->create()); - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requirementDeclRef.getDecl()->findModifier()) - { - auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(aggTypeDecl, visibility); - } + // The visibility of synthesized decl should be the same of the parent decl. + auto thisVisibility = getDeclVisibility(context->parentDecl); + addVisibilityModifier(aggTypeDecl, thisVisibility); // Synthesize the rest of IDifferential method conformances by recursively checking // conformance on the synthesized decl. @@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement( return false; } - auto satisfyingVal = - tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + IntVal* satisfyingVal = nullptr; + if (isValidCompileTimeConstantType(satisfyingType)) + { + satisfyingVal = + tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + } if (satisfyingVal) { witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); @@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl( return; } + memberDecl = maybeGetInner(memberDecl); if (hasDefaultImpl(requiredMemberDeclRef)) { - memberDecl = maybeGetInner(memberDecl); // If the required member has a default implementation, // we need to make sure the member we found is marked as 'override'. if (!memberDecl->hasModifier()) @@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun } else { + auto declaredCapModifier = m_astBuilder->create(); + declaredCapModifier->declaredCapabilityRequirements = declaredCaps; + addModifier(funcDecl, declaredCapModifier); if (vis == DeclVisibility::Public) { // For public decls, we need to enforce that the function diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 306687bd8..b90081af8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( auto typeDef = m_astBuilder->create(); typeDef->nameAndLoc.name = getName("Differential"); typeDef->parentDecl = structDecl; + addVisibilityModifier(structDecl, getDeclVisibility(parent)); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); @@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); synthesizedDecl = parent; parent->addDirectMemberDecl(typeDef); @@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( // to not allow such cases. // // Note that float-to-inst casts for non-`IntVal`s are allowed. - if (!isScalarIntegerType(decl->getType())) + if (!isValidCompileTimeConstantType(decl->getType())) { getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered); return nullptr; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index e6744071b..a360361f7 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) targetOptionSet.hasOption(CompilerOptionName::Capability) && (targetOptionSet.getIntOption(CompilerOptionName::Capability) != SLANG_CAPABILITY_UNKNOWN); + + if (auto declaredCapsMod = + entryPointFuncDecl->findModifier()) + { + // If the entry point has an explicitly declared capability, then we + // will merge that with the target capability set before checking if + // there is an implicit upgrade. + targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements); + } + // Only attempt to error if a specific profile or capability is requested if ((specificCapabilityRequested || specificProfileRequested) && targetCaps.atLeastOneSetImpliedInOther( diff --git a/tests/language-feature/capability/capabilitySimplification2.slang b/tests/language-feature/capability/capabilitySimplification2.slang index 8d96884ce..d50960b8e 100644 --- a/tests/language-feature/capability/capabilitySimplification2.slang +++ b/tests/language-feature/capability/capabilitySimplification2.slang @@ -1,27 +1,27 @@ -//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3 //TEST:SIMPLE(filecheck=GLSL): -target glsl -entry computeMain -stage compute -profile sm_5_0 //TEST:SIMPLE(filecheck=HLSL): -target hlsl -entry computeMain -stage compute -profile sm_5_0 -//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 -ignore-capabilities +//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3 -ignore-capabilities // CHECK_IGNORE_CAPS-NOT: warning 41012 // SPIRV: warning 41012 // SPIRV-NOT: spirv_1_2 -// SPIRV-NOT: spirv_1_3 // SPIRV-SAME: spvGroupNonUniformBallot // GLSL: warning 41012 // GLSL-NOT: GLSL_400 // GLSL-NOT: GLSL_430 -// GLSL-SAME: GL_KHR_shader_subgroup_ballot +// GLSL-SAME: GL_KHR_shader_subgroup_basic // HLSL: warning 41012 // HLSL-NOT: sm_5_1 // HLSL-SAME: sm_6_0 -[require(sm_6_0)] [numthreads(1,1,1)] void computeMain() { + if (WaveIsFirstLane()) + return; } -- cgit v1.2.3