diff options
| author | Yong He <yonghe@outlook.com> | 2025-06-30 14:32:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 21:32:50 +0000 |
| commit | f28f67d988158d6c46f7ffe967152f98d32a37b2 (patch) | |
| tree | 2aa620986a87ec69cf1f210c714312e42b62ac9e | |
| parent | a55ff722cae338a8fcf5402858c47cf0650a8e5e (diff) | |
Add MLP training examples. (#7550)
* Add MLP training examples.
* Formatting fix.
* Fix.
* Improve documentation on coopvector.
* Improve doc.
* Update doc.
* Fix typo.
* Cleanup shader.
* Cleanup.
* Fix test.
* Fix type check recursion.
* Fix.
* Fix.
* Fix override check.
25 files changed, 1710 insertions, 46 deletions
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 9411d10ba..87b0b39f8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -114,4 +114,9 @@ if(SLANG_ENABLE_EXAMPLES) if(SLANG_ENABLE_AFTERMATH) example(nv-aftermath-example WIN32_EXECUTABLE) endif() + + if(SLANG_ENABLE_SLANG_RHI) + example(mlp-training LINK_WITH_PRIVATE slang-rhi) + example(mlp-training-coopvec LINK_WITH_PRIVATE slang-rhi) + endif() endif() diff --git a/examples/mlp-training-coopvec/README.md b/examples/mlp-training-coopvec/README.md new file mode 100644 index 000000000..8d56fff92 --- /dev/null +++ b/examples/mlp-training-coopvec/README.md @@ -0,0 +1,6 @@ +Slang "MLP-Training-CoopVec" Example +========================== + +This example shows how to use the Slang to train a feed-forward neural network +using automatic differentiation and the cooperative vector intrinsics. Also see the +"MLP-Training" example for the same task without using cooperative vector.
\ No newline at end of file diff --git a/examples/mlp-training-coopvec/adam.slang b/examples/mlp-training-coopvec/adam.slang new file mode 100644 index 000000000..33a357122 --- /dev/null +++ b/examples/mlp-training-coopvec/adam.slang @@ -0,0 +1,38 @@ +module adam; + +import mlp; +import common; + +public struct AdamState +{ + internal NFloat mean; + internal NFloat variance; + internal int iteration; +} + +public struct AdamOptimizer +{ + // Adam parameters + public static const NFloat beta1 = 0.9h; + public static const NFloat beta2 = 0.999h; + public static const NFloat epsilon = 1e-7h; + public static const NFloat learningRate = 0.01h; + + public static void step(inout AdamState state, inout NFloat param, inout NFloat grad) + { + state.iteration++; + if (isinf(grad)) + { + if (grad > 0) + grad = 10000.0h; + else + grad = -10000.0h; + } + state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad; + state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad; + NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration))); + NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration))); + param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon)); + grad = NFloat(0.f); + } +} diff --git a/examples/mlp-training-coopvec/common.slang b/examples/mlp-training-coopvec/common.slang new file mode 100644 index 000000000..92dc3b563 --- /dev/null +++ b/examples/mlp-training-coopvec/common.slang @@ -0,0 +1 @@ +public typealias NFloat = half;
\ No newline at end of file diff --git a/examples/mlp-training-coopvec/kernels.slang b/examples/mlp-training-coopvec/kernels.slang new file mode 100644 index 000000000..712494b1f --- /dev/null +++ b/examples/mlp-training-coopvec/kernels.slang @@ -0,0 +1,41 @@ +module kernels; + +import common; +import mlp; +import network; +import adam; + +[numthreads(256, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvCooperativeVectorNV)] +void learnGradient( + uint32_t tid : SV_DispatchThreadID, + uniform MyNetwork* network, + uniform Atomic<uint32_t>* lossBuffer, + uniform float2* inputs, + uniform uint32_t count) +{ + if (tid >= count) + return; + + var input = (half2)inputs[tid]; + bwd_diff(loss)(network, input.x, input.y, 1.0h); + let thisLoss = (float)loss(network, input.x, input.y); + let maxLoss = WaveActiveMax(thisLoss); + if (WaveIsFirstLane()) + { + lossBuffer.max(bit_cast<uint32_t>(maxLoss)); + } +} + +[numthreads(256, 1, 1)] +void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count) +{ + if (tid >= count) + return; + if (isnan(gradients[tid])) + { + gradients[tid] = 0.0h; + return; + } + AdamOptimizer::step(states[tid], params[tid], gradients[tid]); +}
\ No newline at end of file diff --git a/examples/mlp-training-coopvec/mlp-training-coopvec.cpp b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp new file mode 100644 index 000000000..9b5dde531 --- /dev/null +++ b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp @@ -0,0 +1,462 @@ +// In this example, we implement a simple multi-layer perceptron (MLP) training loop on +// Vulkan (through slang-rhi), using cooperative vector intrinsics. +// +// The simple MLP is trained to approximate a polynomial expression. +// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4 +// outputs. + +#include "core/slang-basic.h" +#include "examples/example-base/example-base.h" +#include "external/slang-rhi/include/slang-rhi.h" +#include "slang-com-ptr.h" +#include "slang.h" + +static const ExampleResources resourceBase("mlp-training-coopvec"); + +typedef uint16_t NFloat; + +// Define the sizes of the layers in the MLP. +static const int kLayerSizes[] = {4, 16, 4}; +static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1; + +using Slang::ComPtr; + +struct Kernel +{ + ComPtr<rhi::IShaderProgram> program; + ComPtr<rhi::IComputePipeline> pipeline; + explicit operator bool() { return program && pipeline; } +}; + +struct ClearBufferParams +{ + rhi::DeviceAddress buffer; + uint32_t count; +}; + +struct LearnGradParams +{ + rhi::DeviceAddress networkBuffer; + rhi::DeviceAddress lossBuffer; + rhi::DeviceAddress inputs; + uint32_t count; +}; + +struct AdjustParamsParams +{ + rhi::DeviceAddress adamStates; + rhi::DeviceAddress params; + rhi::DeviceAddress gradients; + uint32_t count; +}; + +struct ExampleProgram : public TestBase +{ + ComPtr<rhi::IDevice> gDevice; + + ComPtr<slang::ISession> gSlangSession; + ComPtr<slang::IModule> gSlangModule; + Kernel gLearnGradProgram; + Kernel gAdjustParamProgram; + + // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients). + // + struct NetworkParameterAllocation + { + size_t weightsOffset; + size_t weightsSize; + size_t biasOffset; + size_t biasSize; + size_t weightsGradOffset; + size_t biasGradOffset; + size_t weightsGradTrainingOffset; + size_t weightsGradTrainingSize; + }; + + SlangResult execute(int argc, char* argv[]) + { + parseOption(argc, argv); + + rhi::DeviceDesc deviceDesc; + deviceDesc.slang.targetProfile = "spirv_1_6"; + deviceDesc.deviceType = rhi::DeviceType::Vulkan; + + gDevice = rhi::getRHI()->createDevice(deviceDesc); + if (!gDevice) + return SLANG_FAIL; + + SLANG_RETURN_ON_FAIL(loadShaderKernels()); + + // Create a buffer to hold all network parameters (weights, biases, gradients). + // This buffer is arranged as following: + // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN | + // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... | + // (segment 3): | weightsGradTraining0 | weightsGradTraining1 | ... | + // + // Where the first segment contains all weights and biases for each layer in row-major + // layout. The second segment contains gradients for weights and biases in row-major layout. + // The third segment contains gradients for weights in training-optimal layout. + // The training-optimal layout is used to accumulate gradients for weights with the + // `coopVecOuterProductAccumulate` intrinsic, which requires the destination to be in + // training-optimal layout. + // After accumulating gradients, we will convert them to row-major layout (i.e. copy + // them back into the second segment) so we can read them in the optimization kernel. + + // Total size of all network parameters. + size_t networkParamsBufferSize; + + // Offset for the second segment, where gradients for weights and biases in row-major layout + // start. + size_t networkGraidentOffset; + + // Offset for the third segment, where gradients for weights in training-optimal layout + // start. + size_t networkGradientTrainingOffset; + + // Sub-allocated weight/Bias offsets for each layer. + std::vector<NetworkParameterAllocation> layerAllocations; + + // Allocate storage for network parameters, filling in `layerRowMajorAllocations`, + // `networkParamsBufferSize`, `networkGraidentOffset` and `networkGradientTrainingOffset`. + // + allocateNetworkParameterStorage( + layerAllocations, + networkParamsBufferSize, + networkGraidentOffset, + networkGradientTrainingOffset); + + // We'll initialize the buffer with random values in the range [-1, 1]. + std::vector<uint16_t> initParams; + srand(1072); + for (int i = 0; i < networkParamsBufferSize / sizeof(NFloat); i++) + { + float v = rand() / (float)RAND_MAX; + v = v * 2.0f - 1.0f; // Normalize to [-1, 1] + initParams.push_back(floatToHalf(v)); + } + auto networkParamsBuffer = createBuffer(networkParamsBufferSize, initParams.data()); + + // Create a buffer for holding the Adam optimizer state for each network parameter. + static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t); + auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize); + clearBuffer(adamStateBuffer); + + // Prepare buffer for the `network` struct that holds pointers to network parameters for + // each layer. + std::vector<uint64_t> networkConstantBufferData; + for (int i = 0; i < kLayerCount; i++) + { + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].weightsOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasOffset); + networkConstantBufferData.push_back( + networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasGradOffset); + } + auto networkConstantBuffer = createBuffer( + networkConstantBufferData.size() * sizeof(uint64_t), + networkConstantBufferData.data()); + + // Create buffer for input data. + static const int inputCount = 32; + std::vector<float> inputBufferData; + for (int i = 0; i < inputCount; i++) + { + inputBufferData.push_back((float)rand() / RAND_MAX); + } + auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data()); + + // Create buffer for receiving current loss value. + auto lossBuffer = createBuffer(sizeof(uint64_t)); + + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + + // Run training loop. + for (int k = 0; k < 1000; k++) + { + clearBuffer(lossBuffer); + + // Clear weight gradients in the parameter buffer to 0. + clearBuffer( + networkParamsBuffer, + rhi::BufferRange{ + networkGradientTrainingOffset, + networkParamsBufferSize - networkGradientTrainingOffset}); + // Compute gradients for weights and biases. + // The weight gradients are stored in the training-optimal layout. + { + LearnGradParams entryPointParams = {}; + entryPointParams.inputs = inputBuffer->getDeviceAddress(); + entryPointParams.count = inputCount / 2; + entryPointParams.lossBuffer = lossBuffer->getDeviceAddress(); + entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress(); + dispatchKernel( + gLearnGradProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + // Copy weight gradients from training-optimal layout to row-major layout, + // so we can read them in the `adjustParameters` kernel. + { + std::vector<rhi::ConvertCooperativeVectorMatrixDesc> matrixDescs; + for (int i = 0; i < kLayerCount; i++) + { + rhi::ConvertCooperativeVectorMatrixDesc desc = {}; + desc.rowCount = kLayerSizes[i + 1]; + desc.colCount = kLayerSizes[i]; + desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.dstSize = &layerAllocations[i].weightsSize; + desc.dstData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradOffset; + desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + desc.dstStride = getNetworkLayerWeightStride(i); + desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + desc.srcSize = layerAllocations[i].weightsGradTrainingSize; + desc.srcData.deviceAddress = networkParamsBuffer->getDeviceAddress() + + layerAllocations[i].weightsGradTrainingOffset; + desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDescs.push_back(desc); + } + auto encoder = queue->createCommandEncoder(); + encoder->convertCooperativeVectorMatrix( + matrixDescs.data(), + (uint32_t)matrixDescs.size()); + ComPtr<rhi::ICommandBuffer> commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + // Adjust parameters in row-major buffer (adam optimize). + { + AdjustParamsParams entryPointParams = {}; + entryPointParams.adamStates = adamStateBuffer->getDeviceAddress(); + entryPointParams.params = networkParamsBuffer->getDeviceAddress(); + entryPointParams.count = + (networkGradientTrainingOffset - networkGraidentOffset) / sizeof(NFloat); + entryPointParams.gradients = + networkParamsBuffer->getDeviceAddress() + networkGraidentOffset; + dispatchKernel( + gAdjustParamProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + + // Print loss value every 10 iterations. + if ((k + 1) % 10 == 0) + { + queue->waitOnHost(); + ComPtr<ISlangBlob> blob; + gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef()); + printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer()); + } + } + return SLANG_OK; + } + + // Allocate storage for network parameters, including weights, biases, and gradients. + void allocateNetworkParameterStorage( + std::vector<NetworkParameterAllocation>& paramStorage, + size_t& outParamBufferSize, + size_t& outGradientOffset, + size_t& outGradientTrainingOffset) + { + outParamBufferSize = 0; + + auto allocRowMajorStorage = [&](size_t size) + { + size = (size + 63) / 64 * 64; + size_t offset = outParamBufferSize; + outParamBufferSize += size; + return offset; + }; + + for (int i = 0; i < kLayerCount; i++) + { + size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat); + NetworkParameterAllocation layerStorage = {}; + layerStorage.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat); + layerStorage.weightsOffset = allocRowMajorStorage(layerStorage.weightsSize); + layerStorage.biasSize = biasSize; + layerStorage.biasOffset = allocRowMajorStorage(biasSize); + paramStorage.push_back(layerStorage); + } + + // Alloc storage for weight and bias gradients (row major layout). + outGradientOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize); + paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize); + } + + // Alloc training-optimal storage for weight gradients. + outGradientTrainingOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + // Allocate space for gradients in training-optimal layout. + rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {}; + matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.srcSize = paramStorage[i].weightsSize; + matrixDesc.srcData.hostAddress = nullptr; + matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor; + matrixDesc.srcStride = getNetworkLayerWeightStride(i); + matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16; + matrixDesc.dstSize = ¶mStorage[i].weightsGradTrainingSize; + matrixDesc.dstData.hostAddress = nullptr; + matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal; + matrixDesc.dstStride = 0; + matrixDesc.rowCount = kLayerSizes[i + 1]; + matrixDesc.colCount = kLayerSizes[i]; + gDevice->convertCooperativeVectorMatrix(&matrixDesc, 1); + paramStorage[i].weightsGradTrainingOffset = + allocRowMajorStorage(paramStorage[i].weightsGradTrainingSize); + } + } + + // Dispatch a compute kernel with the given arguments and number of work groups. + template<typename Args> + void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + ComPtr<rhi::ICommandEncoder> encoder; + queue->createCommandEncoder(encoder.writeRef()); + { + auto computeEncoder = encoder->beginComputePass(); + auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get()); + rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args)); + computeEncoder->dispatchCompute(numWorkGroups, 1, 1); + computeEncoder->end(); + } + ComPtr<rhi::ICommandBuffer> commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + + // Create a buffer with the specified size and optional initial data. + ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr) + { + rhi::BufferDesc bufferDesc = {}; + bufferDesc.size = size; + bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess; + bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination | + rhi::BufferUsage::UnorderedAccess; + bufferDesc.memoryType = rhi::MemoryType::DeviceLocal; + return gDevice->createBuffer(bufferDesc, initData); + } + + void clearBuffer(rhi::IBuffer* buffer, rhi::BufferRange range = rhi::kEntireBuffer) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + auto encoder = queue->createCommandEncoder(); + encoder->clearBuffer(buffer, range); + auto cmdBuffer = encoder->finish(); + queue->submit(cmdBuffer); + } + + SlangResult loadShaderKernels() + { + Slang::String path = resourceBase.resolveResource("kernels.slang"); + + gSlangSession = createSlangSession(gDevice); + gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer()); + if (!gSlangModule) + return SLANG_FAIL; + + gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient"); + if (!gLearnGradProgram) + return SLANG_FAIL; + + gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters"); + if (!gAdjustParamProgram) + return SLANG_FAIL; + + return SLANG_OK; + } + + Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName) + { + ComPtr<slang::IEntryPoint> entryPoint; + slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + entryPoint->link(linkedProgram.writeRef()); + + if (isTestMode()) + { + printEntrypointHashes(1, 1, linkedProgram); + } + + Kernel result; + + rhi::ComputePipelineDesc desc; + auto program = gDevice->createShaderProgram(linkedProgram); + desc.program = program.get(); + result.program = program; + result.pipeline = gDevice->createComputePipeline(desc); + return result; + } + + static inline unsigned short floatToHalf(float val) + { + uint32_t x = 0; + memcpy(&x, &val, sizeof(float)); + + unsigned short bits = (x >> 16) & 0x8000; + unsigned short m = (x >> 12) & 0x07ff; + unsigned int e = (x >> 23) & 0xff; + if (e < 103) + return bits; + if (e > 142) + { + bits |= 0x7c00u; + bits |= e == 255 && (x & 0x007fffffu); + return bits; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; + } + + int getNetworkLayerWeightStride(int i) { return kLayerSizes[i] * sizeof(NFloat); } + + int getNetworkLayerWeightCount(int i) { return kLayerSizes[i] * kLayerSizes[i + 1]; } + + int getNetworkLayerBiasCount(int i) { return kLayerSizes[i + 1]; } + + ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device) + { + ComPtr<slang::ISession> slangSession = device->getSlangSession(); + return slangSession; + } + + ComPtr<slang::IModule> compileShaderModuleFromFile( + slang::ISession* slangSession, + char const* filePath) + { + ComPtr<slang::IModule> slangModule; + ComPtr<slang::IBlob> diagnosticBlob; + Slang::String path = resourceBase.resolveResource(filePath); + slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef()); + diagnoseIfNeeded(diagnosticBlob); + + return slangModule; + } +}; + +int exampleMain(int argc, char** argv) +{ + ExampleProgram app; + if (SLANG_FAILED(app.execute(argc, argv))) + { + return -1; + } + return 0; +} diff --git a/examples/mlp-training-coopvec/mlp.slang b/examples/mlp-training-coopvec/mlp.slang new file mode 100644 index 000000000..44380ad77 --- /dev/null +++ b/examples/mlp-training-coopvec/mlp.slang @@ -0,0 +1,73 @@ +module mlp; + +import common; + +__include mlvec; + +// We use Float16 for the CoopVec component type since it is more widely supported. +// +static const CoopVecComponentType kComponentType = CoopVecComponentType.Float16; + +public struct FeedForwardLayer<int InputSize, int OutputSize> +{ + internal void* weights; + internal void* weightsGrad; + internal void* biases; + internal void* biasesGrad; + + public MLVec<OutputSize> eval(MLVec<InputSize> input) + { + // Compute mul(weights, inputVec) + biases. + // `weights` is treated as an OutputSize(row) x InputSize(col) matrix. + var output = coopVecMatMulAdd<NFloat, OutputSize>( + input.data, kComponentType, // input and format + weights, kComponentType, // weights and format + biases, kComponentType, // biases and format + CoopVecMatrixLayout.RowMajor, // matrix layout + false, // transpose matrix? must be `false` since we specified RowMajor. + InputSize * sizeof(NFloat)); // matrix stride + output = max(output, output * 0.001h); // Leaky ReLU activation + return {output}; + } + + [BackwardDerivativeOf(eval)] + public void evalBwd( + inout DifferentialPair<MLVec<InputSize>> input, + MLVec<OutputSize> resultGrad) + { + let fwd = eval(input.p); + + // Back-prop resultGrad through activation. + [ForceUnroll] + for (int i = 0; i < OutputSize; i++) + { + if (fwd.data[i] < 0.0) + resultGrad.data[i] *= 0.01h; + } + + // Back-prop gradients to the weights matrix. + coopVecOuterProductAccumulate( + resultGrad.data, + input.p.data, + weightsGrad, + 0, // matrixStride, ignored since layout is TrainingOptimal + CoopVecMatrixLayout.TrainingOptimal, // matrix layout, must be TrainingOptimal. + kComponentType); + + // Back-prop gradients to the biases vector. + coopVecReduceSumAccumulate(resultGrad.data, (void*)biasesGrad); + + // Back-prop gradients to the input vector by computing + // mul(transpose(weights), resultGrad). + // By specifying the matrix layout as ColumnMajor, we can + // achieve the effect of transposing the weights matrix. + let dInput = coopVecMatMul<NFloat, InputSize>( + resultGrad.data, kComponentType, + weights, kComponentType, + CoopVecMatrixLayout.ColumnMajor, + false, // transpose, must be `false` since we specified ColumnMajor. + InputSize * sizeof(NFloat)); + + input = {input.p, {dInput}}; + } +} diff --git a/examples/mlp-training-coopvec/mlvec.slang b/examples/mlp-training-coopvec/mlvec.slang new file mode 100644 index 000000000..ce7ce8352 --- /dev/null +++ b/examples/mlp-training-coopvec/mlvec.slang @@ -0,0 +1,63 @@ +implementing mlp; + +// A wrapper of CoopVec<T> to allow it being used in differentiable context. +// +public struct MLVec<int N> : IDifferentiable +{ + public CoopVec<NFloat, N> data; + public typealias Differential = MLVec<N>; + + public static MLVec<N> fromArray(NFloat[N] values) + { + MLVec<N> result; + [ForceUnroll] + for (int i = 0; i < N; i++) + result.data[i] = values[i]; + return result; + } + + internal static NFloat[N] coopVecToArray(CoopVec<NFloat, N> v) + { + NFloat[N] arr; + [ForceUnroll] + for (int i = 0; i < N; i++) + arr[i] = v[i]; + return arr; + } + + [BackwardDerivativeOf(fromArray)] + internal static void fromArrayBwd(inout DifferentialPair<NFloat[N]> values, MLVec<N> dResult) + { + values = diffPair(values.p, coopVecToArray(dResult.data)); + } + + internal static NFloat[N] toArray(MLVec<N> vec) + { + return coopVecToArray(vec.data); + } + + [BackwardDerivativeOf(toArray)] + internal static void toArrayBwd(inout DifferentialPair<MLVec<N>> vec, NFloat[N] dResult) + { + vec = diffPair(vec.p, MLVec<N>.fromArray(dResult)); + } + + [Differentiable] + public NFloat[N] toArray() + { + return toArray(this); + } + + public override static Differential dadd(Differential d0, Differential d1) + { + return {d0.data + d1.data}; + } + public override static Differential dmul<U:__BuiltinRealType>(U s, Differential d) + { + return {d.data * __realCast<NFloat>(s)}; + } + public override static Differential dzero() + { + return {}; + } +} diff --git a/examples/mlp-training-coopvec/network.slang b/examples/mlp-training-coopvec/network.slang new file mode 100644 index 000000000..5741487c4 --- /dev/null +++ b/examples/mlp-training-coopvec/network.slang @@ -0,0 +1,58 @@ +module network; + +import common; +import mlp; + +public struct MyNetwork +{ + public FeedForwardLayer<4, 16> layer1; + public FeedForwardLayer<16, 4> layer2; + + [Differentiable] + internal MLVec<4> encodeInput(NFloat x, NFloat y) + { + return MLVec<4>.fromArray({ + x, + y, + x*x, + y*y, + }); + } + + [Differentiable] + internal MLVec<4> _eval(NFloat x, NFloat y) + { + let encoding = encodeInput(x, y); + let layer1Output = layer1.eval(encoding); + let leyer2Output = layer2.eval(layer1Output); + return leyer2Output; + } + + [Differentiable] + public half4 eval(no_diff NFloat x, no_diff NFloat y) + { + let mlv = _eval(x, y); + let arr = mlv.toArray(); + return half4(arr[0], arr[1], arr[2], arr[3]); + } +} + +[Differentiable] +public half loss(MyNetwork* network, no_diff half x, no_diff half y) +{ + let networkResult = network.eval(x, y); + let gt = no_diff groundtruth(x, y); + let diff = networkResult - gt; + return dot(diff, diff); +} + +public half4 groundtruth(half x, half y) +{ + return { + (x + y) / (1 + y * y), + 2 * x + y, + 0.5 * x * x + 1.2 * y, + x + 0.5 * y * y, + }; +} + diff --git a/examples/mlp-training/README.md b/examples/mlp-training/README.md new file mode 100644 index 000000000..c5266bbf1 --- /dev/null +++ b/examples/mlp-training/README.md @@ -0,0 +1,6 @@ +Slang "MLP-Training" Example +========================== + +This example shows how to use the Slang to train a feed-forward neural network +using automatic differentiation. Also see the "MLP-Training-CoopVec" example +to see how to use the cooperative vector intrinsics to speedup training.
\ No newline at end of file diff --git a/examples/mlp-training/adam.slang b/examples/mlp-training/adam.slang new file mode 100644 index 000000000..8f9b15f01 --- /dev/null +++ b/examples/mlp-training/adam.slang @@ -0,0 +1,38 @@ +module adam; + +import mlp_sw; +import common; + +public struct AdamState +{ + internal NFloat mean; + internal NFloat variance; + internal int iteration; +} + +public struct AdamOptimizer +{ + // Adam parameters + public static const NFloat beta1 = 0.9h; + public static const NFloat beta2 = 0.999h; + public static const NFloat epsilon = 1e-7h; + public static const NFloat learningRate = 0.01h; + + public static void step(inout AdamState state, inout NFloat param, inout NFloat grad) + { + state.iteration++; + if (isinf(grad)) + { + if (grad > 0) + grad = 10000.0h; + else + grad = -10000.0h; + } + state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad; + state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad; + NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration))); + NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration))); + param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon)); + grad = NFloat(0.f); + } +} diff --git a/examples/mlp-training/common.slang b/examples/mlp-training/common.slang new file mode 100644 index 000000000..92dc3b563 --- /dev/null +++ b/examples/mlp-training/common.slang @@ -0,0 +1 @@ +public typealias NFloat = half;
\ No newline at end of file diff --git a/examples/mlp-training/kernels.slang b/examples/mlp-training/kernels.slang new file mode 100644 index 000000000..5be076879 --- /dev/null +++ b/examples/mlp-training/kernels.slang @@ -0,0 +1,41 @@ +module kernels; + +import common; +import mlp_sw; +import network; +import adam; + +[numthreads(256, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)] +void learnGradient( + uint32_t tid : SV_DispatchThreadID, + uniform MyNetwork* network, + uniform Atomic<uint32_t>* lossBuffer, + uniform float2* inputs, + uniform uint32_t count) +{ + if (tid >= count) + return; + + var input = (half2)inputs[tid]; + bwd_diff(loss)(network, input.x, input.y, 1.0h); + let thisLoss = (float)loss(network, input.x, input.y); + let maxLoss = WaveActiveMax(thisLoss); + if (WaveIsFirstLane()) + { + lossBuffer.max(bit_cast<uint32_t>(maxLoss)); + } +} + +[numthreads(256, 1, 1)] +void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count) +{ + if (tid >= count) + return; + if (isnan(gradients[tid])) + { + gradients[tid] = 0.0h; + return; + } + AdamOptimizer::step(states[tid], params[tid], gradients[tid]); +}
\ No newline at end of file diff --git a/examples/mlp-training/mlp-training.cpp b/examples/mlp-training/mlp-training.cpp new file mode 100644 index 000000000..38090578f --- /dev/null +++ b/examples/mlp-training/mlp-training.cpp @@ -0,0 +1,389 @@ +// In this example, we implement a simple multi-layer perceptron (MLP) training loop on +// Vulkan (through slang-rhi). See also the mlp-training-coopvec example, which +// implements the same MLP training loop using cooperative vector intrinsics for better +// performance. +// +// The simple MLP is trained to approximate a polynomial expression. +// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4 +// outputs. + +#include "core/slang-basic.h" +#include "examples/example-base/example-base.h" +#include "external/slang-rhi/include/slang-rhi.h" +#include "slang-com-ptr.h" +#include "slang.h" + +#include <string> + +using Slang::ComPtr; + +static const ExampleResources resourceBase("mlp-training"); + +typedef uint16_t NFloat; + +static const int kLayerSizes[] = {4, 16, 4}; +static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1; + +int getNetworkLayerWeightStride(int i) +{ + return kLayerSizes[i] * sizeof(NFloat); +} + +int getNetworkLayerWeightCount(int i) +{ + return kLayerSizes[i] * kLayerSizes[i + 1]; +} + +int getNetworkLayerBiasCount(int i) +{ + return kLayerSizes[i + 1]; +} + +struct Kernel +{ + ComPtr<rhi::IShaderProgram> program; + ComPtr<rhi::IComputePipeline> pipeline; + operator bool() { return program && pipeline; } +}; + +struct ClearBufferParams +{ + rhi::DeviceAddress buffer; + uint32_t count; +}; + +struct LearnGradParams +{ + rhi::DeviceAddress networkBuffer; + rhi::DeviceAddress lossBuffer; + rhi::DeviceAddress inputs; + uint32_t count; +}; + +struct AdjustParamsParams +{ + rhi::DeviceAddress adamStates; + rhi::DeviceAddress params; + rhi::DeviceAddress gradients; + uint32_t count; +}; + +struct ExampleProgram : public TestBase +{ + ComPtr<rhi::IDevice> gDevice; + + ComPtr<slang::ISession> gSlangSession; + ComPtr<slang::IModule> gSlangModule; + Kernel gLearnGradProgram; + Kernel gAdjustParamProgram; + + // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients). + // + struct NetworkParameterAllocation + { + size_t weightsOffset; + size_t weightsSize; + size_t biasOffset; + size_t biasSize; + size_t weightsGradOffset; + size_t biasGradOffset; + }; + + SlangResult execute(int argc, char* argv[]) + { + parseOption(argc, argv); + + rhi::DeviceDesc deviceDesc; + deviceDesc.slang.targetProfile = "spirv_1_6"; + deviceDesc.deviceType = rhi::DeviceType::Vulkan; + + gDevice = rhi::getRHI()->createDevice(deviceDesc); + if (!gDevice) + return SLANG_FAIL; + + SLANG_RETURN_ON_FAIL(loadShaderKernels()); + + // Create a buffer to hold all network parameters (weights, biases, gradients). + // This buffer is arranged as following: + // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN | + // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... | + // + // Where the first segment contains all weights and biases for each layer in row-major + // layout. The second segment contains gradients for weights and biases in row-major layout. + + // Total size of all network parameters. + size_t paramBufferSize; + + // Offset for the second segment, where gradients for weights and biases in row-major layout + // start. + size_t gradientOffset; + + // Sub-allocated weight/Bias offsets for each layer. + std::vector<NetworkParameterAllocation> layerAllocations; + allocateNetworkParameterStorage(layerAllocations, paramBufferSize, gradientOffset); + + std::vector<uint16_t> initParams; + srand(1072); + for (int i = 0; i < paramBufferSize / sizeof(NFloat); i++) + { + if (i < gradientOffset / sizeof(NFloat)) + { + float v = rand() / (float)RAND_MAX; + v = v * 2.0f - 1.0f; // Normalize to [-1, 1] + initParams.push_back(floatToHalf(v)); + } + else + { + // Initialize gradients to zero. + initParams.push_back(0); + } + } + auto networkParamsBuffer = createBuffer(paramBufferSize, initParams.data()); + + static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t); + auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize); + clearBuffer(adamStateBuffer); + + std::vector<uint64_t> networkConstantBufferData; + auto paramBufferAddr = networkParamsBuffer->getDeviceAddress(); + for (int i = 0; i < kLayerCount; i++) + { + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].weightsOffset); + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].weightsGradOffset); + networkConstantBufferData.push_back(paramBufferAddr + layerAllocations[i].biasOffset); + networkConstantBufferData.push_back( + paramBufferAddr + layerAllocations[i].biasGradOffset); + } + auto networkConstantBuffer = createBuffer( + networkConstantBufferData.size() * sizeof(uint64_t), + networkConstantBufferData.data()); + + static const int inputCount = 32; + std::vector<float> inputBufferData; + for (int i = 0; i < inputCount; i++) + { + inputBufferData.push_back((float)rand() / RAND_MAX); + } + auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data()); + + // Create buffer for receiving current loss value. + auto lossBuffer = createBuffer(sizeof(uint64_t)); + + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + + for (int k = 0; k < 1000; k++) + { + clearBuffer(lossBuffer); + + // Compute gradients. + { + LearnGradParams entryPointParams = {}; + entryPointParams.inputs = inputBuffer->getDeviceAddress(); + entryPointParams.count = inputCount / 2; + entryPointParams.lossBuffer = lossBuffer->getDeviceAddress(); + entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress(); + dispatchKernel( + gLearnGradProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + // Adjust parameters in row-major buffer (adam optimize). + { + AdjustParamsParams entryPointParams = {}; + entryPointParams.adamStates = adamStateBuffer->getDeviceAddress(); + entryPointParams.params = networkParamsBuffer->getDeviceAddress(); + entryPointParams.count = (paramBufferSize - gradientOffset) / sizeof(NFloat); + entryPointParams.gradients = + networkParamsBuffer->getDeviceAddress() + gradientOffset; + dispatchKernel( + gAdjustParamProgram, + entryPointParams, + (entryPointParams.count + 255) / 256); + } + if ((k + 1) % 10 == 0) + { + queue->waitOnHost(); + ComPtr<ISlangBlob> blob; + gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef()); + printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer()); + } + } + return SLANG_OK; + } + + // Allocate storage for network parameters, including weights, biases, and gradients. + void allocateNetworkParameterStorage( + std::vector<NetworkParameterAllocation>& paramStorage, + size_t& outParamBufferSize, + size_t& outGradientOffset) + { + outParamBufferSize = 0; + + auto allocRowMajorStorage = [&](size_t size) + { + size = (size + 63) / 64 * 64; + size_t offset = outParamBufferSize; + outParamBufferSize += size; + return offset; + }; + + for (int i = 0; i < kLayerCount; i++) + { + size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat); + NetworkParameterAllocation layer = {}; + layer.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat); + layer.weightsOffset = allocRowMajorStorage(layer.weightsSize); + layer.biasSize = biasSize; + layer.biasOffset = allocRowMajorStorage(biasSize); + paramStorage.push_back(layer); + } + + // Alloc storage for gradients. + outGradientOffset = outParamBufferSize; + for (int i = 0; i < kLayerCount; i++) + { + paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize); + paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize); + } + } + + template<typename Args> + void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + ComPtr<rhi::ICommandEncoder> encoder; + queue->createCommandEncoder(encoder.writeRef()); + { + auto computeEncoder = encoder->beginComputePass(); + auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get()); + rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args)); + computeEncoder->dispatchCompute(numWorkGroups, 1, 1); + computeEncoder->end(); + } + ComPtr<rhi::ICommandBuffer> commandBuffer; + encoder->finish(commandBuffer.writeRef()); + queue->submit(commandBuffer); + } + + // Create a buffer with the specified size and optional initial data. + ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr) + { + rhi::BufferDesc bufferDesc = {}; + bufferDesc.size = size; + bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess; + bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination | + rhi::BufferUsage::UnorderedAccess; + bufferDesc.memoryType = rhi::MemoryType::DeviceLocal; + return gDevice->createBuffer(bufferDesc, initData); + } + + void clearBuffer(rhi::IBuffer* buffer) + { + auto queue = gDevice->getQueue(rhi::QueueType::Graphics); + auto encoder = queue->createCommandEncoder(); + encoder->clearBuffer(buffer); + auto cmdBuffer = encoder->finish(); + queue->submit(cmdBuffer); + } + + Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName) + { + ComPtr<slang::IEntryPoint> entryPoint; + slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + entryPoint->link(linkedProgram.writeRef()); + + if (isTestMode()) + { + printEntrypointHashes(1, 1, linkedProgram); + } + + Kernel result; + + rhi::ComputePipelineDesc desc; + auto program = gDevice->createShaderProgram(linkedProgram); + desc.program = program.get(); + result.program = program; + result.pipeline = gDevice->createComputePipeline(desc); + return result; + } + + inline unsigned short floatToHalf(float val) + { + uint32_t x = 0; + memcpy(&x, &val, sizeof(float)); + + unsigned short bits = (x >> 16) & 0x8000; + unsigned short m = (x >> 12) & 0x07ff; + unsigned int e = (x >> 23) & 0xff; + if (e < 103) + return bits; + if (e > 142) + { + bits |= 0x7c00u; + bits |= e == 255 && (x & 0x007fffffu); + return bits; + } + if (e < 113) + { + m |= 0x0800u; + bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1); + return bits; + } + bits |= ((e - 112) << 10) | (m >> 1); + bits += m & 1; + return bits; + } + + ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device) + { + ComPtr<slang::ISession> slangSession = device->getSlangSession(); + return slangSession; + } + + ComPtr<slang::IModule> compileShaderModuleFromFile( + slang::ISession* slangSession, + char const* filePath) + { + ComPtr<slang::IModule> slangModule; + ComPtr<slang::IBlob> diagnosticBlob; + Slang::String path = resourceBase.resolveResource(filePath); + slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef()); + diagnoseIfNeeded(diagnosticBlob); + + return slangModule; + } + + SlangResult loadShaderKernels() + { + Slang::String path = resourceBase.resolveResource("kernels.slang"); + + gSlangSession = createSlangSession(gDevice); + gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer()); + if (!gSlangModule) + return SLANG_FAIL; + + gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient"); + if (!gLearnGradProgram) + return SLANG_FAIL; + + gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters"); + if (!gAdjustParamProgram) + return SLANG_FAIL; + + return SLANG_OK; + } +}; + +int exampleMain(int argc, char** argv) +{ + ExampleProgram app; + if (SLANG_FAILED(app.execute(argc, argv))) + { + return -1; + } + return 0; +} diff --git a/examples/mlp-training/mlp_sw.slang b/examples/mlp-training/mlp_sw.slang new file mode 100644 index 000000000..1e222b99f --- /dev/null +++ b/examples/mlp-training/mlp_sw.slang @@ -0,0 +1,59 @@ +module mlp_sw; + +import common; + +__include mlvec_sw; + +public struct FeedForwardLayer<int InputSize, int OutputSize> +{ + public NFloat* weights; + public NFloat* weightsGrad; + public NFloat* biases; + public NFloat* biasesGrad; + + [BackwardDerivative(evalBwd)] + public MLVec<OutputSize> eval(MLVec<InputSize> input) + { + var output = matMulAdd<OutputSize>( + input, + weights, + biases); + // ReLU activation + for (int i = 0; i < OutputSize; i++) + if (output.data[i] < 0.0) + output.data[i] *= 0.001h; + return output; + } + + public void evalBwd( + inout DifferentialPair<MLVec<InputSize>> input, + MLVec<OutputSize> resultGrad) + { + let fwd = eval(input.p); + + // Back-prop resultGrad through activation. + for (int i = 0; i < OutputSize; i++) + { + if (fwd.data[i] < 0.0) + resultGrad.data[i] *= 0.01h; + } + + // Back-prop gradients to the weights matrix. + outerProductAccumulate( + resultGrad, + input.p, + weightsGrad); + + // Back-prop gradients to the biases vector. + for (int i = 0; i < OutputSize; i++) + { + NFloat originalValue; + InterlockedAddF16Emulated(biasesGrad + i, resultGrad.data[i], originalValue); + } + + // Back-prop gradients to the input vector. + let dInput = matMulTransposed<InputSize>(resultGrad, weights); + + input = {input.p, dInput}; + } +} diff --git a/examples/mlp-training/mlvec_sw.slang b/examples/mlp-training/mlvec_sw.slang new file mode 100644 index 000000000..695755706 --- /dev/null +++ b/examples/mlp-training/mlvec_sw.slang @@ -0,0 +1,64 @@ +implementing mlp_sw; + +public struct MLVec<int N> : IDifferentiable +{ + public NFloat data[N]; + + [Differentiable] + public NFloat[N] toArray() + { + return data; + } + + [Differentiable] + public static MLVec<N> fromArray(NFloat[N] values) + { + MLVec<N> result; + [ForceUnroll] + for (int i = 0; i < N; i++) + result.data[i] = values[i]; + return result; + } +} + +MLVec<OutputSize> matMulAdd<int OutputSize, int InputSize>(MLVec<InputSize> input, NFloat* matrix, NFloat* bias) +{ + let getMatElem = (int row, int col) => matrix[row*InputSize + col]; + let getBias = (int idx) => bias[idx]; + MLVec<OutputSize> result = {}; + for (int i = 0; i < OutputSize; i++) + { + NFloat r = getBias(i); + for (int j = 0; j < InputSize; j++) + r += getMatElem(i, j) * input.data[j]; + result.data[i] = r; + } + return result; +} + +MLVec<OutputSize> matMulTransposed<int OutputSize, int InputSize>(MLVec<InputSize> input, NFloat* matrix) +{ + let getMatElem = (int row, int col) => matrix[col*OutputSize + row]; + MLVec<OutputSize> result = {}; + for (int i = 0; i < OutputSize; i++) + { + NFloat r = {}; + for (int j = 0; j < InputSize; j++) + r += getMatElem(i, j) * input.data[j]; + result.data[i] = r; + } + return result; +} + +void outerProductAccumulate<int M, int N>(MLVec<M> v0, MLVec<N> v1, NFloat* matrix) +{ + for (int i = 0; i < M; i++) + { + for (int j = 0; j < N; j++) + { + let elem = v0.data[i] * v1.data[j]; + half original; + InterlockedAddF16Emulated(matrix + (i*N + j), elem, original); + } + } +} diff --git a/examples/mlp-training/network.slang b/examples/mlp-training/network.slang new file mode 100644 index 000000000..a48820f11 --- /dev/null +++ b/examples/mlp-training/network.slang @@ -0,0 +1,59 @@ +module network; + +import common; +import mlp_sw; + +public struct MyNetwork +{ + public FeedForwardLayer<4, 16> layer1; + public FeedForwardLayer<16, 4> layer2; + + [Differentiable] + internal MLVec<4> encodeInput(NFloat x, NFloat y) + { + return MLVec<4>.fromArray({ + x, + y, + x*x, + y*y, + }); + } + + [Differentiable] + internal MLVec<4> _eval(NFloat x, NFloat y) + { + let encoding = encodeInput(x, y); + let layer1Output = layer1.eval(encoding); + let leyer2Output = layer2.eval(layer1Output); + return leyer2Output; + } + + [Differentiable] + public half4 eval(no_diff NFloat x, no_diff NFloat y) + { + let mlv = _eval(x, y); + let arr = mlv.toArray(); + return half4(arr[0], arr[1], arr[2], arr[3]); + } +} + +[Differentiable] +public half loss(MyNetwork* network, no_diff half x, no_diff half y) +{ + let networkResult = network.eval(x, y); + let gt = no_diff groundtruth(x, y); + let diff = networkResult - gt; + + return dot(diff, diff); +} + +public half4 groundtruth(half x, half y) +{ + return { + (x + y) / (1 + y * y), + 2 * x + y, + 0.5 * x * x + 1.2 * y, + x + 0.5 * y * y, + }; +} + diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index deaeae439..2e56c1082 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1809,12 +1809,12 @@ extension Ptr<void> __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } __generic<T, let addrSpace : uint64_t> - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(Ptr<T, addrSpace> ptr); __generic<T> - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(NativeRef<T> ptr); } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 32d7ea824..38f274984 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26610,16 +26610,24 @@ CoopVec<T, M> coopVecMatMulPacked( } } -/// Multiply a cooperative vector with a matrix. -/// @param input The input cooperative vector to multiply with the matrix. +/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and +/// returns a M-element vector. +/// @param input The K-element input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc). -/// @param matrix The matrix to multiply with the input vector. +/// @param matrix The M-by-K matrix to multiply with the input vector. /// @param matrixOffset Byte offset into the matrix buffer. /// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc). /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride in bytes between rows/columns of the matrix. /// @return A new cooperative vector containing the result of the matrix multiplication. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26650,7 +26658,9 @@ CoopVec<T, M> coopVecMatMul( matrixStride); } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector to the result. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). /// @param k The number of columns in the matrix. @@ -26667,6 +26677,14 @@ CoopVec<T, M> coopVecMatMul( /// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values /// can be packed into each element of the input vector. The k parameter specifies the actual number of /// values to use from the packed input. +/// +/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26804,7 +26822,9 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l } } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector. /// @param matrix The matrix buffer to multiply with. @@ -26817,6 +26837,13 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride between matrix rows/columns in bytes. /// @return A new cooperative vector containing the result of the matrix multiplication plus bias. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26862,7 +26889,9 @@ ${{{{ if(buffer.isRW) { }}}} -/// Accumulate the outer product of two cooperative vectors into a matrix. +/// Atomically accumulates the outer product of two cooperative vectors into a matrix. Given an M-element vector `a`, and an N-element vector `b`, +/// compute the outer product of `a` and `b`, forming a M-row by N-col matrix. The elements in the matrix is then atomically accumulated +/// to memory location represented by `matrix`. /// @param a The first cooperative vector. /// @param b The second cooperative vector. /// @param matrix The matrix buffer to accumulate the result into. @@ -26870,6 +26899,21 @@ if(buffer.isRW) /// @param matrixStride The stride between matrix rows/columns in bytes. /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param matrixInterpretation Specifies how to interpret the values in the matrix. +/// @remarks On current hardware, `memoryLayout` must be `TrainingOptimal`. +/// +/// When `memoryLayout` is `RowMajor`, this function is equivalent to: +/// +/// ``` +/// uint8_t* matrixPtr = matrix + matrixOffset; +/// for (int i = 0; i < M; i++) +/// { +/// for (int j = 0; j < N; j++) +/// { +/// let elem = a[i] * b[j]; +/// atomicAdd(matrixPtr + i * matrixStride + j * sizeof(T), elem); +/// } +/// } +/// ``` [require(cooperative_vector)] [require(hlsl_coopvec_poc)] [require(optix_coopvec)] @@ -26959,10 +27003,15 @@ void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let } } -/// Accumulate the sum of a cooperative vector into a buffer at the specified offset. +/// Atomically accumulates the elements a cooperative vector into a buffer at the specified offset. /// @param v The cooperative vector to sum. /// @param buffer The buffer to accumulate the sum into. /// @param offset Byte offset into the buffer. +/// @remarks This function is equivalent to: +/// ``` +/// for (int i = 0; i < N; i++) +/// atomicAdd(dest[i], v[i]); +/// ``` [require(cooperative_vector)] [require(hlsl_coopvec_poc)] [require(optix_coopvec)] @@ -27015,20 +27064,6 @@ static const struct { for(auto buffer : kStructuredBufferCases_) { }}}} -/// Multiply a cooperative vector with a matrix and return the result. -/// @param input The input cooperative vector to multiply with the matrix. -/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). -/// @param k The number of columns in the matrix. -/// @param matrix The matrix buffer to multiply with. -/// @param matrixOffset Byte offset into the matrix buffer. -/// @param matrixInterpretation Specifies how to interpret the values in the matrix. -/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). -/// @param transpose Whether to transpose the matrix before multiplication. -/// @param matrixStride The stride between matrix rows/columns in bytes. -/// @return A new cooperative vector containing the result of the matrix multiplication. -/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values -/// can be packed into each element of the input vector. The k parameter specifies the actual number of -/// values to use from the packed input. [require(spirv, cooperative_vector)] __generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType> CoopVec<T, M> coopVecMatMulPacked( @@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let } } +// Pointer overloads for coopvector operations. + +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMulPacked( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + int operands = 0; // NoneKHR + let zero = 0; + let cvtMatPtr = (Ptr<T[]>)matrixPtr; + if (__isSignedInt<T>()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; + }; + } +} + +// specialized coopVecMatMul for non-packed inputs +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMul( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + void* biasPtr, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let zero : int32_t = 0; + let cvtMatPtr = (Ptr<T[]>)matrixPtr; + let cvtBiasPtr = (Ptr<T[]>)biasPtr; + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + void* bias, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + bias, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>( + CoopVec<T, M> a, + CoopVec<T, N> b, + void* matrixPtr, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + let zero : int32_t = 0; + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let cvtMatrixPtr = (Ptr<T[]>)matrixPtr; + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>( + CoopVec<T, N> v, + void* buffer +) +{ + let zero : int32_t = 0; + let bufferPtr = (Ptr<T[]>)(buffer); + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v; + }; + } +} + //@public: /// Mark a variable as being workgroup uniform. @@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue) { return packHalf2x16(float2(unpackedValue)); } + +[require(spirv)] +void InterlockedAddF16Emulated(half* dest, half value, out half originalValue) +{ + let buf = (half2*)(dest); + uint64_t byteAddress = (uint64_t)dest; + if ((byteAddress & 3) == 0) + { + originalValue = __atomic_add(*buf, half2(value, half(0.0))).x; + } + else + { + originalValue = __atomic_add(*buf, half2(half(0.0), value)).y; + } +} + +[require(spirv)] +void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue) +{ + originalValue = __atomic_add(*dest, value); +}
\ No newline at end of file diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 88dea0b7e..56f9d873a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier FIDDLE(...) }; +FIDDLE() +class ExplicitlyDeclaredCapabilityModifier : public Modifier +{ + FIDDLE(...) + FIDDLE() CapabilitySet declaredCapabilityRequirements; +}; + // Marks a synthesized variable as local temporary variable. FIDDLE() class LocalTempVarModifier : public Modifier diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 506abc1be..6456dbe98 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor( if (!structDecl) return false; + if (!structDecl->checkState.isBeingChecked()) + ensureDecl(structDecl, DeclCheckState::AttributesChecked); + HashSet<Type*> isVisit; bool isCStyle = false; if (!_getSynthesizedConstructor( @@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto toMakeArrayFromElementExpr = m_astBuilder->create<MakeArrayFromElementExpr>(); toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc; toMakeArrayFromElementExpr->type = QualType(toType); - - *outToExpr = toMakeArrayFromElementExpr; + if (outToExpr) + *outToExpr = toMakeArrayFromElementExpr; return true; } for (UInt ee = 0; ee < elementCount; ++ee) @@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto defaultConstructExpr = m_astBuilder->create<DefaultConstructExpr>(); defaultConstructExpr->loc = fromInitializerListExpr->loc; defaultConstructExpr->type = QualType(toType); - - *outToExpr = defaultConstructExpr; + if (outToExpr) + *outToExpr = defaultConstructExpr; return true; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1a70e25d7..0dd859bb2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName())) { // Remove the `ToBeSynthesizedModifier`. - if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first)) + if (auto mod = existingDecl->modifiers.findModifier<ToBeSynthesizedModifier>()) { - existingDecl->modifiers.first = existingDecl->modifiers.first->next; + removeModifier(existingDecl, mod); } else { @@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requirementDeclRef.getDecl()->findModifier<VisibilityModifier>()) - { - auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(aggTypeDecl, visibility); - } + // The visibility of synthesized decl should be the same of the parent decl. + auto thisVisibility = getDeclVisibility(context->parentDecl); + addVisibilityModifier(aggTypeDecl, thisVisibility); // Synthesize the rest of IDifferential method conformances by recursively checking // conformance on the synthesized decl. @@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement( return false; } - auto satisfyingVal = - tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + IntVal* satisfyingVal = nullptr; + if (isValidCompileTimeConstantType(satisfyingType)) + { + satisfyingVal = + tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + } if (satisfyingVal) { witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); @@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl( return; } + memberDecl = maybeGetInner(memberDecl); if (hasDefaultImpl(requiredMemberDeclRef)) { - memberDecl = maybeGetInner(memberDecl); // If the required member has a default implementation, // we need to make sure the member we found is marked as 'override'. if (!memberDecl->hasModifier<OverrideModifier>()) @@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun } else { + auto declaredCapModifier = m_astBuilder->create<ExplicitlyDeclaredCapabilityModifier>(); + declaredCapModifier->declaredCapabilityRequirements = declaredCaps; + addModifier(funcDecl, declaredCapModifier); if (vis == DeclVisibility::Public) { // For public decls, we need to enforce that the function diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 306687bd8..b90081af8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = getName("Differential"); typeDef->parentDecl = structDecl; + addVisibilityModifier(structDecl, getDeclVisibility(parent)); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); @@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); synthesizedDecl = parent; parent->addDirectMemberDecl(typeDef); @@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( // to not allow such cases. // // Note that float-to-inst casts for non-`IntVal`s are allowed. - if (!isScalarIntegerType(decl->getType())) + if (!isValidCompileTimeConstantType(decl->getType())) { getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered); return nullptr; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index e6744071b..a360361f7 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) targetOptionSet.hasOption(CompilerOptionName::Capability) && (targetOptionSet.getIntOption(CompilerOptionName::Capability) != SLANG_CAPABILITY_UNKNOWN); + + if (auto declaredCapsMod = + entryPointFuncDecl->findModifier<ExplicitlyDeclaredCapabilityModifier>()) + { + // If the entry point has an explicitly declared capability, then we + // will merge that with the target capability set before checking if + // there is an implicit upgrade. + targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements); + } + // Only attempt to error if a specific profile or capability is requested if ((specificCapabilityRequested || specificProfileRequested) && targetCaps.atLeastOneSetImpliedInOther( diff --git a/tests/language-feature/capability/capabilitySimplification2.slang b/tests/language-feature/capability/capabilitySimplification2.slang index 8d96884ce..d50960b8e 100644 --- a/tests/language-feature/capability/capabilitySimplification2.slang +++ b/tests/language-feature/capability/capabilitySimplification2.slang @@ -1,27 +1,27 @@ -//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3 //TEST:SIMPLE(filecheck=GLSL): -target glsl -entry computeMain -stage compute -profile sm_5_0 //TEST:SIMPLE(filecheck=HLSL): -target hlsl -entry computeMain -stage compute -profile sm_5_0 -//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 -ignore-capabilities +//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3 -ignore-capabilities // CHECK_IGNORE_CAPS-NOT: warning 41012 // SPIRV: warning 41012 // SPIRV-NOT: spirv_1_2 -// SPIRV-NOT: spirv_1_3 // SPIRV-SAME: spvGroupNonUniformBallot // GLSL: warning 41012 // GLSL-NOT: GLSL_400 // GLSL-NOT: GLSL_430 -// GLSL-SAME: GL_KHR_shader_subgroup_ballot +// GLSL-SAME: GL_KHR_shader_subgroup_basic // HLSL: warning 41012 // HLSL-NOT: sm_5_1 // HLSL-SAME: sm_6_0 -[require(sm_6_0)] [numthreads(1,1,1)] void computeMain() { + if (WaveIsFirstLane()) + return; } |
