summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp193
1 files changed, 192 insertions, 1 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 0facaac3f..2127141bd 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1802,6 +1802,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
auto valueType = ptrType->getValueType();
+
+ // Check for 8/16-bit storage capabilities when emitting pointer types
+ requireCapabilitiesForType(valueType, storageClass);
+
// If we haven't emitted the inner type yet, we need to emit a forward declaration.
bool useForwardDeclaration =
(!m_mapIRInstToSpvInst.containsKey(valueType) && as<IRStructType>(valueType) &&
@@ -3219,6 +3223,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
registerInst(param, systemValInst);
return systemValInst;
}
+
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
param,
@@ -3243,6 +3248,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (ptrType->hasAddressSpace())
storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
}
+
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
globalVar,
@@ -7976,6 +7982,191 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ // Type and constants used by checkTypeNeedsStorageCapability
+ typedef uint32_t TypeNeedsStorageFlags;
+ struct TypeNeedsStorageFlag
+ {
+ enum Enum : TypeNeedsStorageFlags
+ {
+ kNone = 0,
+ kElementSize8 = 1 << 0,
+ kElementSize16 = 1 << 1,
+ };
+ };
+
+ // Helper function to recursively check if a storage type contains 8/16-bit types.
+ // Returns true if all requested targets have been found and therefore it can stop
+ // recursing.
+ bool checkTypeNeedsStorageCapability(
+ IRType* type,
+ const TypeNeedsStorageFlags targets,
+ TypeNeedsStorageFlags& found,
+ HashSet<IRType*>& visited)
+ {
+ if (visited.contains(type))
+ return false; // Cycle detected, break recursion
+
+ visited.add(type);
+
+ switch (type->getOp())
+ {
+ case kIROp_HalfType:
+ case kIROp_UInt16Type:
+ case kIROp_Int16Type:
+ found |= TypeNeedsStorageFlag::kElementSize16;
+ return (targets == (found & targets));
+
+ case kIROp_UInt8Type:
+ case kIROp_Int8Type:
+ found |= TypeNeedsStorageFlag::kElementSize8;
+ return (targets == (found & targets));
+
+ case kIROp_VectorType:
+ if (auto vectorType = as<IRVectorType>(type))
+ return checkTypeNeedsStorageCapability(
+ vectorType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_MatrixType:
+ if (auto matrixType = as<IRMatrixType>(type))
+ return checkTypeNeedsStorageCapability(
+ matrixType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ return checkTypeNeedsStorageCapability(
+ arrayType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_StructType:
+ if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ if (checkTypeNeedsStorageCapability(
+ field->getFieldType(),
+ targets,
+ found,
+ visited))
+ return true;
+ }
+ break;
+
+ case kIROp_AtomicType:
+ if (auto atomicType = as<IRAtomicType>(type))
+ return checkTypeNeedsStorageCapability(
+ atomicType->getElementType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_PtrType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ return checkTypeNeedsStorageCapability(
+ ptrType->getValueType(),
+ targets,
+ found,
+ visited);
+ break;
+
+ case kIROp_RateQualifiedType:
+ if (auto ptrType = as<IRRateQualifiedType>(type))
+ return checkTypeNeedsStorageCapability(
+ ptrType->getValueType(),
+ targets,
+ found,
+ visited);
+ break;
+ }
+
+ return false;
+ }
+
+ // Check and require 8/16-bit storage capabilities based on storage class and type
+ void requireCapabilitiesForType(IRType* type, SpvStorageClass storageClass)
+ {
+ // Search for specific aspects of the type, depending on the class.
+ TypeNeedsStorageFlags targets = TypeNeedsStorageFlag::kNone;
+ switch (storageClass)
+ {
+ case SpvStorageClassUniform:
+ case SpvStorageClassStorageBuffer:
+ if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer8BitAccess))
+ targets |= TypeNeedsStorageFlag::kElementSize8;
+ if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer16BitAccess))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ case SpvStorageClassPushConstant:
+ if (!m_capabilities.contains(SpvCapabilityStoragePushConstant8))
+ targets |= TypeNeedsStorageFlag::kElementSize8;
+ if (!m_capabilities.contains(SpvCapabilityStoragePushConstant16))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ case SpvStorageClassInput:
+ case SpvStorageClassOutput:
+ if (!m_capabilities.contains(SpvCapabilityStorageInputOutput16))
+ targets |= TypeNeedsStorageFlag::kElementSize16;
+ break;
+ }
+
+ // If we've already enabled all possible capabilities for this storage class, there's no
+ // reason to search again.
+ if (targets == TypeNeedsStorageFlag::kNone)
+ return;
+
+ HashSet<IRType*> visited;
+ TypeNeedsStorageFlags found = TypeNeedsStorageFlag::kNone;
+ checkTypeNeedsStorageCapability(type, targets, found, visited);
+
+ if (found == TypeNeedsStorageFlag::kNone)
+ return;
+
+ switch (storageClass)
+ {
+ case SpvStorageClassUniform:
+ case SpvStorageClassStorageBuffer:
+ if (found & TypeNeedsStorageFlag::kElementSize8)
+ {
+ ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage"));
+ requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer8BitAccess);
+ }
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
+ break;
+
+ case SpvStorageClassPushConstant:
+ if (found & TypeNeedsStorageFlag::kElementSize8)
+ {
+ ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage"));
+ requireSPIRVCapability(SpvCapabilityStoragePushConstant8);
+ }
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityStoragePushConstant16);
+ break;
+
+ case SpvStorageClassInput:
+ case SpvStorageClassOutput:
+ if (found & TypeNeedsStorageFlag::kElementSize16)
+ requireSPIRVCapability(SpvCapabilityStorageInputOutput16);
+ break;
+ }
+ }
+
SpvInst* emitVectorOrScalarArithmetic(
SpvInstParent* parent,
IRInst* instToRegister,
@@ -8400,7 +8591,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return ensureInst(m_voidType);
IRBuilder builder(type);
- if (const auto funcType = as<IRFuncType>(type))
+ if (as<IRFuncType>(type))
{
List<SpvInst*> argTypes;
return emitOpDebugTypeFunction(