diff options
| author | James Helferty (NVIDIA) <jhelferty@nvidia.com> | 2025-08-27 20:08:54 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-28 03:08:54 +0000 |
| commit | 80ddf40274fdca93f2ae95a247ff3af122aec6ac (patch) | |
| tree | 859c48522f93cf4c1119221c1679c11b1160f857 | |
| parent | 1681bc67fbae57b54b66c5dcfcbf315d1efa831b (diff) | |
Add SPIRV OpCapability for 8/16bit use in storage (#8194)
Emits the appropriate OpCapability for 8- and 16-bit type usage:
- UniformAndStorageBuffer8BitAccess: for 16-bit types in
SpvStorageClassUniform and SpvStorageClassStorageBuffer
- UniformAndStorageBuffer16BitAccess: for 16-bit types in
SpvStorageClassUniform and SpvStorageClassStorageBuffer
- StoragePushConstant8: for 8-bit types in SpvStorageClassPushConstant
- StoragePushConstant16: for 16-bit types in SpvStorageClassPushConstant
- StorageInputOutput16: for 16-bit types in SpvStorageClassInput and
SpvStorageClassOutput
Generated with Claude Code, with revisions.
Fixes #7879.
---------
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: James Helferty (NVIDIA) <jhelferty-nv@users.noreply.github.com>
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 193 | ||||
| -rw-r--r-- | tests/spirv/capability-storage-input-output.slang | 44 | ||||
| -rw-r--r-- | tests/spirv/capability-storage-push-constant.slang | 35 | ||||
| -rw-r--r-- | tests/spirv/capability-uniform-and-storage.slang | 65 |
4 files changed, 336 insertions, 1 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 0facaac3f..2127141bd 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1802,6 +1802,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } auto valueType = ptrType->getValueType(); + + // Check for 8/16-bit storage capabilities when emitting pointer types + requireCapabilitiesForType(valueType, storageClass); + // If we haven't emitted the inner type yet, we need to emit a forward declaration. bool useForwardDeclaration = (!m_mapIRInstToSpvInst.containsKey(valueType) && as<IRStructType>(valueType) && @@ -3219,6 +3223,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex registerInst(param, systemValInst); return systemValInst; } + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), param, @@ -3243,6 +3248,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (ptrType->hasAddressSpace()) storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); } + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), globalVar, @@ -7976,6 +7982,191 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + // Type and constants used by checkTypeNeedsStorageCapability + typedef uint32_t TypeNeedsStorageFlags; + struct TypeNeedsStorageFlag + { + enum Enum : TypeNeedsStorageFlags + { + kNone = 0, + kElementSize8 = 1 << 0, + kElementSize16 = 1 << 1, + }; + }; + + // Helper function to recursively check if a storage type contains 8/16-bit types. + // Returns true if all requested targets have been found and therefore it can stop + // recursing. + bool checkTypeNeedsStorageCapability( + IRType* type, + const TypeNeedsStorageFlags targets, + TypeNeedsStorageFlags& found, + HashSet<IRType*>& visited) + { + if (visited.contains(type)) + return false; // Cycle detected, break recursion + + visited.add(type); + + switch (type->getOp()) + { + case kIROp_HalfType: + case kIROp_UInt16Type: + case kIROp_Int16Type: + found |= TypeNeedsStorageFlag::kElementSize16; + return (targets == (found & targets)); + + case kIROp_UInt8Type: + case kIROp_Int8Type: + found |= TypeNeedsStorageFlag::kElementSize8; + return (targets == (found & targets)); + + case kIROp_VectorType: + if (auto vectorType = as<IRVectorType>(type)) + return checkTypeNeedsStorageCapability( + vectorType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_MatrixType: + if (auto matrixType = as<IRMatrixType>(type)) + return checkTypeNeedsStorageCapability( + matrixType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + if (auto arrayType = as<IRArrayTypeBase>(type)) + return checkTypeNeedsStorageCapability( + arrayType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_StructType: + if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + if (checkTypeNeedsStorageCapability( + field->getFieldType(), + targets, + found, + visited)) + return true; + } + break; + + case kIROp_AtomicType: + if (auto atomicType = as<IRAtomicType>(type)) + return checkTypeNeedsStorageCapability( + atomicType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_PtrType: + case kIROp_RefType: + case kIROp_ConstRefType: + case kIROp_OutType: + case kIROp_InOutType: + if (auto ptrType = as<IRPtrTypeBase>(type)) + return checkTypeNeedsStorageCapability( + ptrType->getValueType(), + targets, + found, + visited); + break; + + case kIROp_RateQualifiedType: + if (auto ptrType = as<IRRateQualifiedType>(type)) + return checkTypeNeedsStorageCapability( + ptrType->getValueType(), + targets, + found, + visited); + break; + } + + return false; + } + + // Check and require 8/16-bit storage capabilities based on storage class and type + void requireCapabilitiesForType(IRType* type, SpvStorageClass storageClass) + { + // Search for specific aspects of the type, depending on the class. + TypeNeedsStorageFlags targets = TypeNeedsStorageFlag::kNone; + switch (storageClass) + { + case SpvStorageClassUniform: + case SpvStorageClassStorageBuffer: + if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer8BitAccess)) + targets |= TypeNeedsStorageFlag::kElementSize8; + if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer16BitAccess)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + case SpvStorageClassPushConstant: + if (!m_capabilities.contains(SpvCapabilityStoragePushConstant8)) + targets |= TypeNeedsStorageFlag::kElementSize8; + if (!m_capabilities.contains(SpvCapabilityStoragePushConstant16)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + case SpvStorageClassInput: + case SpvStorageClassOutput: + if (!m_capabilities.contains(SpvCapabilityStorageInputOutput16)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + } + + // If we've already enabled all possible capabilities for this storage class, there's no + // reason to search again. + if (targets == TypeNeedsStorageFlag::kNone) + return; + + HashSet<IRType*> visited; + TypeNeedsStorageFlags found = TypeNeedsStorageFlag::kNone; + checkTypeNeedsStorageCapability(type, targets, found, visited); + + if (found == TypeNeedsStorageFlag::kNone) + return; + + switch (storageClass) + { + case SpvStorageClassUniform: + case SpvStorageClassStorageBuffer: + if (found & TypeNeedsStorageFlag::kElementSize8) + { + ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage")); + requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer8BitAccess); + } + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess); + break; + + case SpvStorageClassPushConstant: + if (found & TypeNeedsStorageFlag::kElementSize8) + { + ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage")); + requireSPIRVCapability(SpvCapabilityStoragePushConstant8); + } + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityStoragePushConstant16); + break; + + case SpvStorageClassInput: + case SpvStorageClassOutput: + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityStorageInputOutput16); + break; + } + } + SpvInst* emitVectorOrScalarArithmetic( SpvInstParent* parent, IRInst* instToRegister, @@ -8400,7 +8591,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return ensureInst(m_voidType); IRBuilder builder(type); - if (const auto funcType = as<IRFuncType>(type)) + if (as<IRFuncType>(type)) { List<SpvInst*> argTypes; return emitOpDebugTypeFunction( diff --git a/tests/spirv/capability-storage-input-output.slang b/tests/spirv/capability-storage-input-output.slang new file mode 100644 index 000000000..0cccd26c7 --- /dev/null +++ b/tests/spirv/capability-storage-input-output.slang @@ -0,0 +1,44 @@ +//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK16: OpCapability StorageInputOutput16
+//CHECK-NOT: OpCapability StorageInputOutput16
+
+struct VertexInput {
+#ifdef IN_HALF
+ half4 position : POSITION;
+#else
+ float4 position : POSITION;
+#endif
+#ifdef IN_UINT16
+ uint16_t id : ID;
+#else
+ uint32_t id : ID;
+#endif
+};
+
+#ifdef OUT_HALF
+#define OUT_TYPE half4
+#else
+#define OUT_TYPE float4
+#endif
+
+struct VertexOutput {
+ float4 position : SV_POSITION;
+ OUT_TYPE color : COLOR;
+};
+
+[shader("vertex")]
+VertexOutput vertexMain(VertexInput input)
+{
+ VertexOutput output;
+ output.position = float4(input.position);
+ if (input.id == 0) {
+ output.color = OUT_TYPE(input.position);
+ } else {
+ output.color = OUT_TYPE(0);
+ }
+ return output;
+}
diff --git a/tests/spirv/capability-storage-push-constant.slang b/tests/spirv/capability-storage-push-constant.slang new file mode 100644 index 000000000..6cb540da9 --- /dev/null +++ b/tests/spirv/capability-storage-push-constant.slang @@ -0,0 +1,35 @@ +//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DCONST_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DCONST_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DCONST_HALF
+//TEST:SIMPLE(filecheck=CHECKBOTH): -target spirv -profile spirv_1_3 -DCONST_UINT8 -DCONST_HALF
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK8: OpCapability StoragePushConstant8
+//CHECK16: OpCapability StoragePushConstant16
+//CHECKBOTH-DAG: OpCapability StoragePushConstant8
+//CHECKBOTH-DAG: OpCapability StoragePushConstant16
+//CHECK-NOT: OpCapability StoragePushConstant16
+
+struct PushConstants {
+#if defined(CONST_HALF)
+ half4 color;
+#else
+ float4 color;
+#endif
+#if defined(CONST_UINT8)
+ int8_t index;
+#elif defined(CONST_UINT16)
+ int16_t index;
+#else
+ int32_t index;
+#endif
+};
+
+[[vk::push_constant]]
+PushConstants pushConstants;
+
+[shader("vertex")]
+float4 vertexMain() : SV_POSITION
+{
+ return float4(pushConstants.color);
+}
diff --git a/tests/spirv/capability-uniform-and-storage.slang b/tests/spirv/capability-uniform-and-storage.slang new file mode 100644 index 000000000..665a6b91a --- /dev/null +++ b/tests/spirv/capability-uniform-and-storage.slang @@ -0,0 +1,65 @@ +//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DIN_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECKBOTH): -target spirv -profile spirv_1_3 -DIN_UINT8 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DOUT_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF -DATOMIC
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK8: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECK8-NOT: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK16: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK16-NOT: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECKBOTH-DAG: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECKBOTH-DAG: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK-NOT: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECK-NOT: OpCapability UniformAndStorageBuffer16BitAccess
+
+
+uniform struct {
+#if defined(IN_HALF)
+ half4 data;
+#else
+ float4 data;
+#endif
+#if defined(IN_UINT8)
+ uint8_t index;
+#elif defined(IN_UINT16)
+ uint16_t index;
+#else
+ uint32_t index;
+#endif
+} inputBuffer;
+
+#if defined(OUT_HALF)
+#define OUT_FLOAT_TYPE half
+#else
+#define OUT_FLOAT_TYPE float
+#endif
+#if defined(OUT_UINT8)
+#define OUT_UINT_TYPE uint8_t
+#elif defined(OUT_UINT16)
+#define OUT_UINT_TYPE uint16_t
+#else
+#define OUT_UINT_TYPE uint32_t
+#endif
+
+struct st {
+#if defined(ATOMIC)
+ Atomic<OUT_FLOAT_TYPE> data;
+#else
+ OUT_FLOAT_TYPE data;
+#endif
+ OUT_UINT_TYPE index;
+};
+RWStructuredBuffer<st> outputBuffer;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ outputBuffer[0].data = OUT_FLOAT_TYPE(inputBuffer.data.x);
+ outputBuffer[1].index = OUT_UINT_TYPE(inputBuffer.index);
+}
|
