summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-spirv.cpp193
-rw-r--r--tests/spirv/capability-storage-input-output.slang44
-rw-r--r--tests/spirv/capability-storage-push-constant.slang35
-rw-r--r--tests/spirv/capability-uniform-and-storage.slang65
4 files changed, 336 insertions, 1 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 0facaac3f..2127141bd 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1802,6 +1802,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
auto valueType = ptrType->getValueType();
+
+ // Check for 8/16-bit storage capabilities when emitting pointer types
+ requireCapabilitiesForType(valueType, storageClass);
+
// If we haven't emitted the inner type yet, we need to emit a forward declaration.
bool useForwardDeclaration =
(!m_mapIRInstToSpvInst.containsKey(valueType) && as<IRStructType>(valueType) &&
@@ -3219,6 +3223,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
registerInst(param, systemValInst);
return systemValInst;
}
+
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
param,
@@ -3243,6 +3248,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (ptrType->hasAddressSpace())
storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
}
+
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
globalVar,
@@ -7976,6 +7982,191 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ // Type and constants used by checkTypeNeedsStorageCapability
+ typedef uint32_t TypeNeedsStorageFlags;
+ struct TypeNeedsStorageFlag
+ {
+ enum Enum : TypeNeedsStorageFlags
+ {
+ kNone = 0,
+ kElementSize8 = 1 << 0,
+ kElementSize16 = 1 << 1,
+ };
+ };
+
+ // Helper function to recursively check if a storage type contains 8/16-bit types.
+ // Returns true if all requested targets have been found and therefore it can stop
+ // recursing.
+ bool checkTypeNeedsStorageCapability(
+ IRType* type,
+ const TypeNeedsStorageFlags targets,
+ TypeNeedsStorageFlags& found,
+ HashSet<IRType*>& visited)
+ {
+ if (visited.contains(type))
+ return false; // Cycle detected, break recursion
+
+ visited.add(type);
+
+ switch (type->getOp())
+ {
+ case kIROp_HalfType:
+ case kIROp_UInt16Type:
+ case kIROp_Int16Type:
+ found |= TypeNeedsStorageFlag::kElementSize16;
+ return (targets == (found & targets));
+
+ case kIROp_UInt8Type:
+ case kIROp_Int8Type:
+ found |= TypeNeedsStorageFlag::kElementSize8;
+ return (targets == (found & targets));
+
+ case kIROp_VectorType:
+ if (auto vectorType = as<IRVectorType>(type))
+ return checkTypeNeedsStorageCapability(
+ vectorType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_MatrixType:
+ if (auto matrixType = as<IRMatrixType>(type))
+ return checkTypeNeedsStorageCapability(
+ matrixType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ return checkTypeNeedsStorageCapability(
+ arrayType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_StructType:
+ if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ if (checkTypeNeedsStorageCapability(
+ field->getFieldType(),
+ targets,
+ found,
+ visited))
+ return true;
+ }
+ break;
+
+ case kIROp_AtomicType:
+ if (auto atomicType = as<IRAtomicType>(type))
+ return checkTypeNeedsStorageCapability(
+ atomicType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_PtrType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ return checkTypeNeedsStorageCapability(
+ ptrType->getValueType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_RateQualifiedType:
+ if (auto ptrType = as<IRRateQualifiedType>(type))
+ return checkTypeNeedsStorageCapability(
+ ptrType->getValueType(),
+ targets,
+ found,
+ visited);
+ break;
+ }
+
+ return false;
+ }
+
+ // Check and require 8/16-bit storage capabilities based on storage class and type
+ void requireCapabilitiesForType(IRType* type, SpvStorageClass storageClass)
+ {
+ // Search for specific aspects of the type, depending on the class.
+ TypeNeedsStorageFlags targets = TypeNeedsStorageFlag::kNone;
+ switch (storageClass)
+ {
+ case SpvStorageClassUniform:
+ case SpvStorageClassStorageBuffer:
+ if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer8BitAccess))
+ targets |= TypeNeedsStorageFlag::kElementSize8;
+ if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer16BitAccess))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ case SpvStorageClassPushConstant:
+ if (!m_capabilities.contains(SpvCapabilityStoragePushConstant8))
+ targets |= TypeNeedsStorageFlag::kElementSize8;
+ if (!m_capabilities.contains(SpvCapabilityStoragePushConstant16))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ case SpvStorageClassInput:
+ case SpvStorageClassOutput:
+ if (!m_capabilities.contains(SpvCapabilityStorageInputOutput16))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ }
+
+ // If we've already enabled all possible capabilities for this storage class, there's no
+ // reason to search again.
+ if (targets == TypeNeedsStorageFlag::kNone)
+ return;
+
+ HashSet<IRType*> visited;
+ TypeNeedsStorageFlags found = TypeNeedsStorageFlag::kNone;
+ checkTypeNeedsStorageCapability(type, targets, found, visited);
+
+ if (found == TypeNeedsStorageFlag::kNone)
+ return;
+
+ switch (storageClass)
+ {
+ case SpvStorageClassUniform:
+ case SpvStorageClassStorageBuffer:
+ if (found & TypeNeedsStorageFlag::kElementSize8)
+ {
+ ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage"));
+ requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer8BitAccess);
+ }
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
+ break;
+
+ case SpvStorageClassPushConstant:
+ if (found & TypeNeedsStorageFlag::kElementSize8)
+ {
+ ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage"));
+ requireSPIRVCapability(SpvCapabilityStoragePushConstant8);
+ }
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityStoragePushConstant16);
+ break;
+
+ case SpvStorageClassInput:
+ case SpvStorageClassOutput:
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityStorageInputOutput16);
+ break;
+ }
+ }
+
SpvInst* emitVectorOrScalarArithmetic(
SpvInstParent* parent,
IRInst* instToRegister,
@@ -8400,7 +8591,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return ensureInst(m_voidType);
IRBuilder builder(type);
- if (const auto funcType = as<IRFuncType>(type))
+ if (as<IRFuncType>(type))
{
List<SpvInst*> argTypes;
return emitOpDebugTypeFunction(
diff --git a/tests/spirv/capability-storage-input-output.slang b/tests/spirv/capability-storage-input-output.slang
new file mode 100644
index 000000000..0cccd26c7
--- /dev/null
+++ b/tests/spirv/capability-storage-input-output.slang
@@ -0,0 +1,44 @@
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK16: OpCapability StorageInputOutput16
+//CHECK-NOT: OpCapability StorageInputOutput16
+
+struct VertexInput {
+#ifdef IN_HALF
+ half4 position : POSITION;
+#else
+ float4 position : POSITION;
+#endif
+#ifdef IN_UINT16
+ uint16_t id : ID;
+#else
+ uint32_t id : ID;
+#endif
+};
+
+#ifdef OUT_HALF
+#define OUT_TYPE half4
+#else
+#define OUT_TYPE float4
+#endif
+
+struct VertexOutput {
+ float4 position : SV_POSITION;
+ OUT_TYPE color : COLOR;
+};
+
+[shader("vertex")]
+VertexOutput vertexMain(VertexInput input)
+{
+ VertexOutput output;
+ output.position = float4(input.position);
+ if (input.id == 0) {
+ output.color = OUT_TYPE(input.position);
+ } else {
+ output.color = OUT_TYPE(0);
+ }
+ return output;
+}
diff --git a/tests/spirv/capability-storage-push-constant.slang b/tests/spirv/capability-storage-push-constant.slang
new file mode 100644
index 000000000..6cb540da9
--- /dev/null
+++ b/tests/spirv/capability-storage-push-constant.slang
@@ -0,0 +1,35 @@
+//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DCONST_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DCONST_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DCONST_HALF
+//TEST:SIMPLE(filecheck=CHECKBOTH): -target spirv -profile spirv_1_3 -DCONST_UINT8 -DCONST_HALF
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK8: OpCapability StoragePushConstant8
+//CHECK16: OpCapability StoragePushConstant16
+//CHECKBOTH-DAG: OpCapability StoragePushConstant8
+//CHECKBOTH-DAG: OpCapability StoragePushConstant16
+//CHECK-NOT: OpCapability StoragePushConstant16
+
+struct PushConstants {
+#if defined(CONST_HALF)
+ half4 color;
+#else
+ float4 color;
+#endif
+#if defined(CONST_UINT8)
+ int8_t index;
+#elif defined(CONST_UINT16)
+ int16_t index;
+#else
+ int32_t index;
+#endif
+};
+
+[[vk::push_constant]]
+PushConstants pushConstants;
+
+[shader("vertex")]
+float4 vertexMain() : SV_POSITION
+{
+ return float4(pushConstants.color);
+}
diff --git a/tests/spirv/capability-uniform-and-storage.slang b/tests/spirv/capability-uniform-and-storage.slang
new file mode 100644
index 000000000..665a6b91a
--- /dev/null
+++ b/tests/spirv/capability-uniform-and-storage.slang
@@ -0,0 +1,65 @@
+//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DIN_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECKBOTH): -target spirv -profile spirv_1_3 -DIN_UINT8 -DIN_HALF
+//TEST:SIMPLE(filecheck=CHECK8): -target spirv -profile spirv_1_3 -DOUT_UINT8
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_UINT16
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF
+//TEST:SIMPLE(filecheck=CHECK16): -target spirv -profile spirv_1_3 -DOUT_HALF -DATOMIC
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -profile spirv_1_3
+
+//CHECK8: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECK8-NOT: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK16: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK16-NOT: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECKBOTH-DAG: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECKBOTH-DAG: OpCapability UniformAndStorageBuffer16BitAccess
+//CHECK-NOT: OpCapability UniformAndStorageBuffer8BitAccess
+//CHECK-NOT: OpCapability UniformAndStorageBuffer16BitAccess
+
+
+uniform struct {
+#if defined(IN_HALF)
+ half4 data;
+#else
+ float4 data;
+#endif
+#if defined(IN_UINT8)
+ uint8_t index;
+#elif defined(IN_UINT16)
+ uint16_t index;
+#else
+ uint32_t index;
+#endif
+} inputBuffer;
+
+#if defined(OUT_HALF)
+#define OUT_FLOAT_TYPE half
+#else
+#define OUT_FLOAT_TYPE float
+#endif
+#if defined(OUT_UINT8)
+#define OUT_UINT_TYPE uint8_t
+#elif defined(OUT_UINT16)
+#define OUT_UINT_TYPE uint16_t
+#else
+#define OUT_UINT_TYPE uint32_t
+#endif
+
+struct st {
+#if defined(ATOMIC)
+ Atomic<OUT_FLOAT_TYPE> data;
+#else
+ OUT_FLOAT_TYPE data;
+#endif
+ OUT_UINT_TYPE index;
+};
+RWStructuredBuffer<st> outputBuffer;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain()
+{
+ outputBuffer[0].data = OUT_FLOAT_TYPE(inputBuffer.data.x);
+ outputBuffer[1].index = OUT_UINT_TYPE(inputBuffer.index);
+}