summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-06-30 14:32:50 -0700
committerGitHub <noreply@github.com>2025-06-30 21:32:50 +0000
commitf28f67d988158d6c46f7ffe967152f98d32a37b2 (patch)
tree2aa620986a87ec69cf1f210c714312e42b62ac9e
parenta55ff722cae338a8fcf5402858c47cf0650a8e5e (diff)
Add MLP training examples. (#7550)
* Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check.
-rw-r--r--examples/CMakeLists.txt5
-rw-r--r--examples/mlp-training-coopvec/README.md6
-rw-r--r--examples/mlp-training-coopvec/adam.slang38
-rw-r--r--examples/mlp-training-coopvec/common.slang1
-rw-r--r--examples/mlp-training-coopvec/kernels.slang41
-rw-r--r--examples/mlp-training-coopvec/mlp-training-coopvec.cpp462
-rw-r--r--examples/mlp-training-coopvec/mlp.slang73
-rw-r--r--examples/mlp-training-coopvec/mlvec.slang63
-rw-r--r--examples/mlp-training-coopvec/network.slang58
-rw-r--r--examples/mlp-training/README.md6
-rw-r--r--examples/mlp-training/adam.slang38
-rw-r--r--examples/mlp-training/common.slang1
-rw-r--r--examples/mlp-training/kernels.slang41
-rw-r--r--examples/mlp-training/mlp-training.cpp389
-rw-r--r--examples/mlp-training/mlp_sw.slang59
-rw-r--r--examples/mlp-training/mlvec_sw.slang64
-rw-r--r--examples/mlp-training/network.slang59
-rw-r--r--source/slang/core.meta.slang4
-rw-r--r--source/slang/hlsl.meta.slang277
-rw-r--r--source/slang/slang-ast-modifier.h7
-rw-r--r--source/slang/slang-check-conversion.cpp11
-rw-r--r--source/slang/slang-check-decl.cpp28
-rw-r--r--source/slang/slang-check-expr.cpp5
-rw-r--r--source/slang/slang-check-shader.cpp10
-rw-r--r--tests/language-feature/capability/capabilitySimplification2.slang10
25 files changed, 1710 insertions, 46 deletions
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 9411d10ba..87b0b39f8 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -114,4 +114,9 @@ if(SLANG_ENABLE_EXAMPLES)
if(SLANG_ENABLE_AFTERMATH)
example(nv-aftermath-example WIN32_EXECUTABLE)
endif()
+
+ if(SLANG_ENABLE_SLANG_RHI)
+ example(mlp-training LINK_WITH_PRIVATE slang-rhi)
+ example(mlp-training-coopvec LINK_WITH_PRIVATE slang-rhi)
+ endif()
endif()
diff --git a/examples/mlp-training-coopvec/README.md b/examples/mlp-training-coopvec/README.md
new file mode 100644
index 000000000..8d56fff92
--- /dev/null
+++ b/examples/mlp-training-coopvec/README.md
@@ -0,0 +1,6 @@
+Slang "MLP-Training-CoopVec" Example
+==========================
+
+This example shows how to use the Slang to train a feed-forward neural network
+using automatic differentiation and the cooperative vector intrinsics. Also see the
+"MLP-Training" example for the same task without using cooperative vector. \ No newline at end of file
diff --git a/examples/mlp-training-coopvec/adam.slang b/examples/mlp-training-coopvec/adam.slang
new file mode 100644
index 000000000..33a357122
--- /dev/null
+++ b/examples/mlp-training-coopvec/adam.slang
@@ -0,0 +1,38 @@
+module adam;
+
+import mlp;
+import common;
+
+public struct AdamState
+{
+ internal NFloat mean;
+ internal NFloat variance;
+ internal int iteration;
+}
+
+public struct AdamOptimizer
+{
+ // Adam parameters
+ public static const NFloat beta1 = 0.9h;
+ public static const NFloat beta2 = 0.999h;
+ public static const NFloat epsilon = 1e-7h;
+ public static const NFloat learningRate = 0.01h;
+
+ public static void step(inout AdamState state, inout NFloat param, inout NFloat grad)
+ {
+ state.iteration++;
+ if (isinf(grad))
+ {
+ if (grad > 0)
+ grad = 10000.0h;
+ else
+ grad = -10000.0h;
+ }
+ state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad;
+ state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad;
+ NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration)));
+ NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration)));
+ param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon));
+ grad = NFloat(0.f);
+ }
+}
diff --git a/examples/mlp-training-coopvec/common.slang b/examples/mlp-training-coopvec/common.slang
new file mode 100644
index 000000000..92dc3b563
--- /dev/null
+++ b/examples/mlp-training-coopvec/common.slang
@@ -0,0 +1 @@
+public typealias NFloat = half; \ No newline at end of file
diff --git a/examples/mlp-training-coopvec/kernels.slang b/examples/mlp-training-coopvec/kernels.slang
new file mode 100644
index 000000000..712494b1f
--- /dev/null
+++ b/examples/mlp-training-coopvec/kernels.slang
@@ -0,0 +1,41 @@
+module kernels;
+
+import common;
+import mlp;
+import network;
+import adam;
+
+[numthreads(256, 1, 1)]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvCooperativeVectorNV)]
+void learnGradient(
+ uint32_t tid : SV_DispatchThreadID,
+ uniform MyNetwork* network,
+ uniform Atomic<uint32_t>* lossBuffer,
+ uniform float2* inputs,
+ uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+
+ var input = (half2)inputs[tid];
+ bwd_diff(loss)(network, input.x, input.y, 1.0h);
+ let thisLoss = (float)loss(network, input.x, input.y);
+ let maxLoss = WaveActiveMax(thisLoss);
+ if (WaveIsFirstLane())
+ {
+ lossBuffer.max(bit_cast<uint32_t>(maxLoss));
+ }
+}
+
+[numthreads(256, 1, 1)]
+void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+ if (isnan(gradients[tid]))
+ {
+ gradients[tid] = 0.0h;
+ return;
+ }
+ AdamOptimizer::step(states[tid], params[tid], gradients[tid]);
+} \ No newline at end of file
diff --git a/examples/mlp-training-coopvec/mlp-training-coopvec.cpp b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp
new file mode 100644
index 000000000..9b5dde531
--- /dev/null
+++ b/examples/mlp-training-coopvec/mlp-training-coopvec.cpp
@@ -0,0 +1,462 @@
+// In this example, we implement a simple multi-layer perceptron (MLP) training loop on
+// Vulkan (through slang-rhi), using cooperative vector intrinsics.
+//
+// The simple MLP is trained to approximate a polynomial expression.
+// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4
+// outputs.
+
+#include "core/slang-basic.h"
+#include "examples/example-base/example-base.h"
+#include "external/slang-rhi/include/slang-rhi.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+
+static const ExampleResources resourceBase("mlp-training-coopvec");
+
+typedef uint16_t NFloat;
+
+// Define the sizes of the layers in the MLP.
+static const int kLayerSizes[] = {4, 16, 4};
+static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1;
+
+using Slang::ComPtr;
+
+struct Kernel
+{
+ ComPtr<rhi::IShaderProgram> program;
+ ComPtr<rhi::IComputePipeline> pipeline;
+ explicit operator bool() { return program && pipeline; }
+};
+
+struct ClearBufferParams
+{
+ rhi::DeviceAddress buffer;
+ uint32_t count;
+};
+
+struct LearnGradParams
+{
+ rhi::DeviceAddress networkBuffer;
+ rhi::DeviceAddress lossBuffer;
+ rhi::DeviceAddress inputs;
+ uint32_t count;
+};
+
+struct AdjustParamsParams
+{
+ rhi::DeviceAddress adamStates;
+ rhi::DeviceAddress params;
+ rhi::DeviceAddress gradients;
+ uint32_t count;
+};
+
+struct ExampleProgram : public TestBase
+{
+ ComPtr<rhi::IDevice> gDevice;
+
+ ComPtr<slang::ISession> gSlangSession;
+ ComPtr<slang::IModule> gSlangModule;
+ Kernel gLearnGradProgram;
+ Kernel gAdjustParamProgram;
+
+ // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients).
+ //
+ struct NetworkParameterAllocation
+ {
+ size_t weightsOffset;
+ size_t weightsSize;
+ size_t biasOffset;
+ size_t biasSize;
+ size_t weightsGradOffset;
+ size_t biasGradOffset;
+ size_t weightsGradTrainingOffset;
+ size_t weightsGradTrainingSize;
+ };
+
+ SlangResult execute(int argc, char* argv[])
+ {
+ parseOption(argc, argv);
+
+ rhi::DeviceDesc deviceDesc;
+ deviceDesc.slang.targetProfile = "spirv_1_6";
+ deviceDesc.deviceType = rhi::DeviceType::Vulkan;
+
+ gDevice = rhi::getRHI()->createDevice(deviceDesc);
+ if (!gDevice)
+ return SLANG_FAIL;
+
+ SLANG_RETURN_ON_FAIL(loadShaderKernels());
+
+ // Create a buffer to hold all network parameters (weights, biases, gradients).
+ // This buffer is arranged as following:
+ // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN |
+ // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... |
+ // (segment 3): | weightsGradTraining0 | weightsGradTraining1 | ... |
+ //
+ // Where the first segment contains all weights and biases for each layer in row-major
+ // layout. The second segment contains gradients for weights and biases in row-major layout.
+ // The third segment contains gradients for weights in training-optimal layout.
+ // The training-optimal layout is used to accumulate gradients for weights with the
+ // `coopVecOuterProductAccumulate` intrinsic, which requires the destination to be in
+ // training-optimal layout.
+ // After accumulating gradients, we will convert them to row-major layout (i.e. copy
+ // them back into the second segment) so we can read them in the optimization kernel.
+
+ // Total size of all network parameters.
+ size_t networkParamsBufferSize;
+
+ // Offset for the second segment, where gradients for weights and biases in row-major layout
+ // start.
+ size_t networkGraidentOffset;
+
+ // Offset for the third segment, where gradients for weights in training-optimal layout
+ // start.
+ size_t networkGradientTrainingOffset;
+
+ // Sub-allocated weight/Bias offsets for each layer.
+ std::vector<NetworkParameterAllocation> layerAllocations;
+
+ // Allocate storage for network parameters, filling in `layerRowMajorAllocations`,
+ // `networkParamsBufferSize`, `networkGraidentOffset` and `networkGradientTrainingOffset`.
+ //
+ allocateNetworkParameterStorage(
+ layerAllocations,
+ networkParamsBufferSize,
+ networkGraidentOffset,
+ networkGradientTrainingOffset);
+
+ // We'll initialize the buffer with random values in the range [-1, 1].
+ std::vector<uint16_t> initParams;
+ srand(1072);
+ for (int i = 0; i < networkParamsBufferSize / sizeof(NFloat); i++)
+ {
+ float v = rand() / (float)RAND_MAX;
+ v = v * 2.0f - 1.0f; // Normalize to [-1, 1]
+ initParams.push_back(floatToHalf(v));
+ }
+ auto networkParamsBuffer = createBuffer(networkParamsBufferSize, initParams.data());
+
+ // Create a buffer for holding the Adam optimizer state for each network parameter.
+ static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t);
+ auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize);
+ clearBuffer(adamStateBuffer);
+
+ // Prepare buffer for the `network` struct that holds pointers to network parameters for
+ // each layer.
+ std::vector<uint64_t> networkConstantBufferData;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ networkConstantBufferData.push_back(
+ networkParamsBuffer->getDeviceAddress() + layerAllocations[i].weightsOffset);
+ networkConstantBufferData.push_back(
+ networkParamsBuffer->getDeviceAddress() +
+ layerAllocations[i].weightsGradTrainingOffset);
+ networkConstantBufferData.push_back(
+ networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasOffset);
+ networkConstantBufferData.push_back(
+ networkParamsBuffer->getDeviceAddress() + layerAllocations[i].biasGradOffset);
+ }
+ auto networkConstantBuffer = createBuffer(
+ networkConstantBufferData.size() * sizeof(uint64_t),
+ networkConstantBufferData.data());
+
+ // Create buffer for input data.
+ static const int inputCount = 32;
+ std::vector<float> inputBufferData;
+ for (int i = 0; i < inputCount; i++)
+ {
+ inputBufferData.push_back((float)rand() / RAND_MAX);
+ }
+ auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data());
+
+ // Create buffer for receiving current loss value.
+ auto lossBuffer = createBuffer(sizeof(uint64_t));
+
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+
+ // Run training loop.
+ for (int k = 0; k < 1000; k++)
+ {
+ clearBuffer(lossBuffer);
+
+ // Clear weight gradients in the parameter buffer to 0.
+ clearBuffer(
+ networkParamsBuffer,
+ rhi::BufferRange{
+ networkGradientTrainingOffset,
+ networkParamsBufferSize - networkGradientTrainingOffset});
+ // Compute gradients for weights and biases.
+ // The weight gradients are stored in the training-optimal layout.
+ {
+ LearnGradParams entryPointParams = {};
+ entryPointParams.inputs = inputBuffer->getDeviceAddress();
+ entryPointParams.count = inputCount / 2;
+ entryPointParams.lossBuffer = lossBuffer->getDeviceAddress();
+ entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress();
+ dispatchKernel(
+ gLearnGradProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+ // Copy weight gradients from training-optimal layout to row-major layout,
+ // so we can read them in the `adjustParameters` kernel.
+ {
+ std::vector<rhi::ConvertCooperativeVectorMatrixDesc> matrixDescs;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ rhi::ConvertCooperativeVectorMatrixDesc desc = {};
+ desc.rowCount = kLayerSizes[i + 1];
+ desc.colCount = kLayerSizes[i];
+ desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
+ desc.dstSize = &layerAllocations[i].weightsSize;
+ desc.dstData.deviceAddress = networkParamsBuffer->getDeviceAddress() +
+ layerAllocations[i].weightsGradOffset;
+ desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
+ desc.dstStride = getNetworkLayerWeightStride(i);
+ desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
+ desc.srcSize = layerAllocations[i].weightsGradTrainingSize;
+ desc.srcData.deviceAddress = networkParamsBuffer->getDeviceAddress() +
+ layerAllocations[i].weightsGradTrainingOffset;
+ desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
+ matrixDescs.push_back(desc);
+ }
+ auto encoder = queue->createCommandEncoder();
+ encoder->convertCooperativeVectorMatrix(
+ matrixDescs.data(),
+ (uint32_t)matrixDescs.size());
+ ComPtr<rhi::ICommandBuffer> commandBuffer;
+ encoder->finish(commandBuffer.writeRef());
+ queue->submit(commandBuffer);
+ }
+ // Adjust parameters in row-major buffer (adam optimize).
+ {
+ AdjustParamsParams entryPointParams = {};
+ entryPointParams.adamStates = adamStateBuffer->getDeviceAddress();
+ entryPointParams.params = networkParamsBuffer->getDeviceAddress();
+ entryPointParams.count =
+ (networkGradientTrainingOffset - networkGraidentOffset) / sizeof(NFloat);
+ entryPointParams.gradients =
+ networkParamsBuffer->getDeviceAddress() + networkGraidentOffset;
+ dispatchKernel(
+ gAdjustParamProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+
+ // Print loss value every 10 iterations.
+ if ((k + 1) % 10 == 0)
+ {
+ queue->waitOnHost();
+ ComPtr<ISlangBlob> blob;
+ gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef());
+ printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer());
+ }
+ }
+ return SLANG_OK;
+ }
+
+ // Allocate storage for network parameters, including weights, biases, and gradients.
+ void allocateNetworkParameterStorage(
+ std::vector<NetworkParameterAllocation>& paramStorage,
+ size_t& outParamBufferSize,
+ size_t& outGradientOffset,
+ size_t& outGradientTrainingOffset)
+ {
+ outParamBufferSize = 0;
+
+ auto allocRowMajorStorage = [&](size_t size)
+ {
+ size = (size + 63) / 64 * 64;
+ size_t offset = outParamBufferSize;
+ outParamBufferSize += size;
+ return offset;
+ };
+
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat);
+ NetworkParameterAllocation layerStorage = {};
+ layerStorage.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat);
+ layerStorage.weightsOffset = allocRowMajorStorage(layerStorage.weightsSize);
+ layerStorage.biasSize = biasSize;
+ layerStorage.biasOffset = allocRowMajorStorage(biasSize);
+ paramStorage.push_back(layerStorage);
+ }
+
+ // Alloc storage for weight and bias gradients (row major layout).
+ outGradientOffset = outParamBufferSize;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize);
+ paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize);
+ }
+
+ // Alloc training-optimal storage for weight gradients.
+ outGradientTrainingOffset = outParamBufferSize;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ // Allocate space for gradients in training-optimal layout.
+ rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {};
+ matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
+ matrixDesc.srcSize = paramStorage[i].weightsSize;
+ matrixDesc.srcData.hostAddress = nullptr;
+ matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
+ matrixDesc.srcStride = getNetworkLayerWeightStride(i);
+ matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
+ matrixDesc.dstSize = &paramStorage[i].weightsGradTrainingSize;
+ matrixDesc.dstData.hostAddress = nullptr;
+ matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
+ matrixDesc.dstStride = 0;
+ matrixDesc.rowCount = kLayerSizes[i + 1];
+ matrixDesc.colCount = kLayerSizes[i];
+ gDevice->convertCooperativeVectorMatrix(&matrixDesc, 1);
+ paramStorage[i].weightsGradTrainingOffset =
+ allocRowMajorStorage(paramStorage[i].weightsGradTrainingSize);
+ }
+ }
+
+ // Dispatch a compute kernel with the given arguments and number of work groups.
+ template<typename Args>
+ void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ ComPtr<rhi::ICommandEncoder> encoder;
+ queue->createCommandEncoder(encoder.writeRef());
+ {
+ auto computeEncoder = encoder->beginComputePass();
+ auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get());
+ rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args));
+ computeEncoder->dispatchCompute(numWorkGroups, 1, 1);
+ computeEncoder->end();
+ }
+ ComPtr<rhi::ICommandBuffer> commandBuffer;
+ encoder->finish(commandBuffer.writeRef());
+ queue->submit(commandBuffer);
+ }
+
+ // Create a buffer with the specified size and optional initial data.
+ ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr)
+ {
+ rhi::BufferDesc bufferDesc = {};
+ bufferDesc.size = size;
+ bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess;
+ bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination |
+ rhi::BufferUsage::UnorderedAccess;
+ bufferDesc.memoryType = rhi::MemoryType::DeviceLocal;
+ return gDevice->createBuffer(bufferDesc, initData);
+ }
+
+ void clearBuffer(rhi::IBuffer* buffer, rhi::BufferRange range = rhi::kEntireBuffer)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ auto encoder = queue->createCommandEncoder();
+ encoder->clearBuffer(buffer, range);
+ auto cmdBuffer = encoder->finish();
+ queue->submit(cmdBuffer);
+ }
+
+ SlangResult loadShaderKernels()
+ {
+ Slang::String path = resourceBase.resolveResource("kernels.slang");
+
+ gSlangSession = createSlangSession(gDevice);
+ gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer());
+ if (!gSlangModule)
+ return SLANG_FAIL;
+
+ gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient");
+ if (!gLearnGradProgram)
+ return SLANG_FAIL;
+
+ gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters");
+ if (!gAdjustParamProgram)
+ return SLANG_FAIL;
+
+ return SLANG_OK;
+ }
+
+ Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName)
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef());
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ entryPoint->link(linkedProgram.writeRef());
+
+ if (isTestMode())
+ {
+ printEntrypointHashes(1, 1, linkedProgram);
+ }
+
+ Kernel result;
+
+ rhi::ComputePipelineDesc desc;
+ auto program = gDevice->createShaderProgram(linkedProgram);
+ desc.program = program.get();
+ result.program = program;
+ result.pipeline = gDevice->createComputePipeline(desc);
+ return result;
+ }
+
+ static inline unsigned short floatToHalf(float val)
+ {
+ uint32_t x = 0;
+ memcpy(&x, &val, sizeof(float));
+
+ unsigned short bits = (x >> 16) & 0x8000;
+ unsigned short m = (x >> 12) & 0x07ff;
+ unsigned int e = (x >> 23) & 0xff;
+ if (e < 103)
+ return bits;
+ if (e > 142)
+ {
+ bits |= 0x7c00u;
+ bits |= e == 255 && (x & 0x007fffffu);
+ return bits;
+ }
+ if (e < 113)
+ {
+ m |= 0x0800u;
+ bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
+ return bits;
+ }
+ bits |= ((e - 112) << 10) | (m >> 1);
+ bits += m & 1;
+ return bits;
+ }
+
+ int getNetworkLayerWeightStride(int i) { return kLayerSizes[i] * sizeof(NFloat); }
+
+ int getNetworkLayerWeightCount(int i) { return kLayerSizes[i] * kLayerSizes[i + 1]; }
+
+ int getNetworkLayerBiasCount(int i) { return kLayerSizes[i + 1]; }
+
+ ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device)
+ {
+ ComPtr<slang::ISession> slangSession = device->getSlangSession();
+ return slangSession;
+ }
+
+ ComPtr<slang::IModule> compileShaderModuleFromFile(
+ slang::ISession* slangSession,
+ char const* filePath)
+ {
+ ComPtr<slang::IModule> slangModule;
+ ComPtr<slang::IBlob> diagnosticBlob;
+ Slang::String path = resourceBase.resolveResource(filePath);
+ slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef());
+ diagnoseIfNeeded(diagnosticBlob);
+
+ return slangModule;
+ }
+};
+
+int exampleMain(int argc, char** argv)
+{
+ ExampleProgram app;
+ if (SLANG_FAILED(app.execute(argc, argv)))
+ {
+ return -1;
+ }
+ return 0;
+}
diff --git a/examples/mlp-training-coopvec/mlp.slang b/examples/mlp-training-coopvec/mlp.slang
new file mode 100644
index 000000000..44380ad77
--- /dev/null
+++ b/examples/mlp-training-coopvec/mlp.slang
@@ -0,0 +1,73 @@
+module mlp;
+
+import common;
+
+__include mlvec;
+
+// We use Float16 for the CoopVec component type since it is more widely supported.
+//
+static const CoopVecComponentType kComponentType = CoopVecComponentType.Float16;
+
+public struct FeedForwardLayer<int InputSize, int OutputSize>
+{
+ internal void* weights;
+ internal void* weightsGrad;
+ internal void* biases;
+ internal void* biasesGrad;
+
+ public MLVec<OutputSize> eval(MLVec<InputSize> input)
+ {
+ // Compute mul(weights, inputVec) + biases.
+ // `weights` is treated as an OutputSize(row) x InputSize(col) matrix.
+ var output = coopVecMatMulAdd<NFloat, OutputSize>(
+ input.data, kComponentType, // input and format
+ weights, kComponentType, // weights and format
+ biases, kComponentType, // biases and format
+ CoopVecMatrixLayout.RowMajor, // matrix layout
+ false, // transpose matrix? must be `false` since we specified RowMajor.
+ InputSize * sizeof(NFloat)); // matrix stride
+ output = max(output, output * 0.001h); // Leaky ReLU activation
+ return {output};
+ }
+
+ [BackwardDerivativeOf(eval)]
+ public void evalBwd(
+ inout DifferentialPair<MLVec<InputSize>> input,
+ MLVec<OutputSize> resultGrad)
+ {
+ let fwd = eval(input.p);
+
+ // Back-prop resultGrad through activation.
+ [ForceUnroll]
+ for (int i = 0; i < OutputSize; i++)
+ {
+ if (fwd.data[i] < 0.0)
+ resultGrad.data[i] *= 0.01h;
+ }
+
+ // Back-prop gradients to the weights matrix.
+ coopVecOuterProductAccumulate(
+ resultGrad.data,
+ input.p.data,
+ weightsGrad,
+ 0, // matrixStride, ignored since layout is TrainingOptimal
+ CoopVecMatrixLayout.TrainingOptimal, // matrix layout, must be TrainingOptimal.
+ kComponentType);
+
+ // Back-prop gradients to the biases vector.
+ coopVecReduceSumAccumulate(resultGrad.data, (void*)biasesGrad);
+
+ // Back-prop gradients to the input vector by computing
+ // mul(transpose(weights), resultGrad).
+ // By specifying the matrix layout as ColumnMajor, we can
+ // achieve the effect of transposing the weights matrix.
+ let dInput = coopVecMatMul<NFloat, InputSize>(
+ resultGrad.data, kComponentType,
+ weights, kComponentType,
+ CoopVecMatrixLayout.ColumnMajor,
+ false, // transpose, must be `false` since we specified ColumnMajor.
+ InputSize * sizeof(NFloat));
+
+ input = {input.p, {dInput}};
+ }
+}
diff --git a/examples/mlp-training-coopvec/mlvec.slang b/examples/mlp-training-coopvec/mlvec.slang
new file mode 100644
index 000000000..ce7ce8352
--- /dev/null
+++ b/examples/mlp-training-coopvec/mlvec.slang
@@ -0,0 +1,63 @@
+implementing mlp;
+
+// A wrapper of CoopVec<T> to allow it being used in differentiable context.
+//
+public struct MLVec<int N> : IDifferentiable
+{
+ public CoopVec<NFloat, N> data;
+ public typealias Differential = MLVec<N>;
+
+ public static MLVec<N> fromArray(NFloat[N] values)
+ {
+ MLVec<N> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ result.data[i] = values[i];
+ return result;
+ }
+
+ internal static NFloat[N] coopVecToArray(CoopVec<NFloat, N> v)
+ {
+ NFloat[N] arr;
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ arr[i] = v[i];
+ return arr;
+ }
+
+ [BackwardDerivativeOf(fromArray)]
+ internal static void fromArrayBwd(inout DifferentialPair<NFloat[N]> values, MLVec<N> dResult)
+ {
+ values = diffPair(values.p, coopVecToArray(dResult.data));
+ }
+
+ internal static NFloat[N] toArray(MLVec<N> vec)
+ {
+ return coopVecToArray(vec.data);
+ }
+
+ [BackwardDerivativeOf(toArray)]
+ internal static void toArrayBwd(inout DifferentialPair<MLVec<N>> vec, NFloat[N] dResult)
+ {
+ vec = diffPair(vec.p, MLVec<N>.fromArray(dResult));
+ }
+
+ [Differentiable]
+ public NFloat[N] toArray()
+ {
+ return toArray(this);
+ }
+
+ public override static Differential dadd(Differential d0, Differential d1)
+ {
+ return {d0.data + d1.data};
+ }
+ public override static Differential dmul<U:__BuiltinRealType>(U s, Differential d)
+ {
+ return {d.data * __realCast<NFloat>(s)};
+ }
+ public override static Differential dzero()
+ {
+ return {};
+ }
+}
diff --git a/examples/mlp-training-coopvec/network.slang b/examples/mlp-training-coopvec/network.slang
new file mode 100644
index 000000000..5741487c4
--- /dev/null
+++ b/examples/mlp-training-coopvec/network.slang
@@ -0,0 +1,58 @@
+module network;
+
+import common;
+import mlp;
+
+public struct MyNetwork
+{
+ public FeedForwardLayer<4, 16> layer1;
+ public FeedForwardLayer<16, 4> layer2;
+
+ [Differentiable]
+ internal MLVec<4> encodeInput(NFloat x, NFloat y)
+ {
+ return MLVec<4>.fromArray({
+ x,
+ y,
+ x*x,
+ y*y,
+ });
+ }
+
+ [Differentiable]
+ internal MLVec<4> _eval(NFloat x, NFloat y)
+ {
+ let encoding = encodeInput(x, y);
+ let layer1Output = layer1.eval(encoding);
+ let leyer2Output = layer2.eval(layer1Output);
+ return leyer2Output;
+ }
+
+ [Differentiable]
+ public half4 eval(no_diff NFloat x, no_diff NFloat y)
+ {
+ let mlv = _eval(x, y);
+ let arr = mlv.toArray();
+ return half4(arr[0], arr[1], arr[2], arr[3]);
+ }
+}
+
+[Differentiable]
+public half loss(MyNetwork* network, no_diff half x, no_diff half y)
+{
+ let networkResult = network.eval(x, y);
+ let gt = no_diff groundtruth(x, y);
+ let diff = networkResult - gt;
+ return dot(diff, diff);
+}
+
+public half4 groundtruth(half x, half y)
+{
+ return {
+ (x + y) / (1 + y * y),
+ 2 * x + y,
+ 0.5 * x * x + 1.2 * y,
+ x + 0.5 * y * y,
+ };
+}
+
diff --git a/examples/mlp-training/README.md b/examples/mlp-training/README.md
new file mode 100644
index 000000000..c5266bbf1
--- /dev/null
+++ b/examples/mlp-training/README.md
@@ -0,0 +1,6 @@
+Slang "MLP-Training" Example
+==========================
+
+This example shows how to use the Slang to train a feed-forward neural network
+using automatic differentiation. Also see the "MLP-Training-CoopVec" example
+to see how to use the cooperative vector intrinsics to speedup training. \ No newline at end of file
diff --git a/examples/mlp-training/adam.slang b/examples/mlp-training/adam.slang
new file mode 100644
index 000000000..8f9b15f01
--- /dev/null
+++ b/examples/mlp-training/adam.slang
@@ -0,0 +1,38 @@
+module adam;
+
+import mlp_sw;
+import common;
+
+public struct AdamState
+{
+ internal NFloat mean;
+ internal NFloat variance;
+ internal int iteration;
+}
+
+public struct AdamOptimizer
+{
+ // Adam parameters
+ public static const NFloat beta1 = 0.9h;
+ public static const NFloat beta2 = 0.999h;
+ public static const NFloat epsilon = 1e-7h;
+ public static const NFloat learningRate = 0.01h;
+
+ public static void step(inout AdamState state, inout NFloat param, inout NFloat grad)
+ {
+ state.iteration++;
+ if (isinf(grad))
+ {
+ if (grad > 0)
+ grad = 10000.0h;
+ else
+ grad = -10000.0h;
+ }
+ state.mean = beta1 * state.mean + (NFloat(1.f) - beta1) * grad;
+ state.variance = beta2 * state.variance + (NFloat(1.f) - beta2) * grad * grad;
+ NFloat meanHat = state.mean / (NFloat(1.f) - pow(beta1, NFloat(state.iteration)));
+ NFloat varianceHat = state.variance / (NFloat(1.f) - pow(beta2, NFloat(state.iteration)));
+ param -= learningRate * meanHat / (sqrt(max(NFloat(0.f), varianceHat) + epsilon));
+ grad = NFloat(0.f);
+ }
+}
diff --git a/examples/mlp-training/common.slang b/examples/mlp-training/common.slang
new file mode 100644
index 000000000..92dc3b563
--- /dev/null
+++ b/examples/mlp-training/common.slang
@@ -0,0 +1 @@
+public typealias NFloat = half; \ No newline at end of file
diff --git a/examples/mlp-training/kernels.slang b/examples/mlp-training/kernels.slang
new file mode 100644
index 000000000..5be076879
--- /dev/null
+++ b/examples/mlp-training/kernels.slang
@@ -0,0 +1,41 @@
+module kernels;
+
+import common;
+import mlp_sw;
+import network;
+import adam;
+
+[numthreads(256, 1, 1)]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)]
+void learnGradient(
+ uint32_t tid : SV_DispatchThreadID,
+ uniform MyNetwork* network,
+ uniform Atomic<uint32_t>* lossBuffer,
+ uniform float2* inputs,
+ uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+
+ var input = (half2)inputs[tid];
+ bwd_diff(loss)(network, input.x, input.y, 1.0h);
+ let thisLoss = (float)loss(network, input.x, input.y);
+ let maxLoss = WaveActiveMax(thisLoss);
+ if (WaveIsFirstLane())
+ {
+ lossBuffer.max(bit_cast<uint32_t>(maxLoss));
+ }
+}
+
+[numthreads(256, 1, 1)]
+void adjustParameters(uint32_t tid : SV_DispatchThreadID, uniform AdamState* states, uniform NFloat* params, uniform NFloat* gradients, uniform uint32_t count)
+{
+ if (tid >= count)
+ return;
+ if (isnan(gradients[tid]))
+ {
+ gradients[tid] = 0.0h;
+ return;
+ }
+ AdamOptimizer::step(states[tid], params[tid], gradients[tid]);
+} \ No newline at end of file
diff --git a/examples/mlp-training/mlp-training.cpp b/examples/mlp-training/mlp-training.cpp
new file mode 100644
index 000000000..38090578f
--- /dev/null
+++ b/examples/mlp-training/mlp-training.cpp
@@ -0,0 +1,389 @@
+// In this example, we implement a simple multi-layer perceptron (MLP) training loop on
+// Vulkan (through slang-rhi). See also the mlp-training-coopvec example, which
+// implements the same MLP training loop using cooperative vector intrinsics for better
+// performance.
+//
+// The simple MLP is trained to approximate a polynomial expression.
+// The network contains one hidden layer with 16 neurons. It takes 4 inputs and produces 4
+// outputs.
+
+#include "core/slang-basic.h"
+#include "examples/example-base/example-base.h"
+#include "external/slang-rhi/include/slang-rhi.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+
+#include <string>
+
+using Slang::ComPtr;
+
+static const ExampleResources resourceBase("mlp-training");
+
+typedef uint16_t NFloat;
+
+static const int kLayerSizes[] = {4, 16, 4};
+static const int kLayerCount = sizeof(kLayerSizes) / sizeof(int) - 1;
+
+int getNetworkLayerWeightStride(int i)
+{
+ return kLayerSizes[i] * sizeof(NFloat);
+}
+
+int getNetworkLayerWeightCount(int i)
+{
+ return kLayerSizes[i] * kLayerSizes[i + 1];
+}
+
+int getNetworkLayerBiasCount(int i)
+{
+ return kLayerSizes[i + 1];
+}
+
+struct Kernel
+{
+ ComPtr<rhi::IShaderProgram> program;
+ ComPtr<rhi::IComputePipeline> pipeline;
+ operator bool() { return program && pipeline; }
+};
+
+struct ClearBufferParams
+{
+ rhi::DeviceAddress buffer;
+ uint32_t count;
+};
+
+struct LearnGradParams
+{
+ rhi::DeviceAddress networkBuffer;
+ rhi::DeviceAddress lossBuffer;
+ rhi::DeviceAddress inputs;
+ uint32_t count;
+};
+
+struct AdjustParamsParams
+{
+ rhi::DeviceAddress adamStates;
+ rhi::DeviceAddress params;
+ rhi::DeviceAddress gradients;
+ uint32_t count;
+};
+
+struct ExampleProgram : public TestBase
+{
+ ComPtr<rhi::IDevice> gDevice;
+
+ ComPtr<slang::ISession> gSlangSession;
+ ComPtr<slang::IModule> gSlangModule;
+ Kernel gLearnGradProgram;
+ Kernel gAdjustParamProgram;
+
+ // Sub-allocated buffer range for each network layer's parameters (weights, biases, gradients).
+ //
+ struct NetworkParameterAllocation
+ {
+ size_t weightsOffset;
+ size_t weightsSize;
+ size_t biasOffset;
+ size_t biasSize;
+ size_t weightsGradOffset;
+ size_t biasGradOffset;
+ };
+
+ SlangResult execute(int argc, char* argv[])
+ {
+ parseOption(argc, argv);
+
+ rhi::DeviceDesc deviceDesc;
+ deviceDesc.slang.targetProfile = "spirv_1_6";
+ deviceDesc.deviceType = rhi::DeviceType::Vulkan;
+
+ gDevice = rhi::getRHI()->createDevice(deviceDesc);
+ if (!gDevice)
+ return SLANG_FAIL;
+
+ SLANG_RETURN_ON_FAIL(loadShaderKernels());
+
+ // Create a buffer to hold all network parameters (weights, biases, gradients).
+ // This buffer is arranged as following:
+ // (segment 1): | weights0 | bias0 | weights1 | bias1 | ... | weightsN | biasN |
+ // (segment 2): | weightsGrad0 | biasGrad0 | weightsGrad1 | biasGrad1 | ... |
+ //
+ // Where the first segment contains all weights and biases for each layer in row-major
+ // layout. The second segment contains gradients for weights and biases in row-major layout.
+
+ // Total size of all network parameters.
+ size_t paramBufferSize;
+
+ // Offset for the second segment, where gradients for weights and biases in row-major layout
+ // start.
+ size_t gradientOffset;
+
+ // Sub-allocated weight/Bias offsets for each layer.
+ std::vector<NetworkParameterAllocation> layerAllocations;
+ allocateNetworkParameterStorage(layerAllocations, paramBufferSize, gradientOffset);
+
+ std::vector<uint16_t> initParams;
+ srand(1072);
+ for (int i = 0; i < paramBufferSize / sizeof(NFloat); i++)
+ {
+ if (i < gradientOffset / sizeof(NFloat))
+ {
+ float v = rand() / (float)RAND_MAX;
+ v = v * 2.0f - 1.0f; // Normalize to [-1, 1]
+ initParams.push_back(floatToHalf(v));
+ }
+ else
+ {
+ // Initialize gradients to zero.
+ initParams.push_back(0);
+ }
+ }
+ auto networkParamsBuffer = createBuffer(paramBufferSize, initParams.data());
+
+ static const size_t kAdamStateSize = sizeof(NFloat) * 2 + sizeof(int32_t);
+ auto adamStateBuffer = createBuffer(initParams.size() * kAdamStateSize);
+ clearBuffer(adamStateBuffer);
+
+ std::vector<uint64_t> networkConstantBufferData;
+ auto paramBufferAddr = networkParamsBuffer->getDeviceAddress();
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].weightsOffset);
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].weightsGradOffset);
+ networkConstantBufferData.push_back(paramBufferAddr + layerAllocations[i].biasOffset);
+ networkConstantBufferData.push_back(
+ paramBufferAddr + layerAllocations[i].biasGradOffset);
+ }
+ auto networkConstantBuffer = createBuffer(
+ networkConstantBufferData.size() * sizeof(uint64_t),
+ networkConstantBufferData.data());
+
+ static const int inputCount = 32;
+ std::vector<float> inputBufferData;
+ for (int i = 0; i < inputCount; i++)
+ {
+ inputBufferData.push_back((float)rand() / RAND_MAX);
+ }
+ auto inputBuffer = createBuffer(inputCount * sizeof(float), inputBufferData.data());
+
+ // Create buffer for receiving current loss value.
+ auto lossBuffer = createBuffer(sizeof(uint64_t));
+
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+
+ for (int k = 0; k < 1000; k++)
+ {
+ clearBuffer(lossBuffer);
+
+ // Compute gradients.
+ {
+ LearnGradParams entryPointParams = {};
+ entryPointParams.inputs = inputBuffer->getDeviceAddress();
+ entryPointParams.count = inputCount / 2;
+ entryPointParams.lossBuffer = lossBuffer->getDeviceAddress();
+ entryPointParams.networkBuffer = networkConstantBuffer->getDeviceAddress();
+ dispatchKernel(
+ gLearnGradProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+ // Adjust parameters in row-major buffer (adam optimize).
+ {
+ AdjustParamsParams entryPointParams = {};
+ entryPointParams.adamStates = adamStateBuffer->getDeviceAddress();
+ entryPointParams.params = networkParamsBuffer->getDeviceAddress();
+ entryPointParams.count = (paramBufferSize - gradientOffset) / sizeof(NFloat);
+ entryPointParams.gradients =
+ networkParamsBuffer->getDeviceAddress() + gradientOffset;
+ dispatchKernel(
+ gAdjustParamProgram,
+ entryPointParams,
+ (entryPointParams.count + 255) / 256);
+ }
+ if ((k + 1) % 10 == 0)
+ {
+ queue->waitOnHost();
+ ComPtr<ISlangBlob> blob;
+ gDevice->readBuffer(lossBuffer, 0, sizeof(float), blob.writeRef());
+ printf("Loss after %d iterations: %f\n", k + 1, *(float*)blob->getBufferPointer());
+ }
+ }
+ return SLANG_OK;
+ }
+
+ // Allocate storage for network parameters, including weights, biases, and gradients.
+ void allocateNetworkParameterStorage(
+ std::vector<NetworkParameterAllocation>& paramStorage,
+ size_t& outParamBufferSize,
+ size_t& outGradientOffset)
+ {
+ outParamBufferSize = 0;
+
+ auto allocRowMajorStorage = [&](size_t size)
+ {
+ size = (size + 63) / 64 * 64;
+ size_t offset = outParamBufferSize;
+ outParamBufferSize += size;
+ return offset;
+ };
+
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ size_t biasSize = getNetworkLayerBiasCount(i) * sizeof(NFloat);
+ NetworkParameterAllocation layer = {};
+ layer.weightsSize = getNetworkLayerWeightCount(i) * sizeof(NFloat);
+ layer.weightsOffset = allocRowMajorStorage(layer.weightsSize);
+ layer.biasSize = biasSize;
+ layer.biasOffset = allocRowMajorStorage(biasSize);
+ paramStorage.push_back(layer);
+ }
+
+ // Alloc storage for gradients.
+ outGradientOffset = outParamBufferSize;
+ for (int i = 0; i < kLayerCount; i++)
+ {
+ paramStorage[i].weightsGradOffset = allocRowMajorStorage(paramStorage[i].weightsSize);
+ paramStorage[i].biasGradOffset = allocRowMajorStorage(paramStorage[i].biasSize);
+ }
+ }
+
+ template<typename Args>
+ void dispatchKernel(Kernel& kernel, Args& args, size_t numWorkGroups)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ ComPtr<rhi::ICommandEncoder> encoder;
+ queue->createCommandEncoder(encoder.writeRef());
+ {
+ auto computeEncoder = encoder->beginComputePass();
+ auto rootShaderObject = computeEncoder->bindPipeline(kernel.pipeline.get());
+ rootShaderObject->getEntryPoint(0)->setData(rhi::ShaderOffset(), &args, sizeof(args));
+ computeEncoder->dispatchCompute(numWorkGroups, 1, 1);
+ computeEncoder->end();
+ }
+ ComPtr<rhi::ICommandBuffer> commandBuffer;
+ encoder->finish(commandBuffer.writeRef());
+ queue->submit(commandBuffer);
+ }
+
+ // Create a buffer with the specified size and optional initial data.
+ ComPtr<rhi::IBuffer> createBuffer(size_t size, void* initData = nullptr)
+ {
+ rhi::BufferDesc bufferDesc = {};
+ bufferDesc.size = size;
+ bufferDesc.defaultState = rhi::ResourceState::UnorderedAccess;
+ bufferDesc.usage = rhi::BufferUsage::CopySource | rhi::BufferUsage::CopyDestination |
+ rhi::BufferUsage::UnorderedAccess;
+ bufferDesc.memoryType = rhi::MemoryType::DeviceLocal;
+ return gDevice->createBuffer(bufferDesc, initData);
+ }
+
+ void clearBuffer(rhi::IBuffer* buffer)
+ {
+ auto queue = gDevice->getQueue(rhi::QueueType::Graphics);
+ auto encoder = queue->createCommandEncoder();
+ encoder->clearBuffer(buffer);
+ auto cmdBuffer = encoder->finish();
+ queue->submit(cmdBuffer);
+ }
+
+ Kernel loadComputeProgram(slang::IModule* slangModule, char const* entryPointName)
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ slangModule->findEntryPointByName(entryPointName, entryPoint.writeRef());
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ entryPoint->link(linkedProgram.writeRef());
+
+ if (isTestMode())
+ {
+ printEntrypointHashes(1, 1, linkedProgram);
+ }
+
+ Kernel result;
+
+ rhi::ComputePipelineDesc desc;
+ auto program = gDevice->createShaderProgram(linkedProgram);
+ desc.program = program.get();
+ result.program = program;
+ result.pipeline = gDevice->createComputePipeline(desc);
+ return result;
+ }
+
+ inline unsigned short floatToHalf(float val)
+ {
+ uint32_t x = 0;
+ memcpy(&x, &val, sizeof(float));
+
+ unsigned short bits = (x >> 16) & 0x8000;
+ unsigned short m = (x >> 12) & 0x07ff;
+ unsigned int e = (x >> 23) & 0xff;
+ if (e < 103)
+ return bits;
+ if (e > 142)
+ {
+ bits |= 0x7c00u;
+ bits |= e == 255 && (x & 0x007fffffu);
+ return bits;
+ }
+ if (e < 113)
+ {
+ m |= 0x0800u;
+ bits |= (m >> (114 - e)) + ((m >> (113 - e)) & 1);
+ return bits;
+ }
+ bits |= ((e - 112) << 10) | (m >> 1);
+ bits += m & 1;
+ return bits;
+ }
+
+ ComPtr<slang::ISession> createSlangSession(rhi::IDevice* device)
+ {
+ ComPtr<slang::ISession> slangSession = device->getSlangSession();
+ return slangSession;
+ }
+
+ ComPtr<slang::IModule> compileShaderModuleFromFile(
+ slang::ISession* slangSession,
+ char const* filePath)
+ {
+ ComPtr<slang::IModule> slangModule;
+ ComPtr<slang::IBlob> diagnosticBlob;
+ Slang::String path = resourceBase.resolveResource(filePath);
+ slangModule = slangSession->loadModule(path.getBuffer(), diagnosticBlob.writeRef());
+ diagnoseIfNeeded(diagnosticBlob);
+
+ return slangModule;
+ }
+
+ SlangResult loadShaderKernels()
+ {
+ Slang::String path = resourceBase.resolveResource("kernels.slang");
+
+ gSlangSession = createSlangSession(gDevice);
+ gSlangModule = compileShaderModuleFromFile(gSlangSession, path.getBuffer());
+ if (!gSlangModule)
+ return SLANG_FAIL;
+
+ gLearnGradProgram = loadComputeProgram(gSlangModule, "learnGradient");
+ if (!gLearnGradProgram)
+ return SLANG_FAIL;
+
+ gAdjustParamProgram = loadComputeProgram(gSlangModule, "adjustParameters");
+ if (!gAdjustParamProgram)
+ return SLANG_FAIL;
+
+ return SLANG_OK;
+ }
+};
+
+int exampleMain(int argc, char** argv)
+{
+ ExampleProgram app;
+ if (SLANG_FAILED(app.execute(argc, argv)))
+ {
+ return -1;
+ }
+ return 0;
+}
diff --git a/examples/mlp-training/mlp_sw.slang b/examples/mlp-training/mlp_sw.slang
new file mode 100644
index 000000000..1e222b99f
--- /dev/null
+++ b/examples/mlp-training/mlp_sw.slang
@@ -0,0 +1,59 @@
+module mlp_sw;
+
+import common;
+
+__include mlvec_sw;
+
+public struct FeedForwardLayer<int InputSize, int OutputSize>
+{
+ public NFloat* weights;
+ public NFloat* weightsGrad;
+ public NFloat* biases;
+ public NFloat* biasesGrad;
+
+ [BackwardDerivative(evalBwd)]
+ public MLVec<OutputSize> eval(MLVec<InputSize> input)
+ {
+ var output = matMulAdd<OutputSize>(
+ input,
+ weights,
+ biases);
+ // ReLU activation
+ for (int i = 0; i < OutputSize; i++)
+ if (output.data[i] < 0.0)
+ output.data[i] *= 0.001h;
+ return output;
+ }
+
+ public void evalBwd(
+ inout DifferentialPair<MLVec<InputSize>> input,
+ MLVec<OutputSize> resultGrad)
+ {
+ let fwd = eval(input.p);
+
+ // Back-prop resultGrad through activation.
+ for (int i = 0; i < OutputSize; i++)
+ {
+ if (fwd.data[i] < 0.0)
+ resultGrad.data[i] *= 0.01h;
+ }
+
+ // Back-prop gradients to the weights matrix.
+ outerProductAccumulate(
+ resultGrad,
+ input.p,
+ weightsGrad);
+
+ // Back-prop gradients to the biases vector.
+ for (int i = 0; i < OutputSize; i++)
+ {
+ NFloat originalValue;
+ InterlockedAddF16Emulated(biasesGrad + i, resultGrad.data[i], originalValue);
+ }
+
+ // Back-prop gradients to the input vector.
+ let dInput = matMulTransposed<InputSize>(resultGrad, weights);
+
+ input = {input.p, dInput};
+ }
+}
diff --git a/examples/mlp-training/mlvec_sw.slang b/examples/mlp-training/mlvec_sw.slang
new file mode 100644
index 000000000..695755706
--- /dev/null
+++ b/examples/mlp-training/mlvec_sw.slang
@@ -0,0 +1,64 @@
+implementing mlp_sw;
+
+public struct MLVec<int N> : IDifferentiable
+{
+ public NFloat data[N];
+
+ [Differentiable]
+ public NFloat[N] toArray()
+ {
+ return data;
+ }
+
+ [Differentiable]
+ public static MLVec<N> fromArray(NFloat[N] values)
+ {
+ MLVec<N> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ result.data[i] = values[i];
+ return result;
+ }
+}
+
+MLVec<OutputSize> matMulAdd<int OutputSize, int InputSize>(MLVec<InputSize> input, NFloat* matrix, NFloat* bias)
+{
+ let getMatElem = (int row, int col) => matrix[row*InputSize + col];
+ let getBias = (int idx) => bias[idx];
+ MLVec<OutputSize> result = {};
+ for (int i = 0; i < OutputSize; i++)
+ {
+ NFloat r = getBias(i);
+ for (int j = 0; j < InputSize; j++)
+ r += getMatElem(i, j) * input.data[j];
+ result.data[i] = r;
+ }
+ return result;
+}
+
+MLVec<OutputSize> matMulTransposed<int OutputSize, int InputSize>(MLVec<InputSize> input, NFloat* matrix)
+{
+ let getMatElem = (int row, int col) => matrix[col*OutputSize + row];
+ MLVec<OutputSize> result = {};
+ for (int i = 0; i < OutputSize; i++)
+ {
+ NFloat r = {};
+ for (int j = 0; j < InputSize; j++)
+ r += getMatElem(i, j) * input.data[j];
+ result.data[i] = r;
+ }
+ return result;
+}
+
+void outerProductAccumulate<int M, int N>(MLVec<M> v0, MLVec<N> v1, NFloat* matrix)
+{
+ for (int i = 0; i < M; i++)
+ {
+ for (int j = 0; j < N; j++)
+ {
+ let elem = v0.data[i] * v1.data[j];
+ half original;
+ InterlockedAddF16Emulated(matrix + (i*N + j), elem, original);
+ }
+ }
+}
diff --git a/examples/mlp-training/network.slang b/examples/mlp-training/network.slang
new file mode 100644
index 000000000..a48820f11
--- /dev/null
+++ b/examples/mlp-training/network.slang
@@ -0,0 +1,59 @@
+module network;
+
+import common;
+import mlp_sw;
+
+public struct MyNetwork
+{
+ public FeedForwardLayer<4, 16> layer1;
+ public FeedForwardLayer<16, 4> layer2;
+
+ [Differentiable]
+ internal MLVec<4> encodeInput(NFloat x, NFloat y)
+ {
+ return MLVec<4>.fromArray({
+ x,
+ y,
+ x*x,
+ y*y,
+ });
+ }
+
+ [Differentiable]
+ internal MLVec<4> _eval(NFloat x, NFloat y)
+ {
+ let encoding = encodeInput(x, y);
+ let layer1Output = layer1.eval(encoding);
+ let leyer2Output = layer2.eval(layer1Output);
+ return leyer2Output;
+ }
+
+ [Differentiable]
+ public half4 eval(no_diff NFloat x, no_diff NFloat y)
+ {
+ let mlv = _eval(x, y);
+ let arr = mlv.toArray();
+ return half4(arr[0], arr[1], arr[2], arr[3]);
+ }
+}
+
+[Differentiable]
+public half loss(MyNetwork* network, no_diff half x, no_diff half y)
+{
+ let networkResult = network.eval(x, y);
+ let gt = no_diff groundtruth(x, y);
+ let diff = networkResult - gt;
+
+ return dot(diff, diff);
+}
+
+public half4 groundtruth(half x, half y)
+{
+ return {
+ (x + y) / (1 + y * y),
+ 2 * x + y,
+ 0.5 * x * x + 1.2 * y,
+ x + 0.5 * y * y,
+ };
+}
+
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index deaeae439..2e56c1082 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1809,12 +1809,12 @@ extension Ptr<void>
__init(NativeString nativeStr) { this = nativeStr.getBuffer(); }
__generic<T, let addrSpace : uint64_t>
- __intrinsic_op(0)
+ __intrinsic_op($(kIROp_BitCast))
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
__init(Ptr<T, addrSpace> ptr);
__generic<T>
- __intrinsic_op(0)
+ __intrinsic_op($(kIROp_BitCast))
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
__init(NativeRef<T> ptr);
}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 32d7ea824..38f274984 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -26610,16 +26610,24 @@ CoopVec<T, M> coopVecMatMulPacked(
}
}
-/// Multiply a cooperative vector with a matrix.
-/// @param input The input cooperative vector to multiply with the matrix.
+/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and
+/// returns a M-element vector.
+/// @param input The K-element input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc).
-/// @param matrix The matrix to multiply with the input vector.
+/// @param matrix The M-by-K matrix to multiply with the input vector.
/// @param matrixOffset Byte offset into the matrix buffer.
/// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc).
/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
/// @param transpose Whether to transpose the matrix before multiplication.
/// @param matrixStride The stride in bytes between rows/columns of the matrix.
/// @return A new cooperative vector containing the result of the matrix multiplication.
+/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26650,7 +26658,9 @@ CoopVec<T, M> coopVecMatMul(
matrixStride);
}
-/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// Multiply a matrix with a cooperative vector and add a bias vector to the result.
+/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and
+/// returns a M-element vector.
/// @param input The input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
/// @param k The number of columns in the matrix.
@@ -26667,6 +26677,14 @@ CoopVec<T, M> coopVecMatMul(
/// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values
/// can be packed into each element of the input vector. The k parameter specifies the actual number of
/// values to use from the packed input.
+///
+/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26804,7 +26822,9 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
}
}
-/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// Multiply a matrix with a cooperative vector and add a bias vector.
+/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and
+/// returns a M-element vector.
/// @param input The input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector.
/// @param matrix The matrix buffer to multiply with.
@@ -26817,6 +26837,13 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
/// @param transpose Whether to transpose the matrix before multiplication.
/// @param matrixStride The stride between matrix rows/columns in bytes.
/// @return A new cooperative vector containing the result of the matrix multiplication plus bias.
+/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26862,7 +26889,9 @@ ${{{{
if(buffer.isRW)
{
}}}}
-/// Accumulate the outer product of two cooperative vectors into a matrix.
+/// Atomically accumulates the outer product of two cooperative vectors into a matrix. Given an M-element vector `a`, and an N-element vector `b`,
+/// compute the outer product of `a` and `b`, forming a M-row by N-col matrix. The elements in the matrix is then atomically accumulated
+/// to memory location represented by `matrix`.
/// @param a The first cooperative vector.
/// @param b The second cooperative vector.
/// @param matrix The matrix buffer to accumulate the result into.
@@ -26870,6 +26899,21 @@ if(buffer.isRW)
/// @param matrixStride The stride between matrix rows/columns in bytes.
/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @remarks On current hardware, `memoryLayout` must be `TrainingOptimal`.
+///
+/// When `memoryLayout` is `RowMajor`, this function is equivalent to:
+///
+/// ```
+/// uint8_t* matrixPtr = matrix + matrixOffset;
+/// for (int i = 0; i < M; i++)
+/// {
+/// for (int j = 0; j < N; j++)
+/// {
+/// let elem = a[i] * b[j];
+/// atomicAdd(matrixPtr + i * matrixStride + j * sizeof(T), elem);
+/// }
+/// }
+/// ```
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
[require(optix_coopvec)]
@@ -26959,10 +27003,15 @@ void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let
}
}
-/// Accumulate the sum of a cooperative vector into a buffer at the specified offset.
+/// Atomically accumulates the elements a cooperative vector into a buffer at the specified offset.
/// @param v The cooperative vector to sum.
/// @param buffer The buffer to accumulate the sum into.
/// @param offset Byte offset into the buffer.
+/// @remarks This function is equivalent to:
+/// ```
+/// for (int i = 0; i < N; i++)
+/// atomicAdd(dest[i], v[i]);
+/// ```
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
[require(optix_coopvec)]
@@ -27015,20 +27064,6 @@ static const struct {
for(auto buffer : kStructuredBufferCases_) {
}}}}
-/// Multiply a cooperative vector with a matrix and return the result.
-/// @param input The input cooperative vector to multiply with the matrix.
-/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
-/// @param k The number of columns in the matrix.
-/// @param matrix The matrix buffer to multiply with.
-/// @param matrixOffset Byte offset into the matrix buffer.
-/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
-/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
-/// @param transpose Whether to transpose the matrix before multiplication.
-/// @param matrixStride The stride between matrix rows/columns in bytes.
-/// @return A new cooperative vector containing the result of the matrix multiplication.
-/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values
-/// can be packed into each element of the input vector. The k parameter specifies the actual number of
-/// values to use from the packed input.
[require(spirv, cooperative_vector)]
__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType>
CoopVec<T, M> coopVecMatMulPacked(
@@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let
}
}
+// Pointer overloads for coopvector operations.
+
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMulPacked(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ void* matrixPtr,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ int operands = 0; // NoneKHR
+ let zero = 0;
+ let cvtMatPtr = (Ptr<T[]>)matrixPtr;
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands;
+ };
+ }
+}
+
+// specialized coopVecMatMul for non-packed inputs
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMul(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ void* matrix,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ void* matrixPtr,
+ constexpr CoopVecComponentType matrixInterpretation,
+ void* biasPtr,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let zero : int32_t = 0;
+ let cvtMatPtr = (Ptr<T[]>)matrixPtr;
+ let cvtBiasPtr = (Ptr<T[]>)biasPtr;
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ void* matrix,
+ constexpr CoopVecComponentType matrixInterpretation,
+ void* bias,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulAddPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixInterpretation,
+ bias,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>(
+ CoopVec<T, M> a,
+ CoopVec<T, N> b,
+ void* matrixPtr,
+ constexpr uint matrixStride,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr CoopVecComponentType matrixInterpretation,
+)
+{
+ let zero : int32_t = 0;
+ __target_switch
+ {
+ case spirv:
+ let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let cvtMatrixPtr = (Ptr<T[]>)matrixPtr;
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>(
+ CoopVec<T, N> v,
+ void* buffer
+)
+{
+ let zero : int32_t = 0;
+ let bufferPtr = (Ptr<T[]>)(buffer);
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v;
+ };
+ }
+}
+
//@public:
/// Mark a variable as being workgroup uniform.
@@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue)
{
return packHalf2x16(float2(unpackedValue));
}
+
+[require(spirv)]
+void InterlockedAddF16Emulated(half* dest, half value, out half originalValue)
+{
+ let buf = (half2*)(dest);
+ uint64_t byteAddress = (uint64_t)dest;
+ if ((byteAddress & 3) == 0)
+ {
+ originalValue = __atomic_add(*buf, half2(value, half(0.0))).x;
+ }
+ else
+ {
+ originalValue = __atomic_add(*buf, half2(half(0.0), value)).y;
+ }
+}
+
+[require(spirv)]
+void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue)
+{
+ originalValue = __atomic_add(*dest, value);
+} \ No newline at end of file
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 88dea0b7e..56f9d873a 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier
FIDDLE(...)
};
+FIDDLE()
+class ExplicitlyDeclaredCapabilityModifier : public Modifier
+{
+ FIDDLE(...)
+ FIDDLE() CapabilitySet declaredCapabilityRequirements;
+};
+
// Marks a synthesized variable as local temporary variable.
FIDDLE()
class LocalTempVarModifier : public Modifier
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 506abc1be..6456dbe98 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor(
if (!structDecl)
return false;
+ if (!structDecl->checkState.isBeingChecked())
+ ensureDecl(structDecl, DeclCheckState::AttributesChecked);
+
HashSet<Type*> isVisit;
bool isCStyle = false;
if (!_getSynthesizedConstructor(
@@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList(
auto toMakeArrayFromElementExpr = m_astBuilder->create<MakeArrayFromElementExpr>();
toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc;
toMakeArrayFromElementExpr->type = QualType(toType);
-
- *outToExpr = toMakeArrayFromElementExpr;
+ if (outToExpr)
+ *outToExpr = toMakeArrayFromElementExpr;
return true;
}
for (UInt ee = 0; ee < elementCount; ++ee)
@@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList(
auto defaultConstructExpr = m_astBuilder->create<DefaultConstructExpr>();
defaultConstructExpr->loc = fromInitializerListExpr->loc;
defaultConstructExpr->type = QualType(toType);
-
- *outToExpr = defaultConstructExpr;
+ if (outToExpr)
+ *outToExpr = defaultConstructExpr;
return true;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1a70e25d7..0dd859bb2 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName()))
{
// Remove the `ToBeSynthesizedModifier`.
- if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first))
+ if (auto mod = existingDecl->modifiers.findModifier<ToBeSynthesizedModifier>())
{
- existingDecl->modifiers.first = existingDecl->modifiers.first->next;
+ removeModifier(existingDecl, mod);
}
else
{
@@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- // The visibility of synthesized decl should be the min of the parent decl and the requirement.
- if (requirementDeclRef.getDecl()->findModifier<VisibilityModifier>())
- {
- auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl());
- auto thisVisibility = getDeclVisibility(context->parentDecl);
- auto visibility = Math::Min(thisVisibility, requirementVisibility);
- addVisibilityModifier(aggTypeDecl, visibility);
- }
+ // The visibility of synthesized decl should be the same of the parent decl.
+ auto thisVisibility = getDeclVisibility(context->parentDecl);
+ addVisibilityModifier(aggTypeDecl, thisVisibility);
// Synthesize the rest of IDifferential method conformances by recursively checking
// conformance on the synthesized decl.
@@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement(
return false;
}
- auto satisfyingVal =
- tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr);
+ IntVal* satisfyingVal = nullptr;
+ if (isValidCompileTimeConstantType(satisfyingType))
+ {
+ satisfyingVal =
+ tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr);
+ }
if (satisfyingVal)
{
witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal));
@@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl(
return;
}
+ memberDecl = maybeGetInner(memberDecl);
if (hasDefaultImpl(requiredMemberDeclRef))
{
- memberDecl = maybeGetInner(memberDecl);
// If the required member has a default implementation,
// we need to make sure the member we found is marked as 'override'.
if (!memberDecl->hasModifier<OverrideModifier>())
@@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun
}
else
{
+ auto declaredCapModifier = m_astBuilder->create<ExplicitlyDeclaredCapabilityModifier>();
+ declaredCapModifier->declaredCapabilityRequirements = declaredCaps;
+ addModifier(funcDecl, declaredCapModifier);
if (vis == DeclVisibility::Public)
{
// For public decls, we need to enforce that the function
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 306687bd8..b90081af8 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
typeDef->nameAndLoc.name = getName("Differential");
typeDef->parentDecl = structDecl;
+ addVisibilityModifier(structDecl, getDeclVisibility(parent));
+ addVisibilityModifier(typeDef, getDeclVisibility(parent));
auto synthDeclRef =
createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
@@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
typeDef->type.type =
calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef());
+ addVisibilityModifier(typeDef, getDeclVisibility(parent));
synthesizedDecl = parent;
parent->addDirectMemberDecl(typeDef);
@@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
// to not allow such cases.
//
// Note that float-to-inst casts for non-`IntVal`s are allowed.
- if (!isScalarIntegerType(decl->getType()))
+ if (!isValidCompileTimeConstantType(decl->getType()))
{
getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered);
return nullptr;
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index e6744071b..a360361f7 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink)
targetOptionSet.hasOption(CompilerOptionName::Capability) &&
(targetOptionSet.getIntOption(CompilerOptionName::Capability) !=
SLANG_CAPABILITY_UNKNOWN);
+
+ if (auto declaredCapsMod =
+ entryPointFuncDecl->findModifier<ExplicitlyDeclaredCapabilityModifier>())
+ {
+ // If the entry point has an explicitly declared capability, then we
+ // will merge that with the target capability set before checking if
+ // there is an implicit upgrade.
+ targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements);
+ }
+
// Only attempt to error if a specific profile or capability is requested
if ((specificCapabilityRequested || specificProfileRequested) &&
targetCaps.atLeastOneSetImpliedInOther(
diff --git a/tests/language-feature/capability/capabilitySimplification2.slang b/tests/language-feature/capability/capabilitySimplification2.slang
index 8d96884ce..d50960b8e 100644
--- a/tests/language-feature/capability/capabilitySimplification2.slang
+++ b/tests/language-feature/capability/capabilitySimplification2.slang
@@ -1,27 +1,27 @@
-//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0
+//TEST:SIMPLE(filecheck=SPIRV): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3
//TEST:SIMPLE(filecheck=GLSL): -target glsl -entry computeMain -stage compute -profile sm_5_0
//TEST:SIMPLE(filecheck=HLSL): -target hlsl -entry computeMain -stage compute -profile sm_5_0
-//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile sm_5_0 -ignore-capabilities
+//TEST:SIMPLE(filecheck=CHECK_IGNORE_CAPS): -target spirv -emit-spirv-directly -entry computeMain -stage compute -profile spirv_1_3 -ignore-capabilities
// CHECK_IGNORE_CAPS-NOT: warning 41012
// SPIRV: warning 41012
// SPIRV-NOT: spirv_1_2
-// SPIRV-NOT: spirv_1_3
// SPIRV-SAME: spvGroupNonUniformBallot
// GLSL: warning 41012
// GLSL-NOT: GLSL_400
// GLSL-NOT: GLSL_430
-// GLSL-SAME: GL_KHR_shader_subgroup_ballot
+// GLSL-SAME: GL_KHR_shader_subgroup_basic
// HLSL: warning 41012
// HLSL-NOT: sm_5_1
// HLSL-SAME: sm_6_0
-[require(sm_6_0)]
[numthreads(1,1,1)]
void computeMain()
{
+ if (WaveIsFirstLane())
+ return;
}