diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 193 |
1 files changed, 192 insertions, 1 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 0facaac3f..2127141bd 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1802,6 +1802,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } auto valueType = ptrType->getValueType(); + + // Check for 8/16-bit storage capabilities when emitting pointer types + requireCapabilitiesForType(valueType, storageClass); + // If we haven't emitted the inner type yet, we need to emit a forward declaration. bool useForwardDeclaration = (!m_mapIRInstToSpvInst.containsKey(valueType) && as<IRStructType>(valueType) && @@ -3219,6 +3223,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex registerInst(param, systemValInst); return systemValInst; } + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), param, @@ -3243,6 +3248,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (ptrType->hasAddressSpace()) storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); } + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), globalVar, @@ -7976,6 +7982,191 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + // Type and constants used by checkTypeNeedsStorageCapability + typedef uint32_t TypeNeedsStorageFlags; + struct TypeNeedsStorageFlag + { + enum Enum : TypeNeedsStorageFlags + { + kNone = 0, + kElementSize8 = 1 << 0, + kElementSize16 = 1 << 1, + }; + }; + + // Helper function to recursively check if a storage type contains 8/16-bit types. + // Returns true if all requested targets have been found and therefore it can stop + // recursing. + bool checkTypeNeedsStorageCapability( + IRType* type, + const TypeNeedsStorageFlags targets, + TypeNeedsStorageFlags& found, + HashSet<IRType*>& visited) + { + if (visited.contains(type)) + return false; // Cycle detected, break recursion + + visited.add(type); + + switch (type->getOp()) + { + case kIROp_HalfType: + case kIROp_UInt16Type: + case kIROp_Int16Type: + found |= TypeNeedsStorageFlag::kElementSize16; + return (targets == (found & targets)); + + case kIROp_UInt8Type: + case kIROp_Int8Type: + found |= TypeNeedsStorageFlag::kElementSize8; + return (targets == (found & targets)); + + case kIROp_VectorType: + if (auto vectorType = as<IRVectorType>(type)) + return checkTypeNeedsStorageCapability( + vectorType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_MatrixType: + if (auto matrixType = as<IRMatrixType>(type)) + return checkTypeNeedsStorageCapability( + matrixType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + if (auto arrayType = as<IRArrayTypeBase>(type)) + return checkTypeNeedsStorageCapability( + arrayType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_StructType: + if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + if (checkTypeNeedsStorageCapability( + field->getFieldType(), + targets, + found, + visited)) + return true; + } + break; + + case kIROp_AtomicType: + if (auto atomicType = as<IRAtomicType>(type)) + return checkTypeNeedsStorageCapability( + atomicType->getElementType(), + targets, + found, + visited); + break; + + case kIROp_PtrType: + case kIROp_RefType: + case kIROp_ConstRefType: + case kIROp_OutType: + case kIROp_InOutType: + if (auto ptrType = as<IRPtrTypeBase>(type)) + return checkTypeNeedsStorageCapability( + ptrType->getValueType(), + targets, + found, + visited); + break; + + case kIROp_RateQualifiedType: + if (auto ptrType = as<IRRateQualifiedType>(type)) + return checkTypeNeedsStorageCapability( + ptrType->getValueType(), + targets, + found, + visited); + break; + } + + return false; + } + + // Check and require 8/16-bit storage capabilities based on storage class and type + void requireCapabilitiesForType(IRType* type, SpvStorageClass storageClass) + { + // Search for specific aspects of the type, depending on the class. + TypeNeedsStorageFlags targets = TypeNeedsStorageFlag::kNone; + switch (storageClass) + { + case SpvStorageClassUniform: + case SpvStorageClassStorageBuffer: + if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer8BitAccess)) + targets |= TypeNeedsStorageFlag::kElementSize8; + if (!m_capabilities.contains(SpvCapabilityUniformAndStorageBuffer16BitAccess)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + case SpvStorageClassPushConstant: + if (!m_capabilities.contains(SpvCapabilityStoragePushConstant8)) + targets |= TypeNeedsStorageFlag::kElementSize8; + if (!m_capabilities.contains(SpvCapabilityStoragePushConstant16)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + case SpvStorageClassInput: + case SpvStorageClassOutput: + if (!m_capabilities.contains(SpvCapabilityStorageInputOutput16)) + targets |= TypeNeedsStorageFlag::kElementSize16; + break; + } + + // If we've already enabled all possible capabilities for this storage class, there's no + // reason to search again. + if (targets == TypeNeedsStorageFlag::kNone) + return; + + HashSet<IRType*> visited; + TypeNeedsStorageFlags found = TypeNeedsStorageFlag::kNone; + checkTypeNeedsStorageCapability(type, targets, found, visited); + + if (found == TypeNeedsStorageFlag::kNone) + return; + + switch (storageClass) + { + case SpvStorageClassUniform: + case SpvStorageClassStorageBuffer: + if (found & TypeNeedsStorageFlag::kElementSize8) + { + ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage")); + requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer8BitAccess); + } + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityUniformAndStorageBuffer16BitAccess); + break; + + case SpvStorageClassPushConstant: + if (found & TypeNeedsStorageFlag::kElementSize8) + { + ensureExtensionDeclarationBeforeSpv15(UnownedStringSlice("SPV_KHR_8bit_storage")); + requireSPIRVCapability(SpvCapabilityStoragePushConstant8); + } + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityStoragePushConstant16); + break; + + case SpvStorageClassInput: + case SpvStorageClassOutput: + if (found & TypeNeedsStorageFlag::kElementSize16) + requireSPIRVCapability(SpvCapabilityStorageInputOutput16); + break; + } + } + SpvInst* emitVectorOrScalarArithmetic( SpvInstParent* parent, IRInst* instToRegister, @@ -8400,7 +8591,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return ensureInst(m_voidType); IRBuilder builder(type); - if (const auto funcType = as<IRFuncType>(type)) + if (as<IRFuncType>(type)) { List<SpvInst*> argTypes; return emitOpDebugTypeFunction( |
