diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-08-20 06:06:34 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-19 15:06:34 -0700 |
| commit | f77a5ac9d1547a4394bba4ab8e94d905972c79b7 (patch) | |
| tree | 0d66b3c8386d8cb1e75970c93914fe2a60f03c61 /source | |
| parent | 453683bf44f2112719802eaac2b332d49eebd640 (diff) | |
Remove using SpvStorageClass values casted into AddressSpace values (#4861)
* Remove using SpvStorageClass values casted into AddressSpace values
Also removes support for specific storage classes in __target_intrinsic snippets
* remove SLANG_RETURN_NEVER macro
* squash warnings
* Make nonexhaustive switch statement error on gcc
* Add SLANG_EXHAUSTIVE_SWITCH_BEGIN/END macros
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/slang-common.h | 24 | ||||
| -rw-r--r-- | source/core/slang-signal.cpp | 2 | ||||
| -rw-r--r-- | source/core/slang-signal.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 97 | ||||
| -rw-r--r-- | source/slang/slang-ir-explicit-global-context.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-legalize-types.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-type-system-shared.h | 41 |
15 files changed, 208 insertions, 138 deletions
diff --git a/source/core/slang-common.h b/source/core/slang-common.h index 362a509a7..907bf4593 100644 --- a/source/core/slang-common.h +++ b/source/core/slang-common.h @@ -158,6 +158,30 @@ public: #define UNREACHABLE_RETURN(x) return x; #endif +#if SLANG_GCC +# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \ + _Pragma("GCC diagnostic push"); \ + _Pragma("GCC diagnostic error \"-Wswitch-enum\""); +# define SLANG_EXHAUSTIVE_SWITCH_END \ + _Pragma("GCC diagnostic pop"); +#elif SLANG_CLANG +# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \ + _Pragma("clang diagnostic push"); \ + _Pragma("clang diagnostic error \"-Wswitch-enum\""); +# define SLANG_EXHAUSTIVE_SWITCH_END \ + _Pragma("clang diagnostic pop"); +#elif SLANG_VC +# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \ + _Pragma("warning(push)"); \ + _Pragma("warning(error : 4062)"); +# define SLANG_EXHAUSTIVE_SWITCH_END \ + _Pragma("warning(pop)"); +#else +# define SLANG_EXHAUSTIVE_SWITCH_BEGIN +# define SLANG_EXHAUSTIVE_SWITCH_END +#endif + + // // Use `SLANG_ASSUME(myBoolExpression);` to inform the compiler that the condition is true. // Do not rely on side effects of the condition being performed. diff --git a/source/core/slang-signal.cpp b/source/core/slang-signal.cpp index d8218b379..5f53cba93 100644 --- a/source/core/slang-signal.cpp +++ b/source/core/slang-signal.cpp @@ -36,7 +36,7 @@ String _getMessage(SignalType type, char const* message) // One point of having as a single function is a choke point both for handling (allowing different // handling scenarios) as well as a choke point to set a breakpoint to catch 'signal' types -SLANG_RETURN_NEVER void handleSignal(SignalType type, char const* message) +[[noreturn]] void handleSignal(SignalType type, char const* message) { StringBuilder buf; const char*const typeText = _getSignalTypeAsText(type); diff --git a/source/core/slang-signal.h b/source/core/slang-signal.h index 2151bdcfe..759581ee2 100644 --- a/source/core/slang-signal.h +++ b/source/core/slang-signal.h @@ -18,7 +18,7 @@ enum class SignalType // Note that message can be passed as nullptr for no message. -SLANG_RETURN_NEVER void handleSignal(SignalType type, char const* message); +[[noreturn]] void handleSignal(SignalType type, char const* message); #define SLANG_UNEXPECTED(reason) \ ::Slang::handleSignal(::Slang::SignalType::Unexpected, reason) diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 86ad0be6e..50d71b6b6 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2639,7 +2639,7 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { - if(addressSpace == (AddressSpace)SpvStorageClassTaskPayloadWorkgroupEXT) + if(addressSpace == AddressSpace::TaskPayloadWorkgroup) { m_writer->emit("taskPayloadSharedEXT "); } diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 302e2b6b7..0b282c4d9 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -700,7 +700,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) { auto ptrType = cast<IRPtrTypeBase>(type); emitType((IRType*)ptrType->getValueType()); - switch ((AddressSpace)ptrType->getAddressSpace()) + switch (ptrType->getAddressSpace()) { case AddressSpace::Global: m_writer->emit(" device"); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index bb0a2565c..20cbbac02 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -324,7 +324,6 @@ struct SpvSnippetEmitContext bool isResultTypeFloat; // True if resultType is signed. bool isResultTypeSigned; - Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes; List<SpvWord> argumentIds; }; @@ -1223,17 +1222,59 @@ struct SPIRVEmitContext return m_NonSemanticDebugPrintfExtInst; } - SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace) + static SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace) { + SLANG_EXHAUSTIVE_SWITCH_BEGIN switch (addrSpace) { case AddressSpace::Generic: return SpvStorageClassMax; + case AddressSpace::ThreadLocal: + return SpvStorageClassPrivate; + case AddressSpace::GroupShared: + return SpvStorageClassWorkgroup; + case AddressSpace::Uniform: + return SpvStorageClassUniform; + case AddressSpace::Input: + return SpvStorageClassInput; + case AddressSpace::Output: + return SpvStorageClassOutput; + case AddressSpace::TaskPayloadWorkgroup: + return SpvStorageClassTaskPayloadWorkgroupEXT; + case AddressSpace::Function: + return SpvStorageClassFunction; + case AddressSpace::StorageBuffer: + return SpvStorageClassStorageBuffer; + case AddressSpace::PushConstant: + return SpvStorageClassPushConstant; + case AddressSpace::RayPayloadKHR: + return SpvStorageClassRayPayloadKHR; + case AddressSpace::IncomingRayPayload: + return SpvStorageClassIncomingRayPayloadKHR; + case AddressSpace::CallableDataKHR: + return SpvStorageClassCallableDataKHR; + case AddressSpace::IncomingCallableData: + return SpvStorageClassIncomingCallableDataKHR; + case AddressSpace::HitObjectAttribute: + return SpvStorageClassHitObjectAttributeNV; + case AddressSpace::HitAttribute: + return SpvStorageClassHitAttributeKHR; + case AddressSpace::ShaderRecordBuffer: + return SpvStorageClassShaderRecordBufferKHR; + case AddressSpace::UniformConstant: + return SpvStorageClassUniformConstant; + case AddressSpace::Image: + return SpvStorageClassImage; case AddressSpace::UserPointer: return SpvStorageClassPhysicalStorageBuffer; - default: - return (SpvStorageClass)addrSpace; + case AddressSpace::Global: + case AddressSpace::MetalObjectData: + // msvc is limiting us from putting the UNEXPECTED macro here, so + // just fall out + ; } + SLANG_UNEXPECTED("Unhandled AddressSpace in addressSpaceToStorageClass"); + SLANG_EXHAUSTIVE_SWITCH_END } // Now that we've gotten the core infrastructure out of the way, @@ -1829,7 +1870,7 @@ struct SPIRVEmitContext SpvStorageClass storageClass = SpvStorageClassFunction; if (ptrType && ptrType->hasAddressSpace()) { - storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); } return storageClass; } @@ -2369,7 +2410,7 @@ struct SPIRVEmitContext if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) { if (ptrType->hasAddressSpace()) - storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); } if (auto systemValInst = maybeEmitSystemVal(param)) { @@ -2398,7 +2439,7 @@ struct SPIRVEmitContext if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType())) { if (ptrType->hasAddressSpace()) - storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); } auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), @@ -2418,7 +2459,7 @@ struct SPIRVEmitContext const auto kind = (SpvBuiltIn)(getIntVal(spvAsmBuiltinVar->getOperand(0))); IRBuilder builder(spvAsmBuiltinVar); builder.setInsertBefore(spvAsmBuiltinVar); - auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), SpvStorageClassInput), kind, spvAsmBuiltinVar); + auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), AddressSpace::Input), kind, spvAsmBuiltinVar); registerInst(spvAsmBuiltinVar, varInst); return varInst; } @@ -3373,7 +3414,7 @@ struct SPIRVEmitContext if (auto ptrType = as<IRPtrTypeBase>(globalInst->getDataType())) { auto addrSpace = ptrType->getAddressSpace(); - if (addrSpace != AddressSpace(SpvStorageClassInput) && addrSpace != AddressSpace(SpvStorageClassOutput)) + if (addrSpace != AddressSpace::Input && addrSpace != AddressSpace::Output) continue; } } @@ -4076,7 +4117,7 @@ struct SPIRVEmitContext if (!ptrType) return; auto addrSpace = ptrType->getAddressSpace(); - if (addrSpace == AddressSpace(SpvStorageClassInput)) + if (addrSpace == AddressSpace::Input) { if (isIntegralScalarOrCompositeType(ptrType->getValueType())) { @@ -4091,7 +4132,7 @@ struct SPIRVEmitContext SpvInst* result = nullptr; auto ptrType = as<IRPtrTypeBase>(type); SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type."); - auto storageClass = static_cast<SpvStorageClass>(ptrType->getAddressSpace()); + auto storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); auto key = BuiltinSpvVarKey(builtinVal, storageClass); if (m_builtinGlobalVars.tryGetValue(key, result)) { @@ -4103,7 +4144,7 @@ struct SPIRVEmitContext getSection(SpvLogicalSectionID::GlobalVariables), nullptr, type, - static_cast<SpvStorageClass>(ptrType->getAddressSpace()) + addressSpaceToStorageClass(ptrType->getAddressSpace()) ); emitOpDecorateBuiltIn( getSection(SpvLogicalSectionID::Annotations), @@ -4652,15 +4693,8 @@ struct SPIRVEmitContext { IRBuilder builder(m_irModule); builder.setInsertBefore(inst); - for (auto storageClass : snippet->usedPtrResultTypeStorageClasses) - { - auto newPtrType = builder.getPtrType( - kIROp_PtrType, - inst->getDataType(), - storageClass - ); - context.qualifiedResultTypes[storageClass] = newPtrType; - } + if(snippet->usedPtrResultTypeStorageClasses.getCount()) + SLANG_UNIMPLEMENTED_X("specifying storage classes in __target_intrinsic modifiers"); } return emitSpvSnippet(parent, inst, context, snippet); } @@ -4780,11 +4814,6 @@ struct SPIRVEmitContext emitOperand(kResultID); break; case SpvSnippet::ASMOperandType::ResultTypeId: - if (operand.content != 0xFFFFFFFF) - { - emitOperand(context.qualifiedResultTypes.getValue((SpvStorageClass)operand.content)); - } - else { emitOperand(context.resultType); } @@ -5067,9 +5096,9 @@ struct SPIRVEmitContext IRBuilder builder(inst); builder.setInsertBefore(inst); auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType()); - SpvStorageClass addrSpace = SpvStorageClassFunction; + auto addrSpace = AddressSpace::Function; if (destPtrType->hasAddressSpace()) - addrSpace = (SpvStorageClass)destPtrType->getAddressSpace(); + addrSpace = destPtrType->getAddressSpace(); auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace); for (UInt i = 0; i < inst->getElementCount(); i++) { @@ -5117,8 +5146,8 @@ struct SPIRVEmitContext if(ptrTypeWithNoAddressSpace->getAddressSpace() == addressSpace) return ptrTypeWithNoAddressSpace; - // It has an address space, but it doesn't match fail, this indicates a - // problem with whatever's creating these types + // It has an address space, but it doesn't match then fail, this + // indicates a problem with whatever's creating these types SLANG_ASSERT(!ptrTypeWithNoAddressSpace->hasAddressSpace()); IRBuilder builder(ptrTypeWithNoAddressSpace); @@ -5133,12 +5162,12 @@ struct SPIRVEmitContext { //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;" IRBuilder builder(inst); - auto storageClass = isSpirv14OrLater()? SpvStorageClassStorageBuffer : SpvStorageClassUniform; + auto addressSpace = isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; return emitOpAccessChain( parent, inst, // Make sure the resulting pointer has the correct storage class - getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), AddressSpace(storageClass)), + getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), addressSpace), inst->getOperand(0), makeArray(emitIntConstant(0, builder.getIntType()), ensureInst(inst->getOperand(1))) ); @@ -5943,7 +5972,7 @@ struct SPIRVEmitContext SpvStorageClass storageClass = SpvStorageClassFunction; if (fieldPtrType->hasAddressSpace()) - storageClass = (SpvStorageClass)fieldPtrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(fieldPtrType->getAddressSpace()); spvFieldType = emitOpDebugTypePointer( getSection(SpvLogicalSectionID::ConstantsAndTypes), @@ -6108,7 +6137,7 @@ struct SPIRVEmitContext SpvInst* debugBaseType = emitDebugType(baseType); SpvStorageClass storageClass = SpvStorageClassFunction; if (ptrType->hasAddressSpace()) - storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); return emitOpDebugTypePointer( getSection(SpvLogicalSectionID::ConstantsAndTypes), diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index 3dc3d3ad4..f68877d7d 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -47,7 +47,7 @@ struct IntroduceExplicitGlobalContextPass case CodeGenTarget::SPIRVAssembly: hoistableGlobalObjectKind = GlobalObjectKind::GlobalVar; requiresFuncTypeCorrectionPass = true; - addressSpaceOfLocals = (AddressSpace)SpvStorageClassFunction; + addressSpaceOfLocals = AddressSpace::Function; hoistGlobalVarOptions = HoistGlobalVarOptions::PlainGlobal; break; case CodeGenTarget::CUDASource: @@ -284,7 +284,7 @@ struct IntroduceExplicitGlobalContextPass // The context will usually be passed around by pointer, // so we get and cache that pointer type up front. // - m_contextStructPtrType = builder.getPtrType(kIROp_PtrType, m_contextStructType, (IRIntegerValue)getAddressSpaceOfLocal()); + m_contextStructPtrType = builder.getPtrType(kIROp_PtrType, m_contextStructType, getAddressSpaceOfLocal()); // The first step will be to create fields in the `KernelContext` @@ -505,7 +505,7 @@ struct IntroduceExplicitGlobalContextPass auto fieldInfo = m_mapInstToContextFieldInfo[globalVar]; if (fieldInfo.needDereference) { - auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared); + auto var = builder.emitVar(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared); if (auto nameDecor = globalVar->findDecoration<IRNameHintDecoration>()) { builder.addNameHintDecoration(var, nameDecor->getName()); diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 404111a85..f2e2fa7cb 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2179,7 +2179,7 @@ static void legalizeMeshPayloadInputParam( builder->setInsertInto(builder->getModule()); const auto ptrType = cast<IRPtrTypeBase>(pp->getDataType()); - const auto g = builder->createGlobalVar(ptrType->getValueType(), SpvStorageClassTaskPayloadWorkgroupEXT); + const auto g = builder->createGlobalVar(ptrType->getValueType(), AddressSpace::TaskPayloadWorkgroup); g->setFullType(builder->getRateQualifiedType(builder->getGroupSharedRate(), g->getFullType())); // moveValueBefore(g, builder->getFunc()); builder->addNameHintDecoration(g, pp->findDecoration<IRNameHintDecoration>()->getName()); @@ -3589,7 +3589,7 @@ void legalizeDispatchMeshPayloadForGLSL(IRModule* module) builder.getPtrType( payloadPtrType->getOp(), payloadPtrType->getValueType(), - SpvStorageClassTaskPayloadWorkgroupEXT + AddressSpace::TaskPayloadWorkgroup ) ); payload->setFullType(payloadSharedPtrType); @@ -3601,7 +3601,7 @@ void legalizeDispatchMeshPayloadForGLSL(IRModule* module) // parameter and store into the value being passed to this // call. builder.setInsertInto(module->getModuleInst()); - const auto v = builder.createGlobalVar(payloadType, SpvStorageClassTaskPayloadWorkgroupEXT); + const auto v = builder.createGlobalVar(payloadType, AddressSpace::TaskPayloadWorkgroup); v->setFullType(builder.getRateQualifiedType(builder.getGroupSharedRate(), v->getFullType())); builder.setInsertBefore(call); builder.emitStore(v, builder.emitLoad(payload)); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4111fb983..3236bb2e6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3537,10 +3537,9 @@ public: IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace); IRConstRefType* getConstRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); - IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace); + IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace); IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace); - IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); } - IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); } + IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, addressSpace); } IRTextureTypeBase* getTextureType( IRType* elementType, @@ -4030,7 +4029,7 @@ public: IRType* valueType); IRGlobalVar* createGlobalVar( IRType* valueType, - IRIntegerValue addressSpace); + AddressSpace addressSpace); IRGlobalParam* createGlobalParam( IRType* valueType); @@ -4124,7 +4123,7 @@ public: IRType* type); IRVar* emitVar( IRType* type, - IRIntegerValue addressSpace); + AddressSpace addressSpace); IRInst* emitLoad( IRType* type, diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index bff7363b3..435abc369 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1704,7 +1704,7 @@ namespace Slang if (auto ptrType = as<IRPtrTypeBase>(type)) { if (ptrType->hasAddressSpace()) - return (AddressSpace)ptrType->getAddressSpace(); + return ptrType->getAddressSpace(); return AddressSpace::Global; } return AddressSpace::Generic; diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 24eee3d76..60aa1f10a 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -155,7 +155,7 @@ namespace Slang { if (ptrType->hasAddressSpace()) { - mapInstToAddrSpace[inst] = (AddressSpace)ptrType->getAddressSpace(); + mapInstToAddrSpace[inst] = ptrType->getAddressSpace(); continue; } } diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 528fe2331..11fbb5b4b 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -203,13 +203,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase traverseUses(cbParamInst, [&](IRUse* use) { builder.setInsertBefore(use->getUser()); - auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, SpvStorageClassUniform), cbParamInst, key); + auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, AddressSpace::Uniform), cbParamInst, key); use->set(addr); }); return structType; } - static void insertLoadAtLatestLocation(IRInst* addrInst, IRUse* inUse, SpvStorageClass storageClass) + static void insertLoadAtLatestLocation(IRInst* addrInst, IRUse* inUse, AddressSpace addressSpace) { struct WorkItem { IRInst* addr; IRUse* use; }; List<WorkItem> workList; @@ -257,7 +257,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // passing as is, which needs to be a pointer (pass as is). if (user->getDataType() && user->getDataType()->getOp() == kIROp_RefType - && storageClass == SpvStorageClassInput) + && addressSpace == AddressSpace::Input) { builder.replaceOperand(use, addr); continue; @@ -552,19 +552,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase innerType = arrayType->getElementType(); } - SpvStorageClass storageClass = SpvStorageClassPrivate; + AddressSpace addressSpace = AddressSpace::ThreadLocal; // Figure out storage class based on var layout. if (auto layout = getVarLayout(inst)) { - auto cls = getGlobalParamStorageClass(layout); - if (cls != SpvStorageClassMax) - storageClass = cls; + auto cls = getGlobalParamAddressSpace(layout); + if (cls != AddressSpace::Generic) + addressSpace = cls; else if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) { String semanticName = systemValueAttr->getName(); semanticName = semanticName.toLower(); if (semanticName == "sv_pointsize") - storageClass = SpvStorageClassInput; + addressSpace = AddressSpace::Input; } } @@ -572,7 +572,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // placed here then put them in UniformConstant instead if (isSpirvUniformConstantType(inst->getDataType())) { - storageClass = SpvStorageClassUniformConstant; + addressSpace = AddressSpace::UniformConstant; } // Strip any HLSL wrappers @@ -582,8 +582,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (cbufferType || paramBlockType) { innerType = as<IRUniformParameterGroupType>(innerType)->getElementType(); - if (storageClass == SpvStorageClassPrivate) - storageClass = SpvStorageClassUniform; + if (addressSpace == AddressSpace::ThreadLocal) + addressSpace = AddressSpace::Uniform; // Constant buffer is already treated like a pointer type, and // we are not adding another layer of indirection when replacing it // with a pointer type. Therefore we don't need to insert a load at @@ -623,7 +623,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } } - else if (storageClass == SpvStorageClassPushConstant) + else if (addressSpace == AddressSpace::PushConstant) { // Push constant params does not need a VarLayout. varLayoutInst->removeAndDeallocate(); @@ -632,7 +632,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(innerType)) { innerType = lowerStructuredBufferType(structuredBufferType).structType; - storageClass = getStorageBufferStorageClass(); + addressSpace = getStorageBufferAddressSpace(); needLoad = false; auto memoryFlags = MemoryQualifierSetModifier::Flags::kNone; @@ -652,12 +652,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (m_sharedContext->isSpirv14OrLater()) { builder.addDecorationIfNotExist(innerType, kIROp_SPIRVBlockDecoration); - storageClass = SpvStorageClassStorageBuffer; + addressSpace = AddressSpace::StorageBuffer; } else { builder.addDecorationIfNotExist(innerType, kIROp_SPIRVBufferBlockDecoration); - storageClass = SpvStorageClassUniform; + addressSpace = AddressSpace::Uniform; } needLoad = false; } @@ -674,14 +674,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Make a pointer type of storageClass. builder.setInsertBefore(inst); - ptrType = builder.getPtrType(kIROp_PtrType, innerType, storageClass); + ptrType = builder.getPtrType(kIROp_PtrType, innerType, addressSpace); inst->setFullType(ptrType); if (needLoad) { // Insert an explicit load at each use site. traverseUses(inst, [&](IRUse* use) { - insertLoadAtLatestLocation(inst, use, storageClass); + insertLoadAtLatestLocation(inst, use, addressSpace); }); } else if (arrayType) @@ -694,7 +694,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // For array resources, getElement(r, index) ==> getElementPtr(r, index). IRBuilder builder(getElement); builder.setInsertBefore(user); - auto newAddr = builder.emitElementAddress(builder.getPtrType(kIROp_PtrType, innerElementType, storageClass), inst, getElement->getIndex()); + auto newAddr = builder.emitElementAddress(builder.getPtrType(kIROp_PtrType, innerElementType, addressSpace), inst, getElement->getIndex()); user->replaceUsesWith(newAddr); user->removeAndDeallocate(); return; @@ -705,48 +705,48 @@ struct SPIRVLegalizationContext : public SourceEmitterBase processGlobalVar(inst); } - SpvStorageClass getStorageClassFromGlobalParamResourceKind(LayoutResourceKind kind) + AddressSpace getAddressSpaceFromGlobalParamResourceKind(LayoutResourceKind kind) { - SpvStorageClass storageClass = SpvStorageClassMax; + AddressSpace addressSpace = AddressSpace::Generic; switch (kind) { case LayoutResourceKind::Uniform: case LayoutResourceKind::DescriptorTableSlot: case LayoutResourceKind::ConstantBuffer: - storageClass = SpvStorageClassUniform; + addressSpace = AddressSpace::Uniform; break; case LayoutResourceKind::VaryingInput: - storageClass = SpvStorageClassInput; + addressSpace = AddressSpace::Input; break; case LayoutResourceKind::VaryingOutput: - storageClass = SpvStorageClassOutput; + addressSpace = AddressSpace::Output; break; case LayoutResourceKind::ShaderResource: case LayoutResourceKind::UnorderedAccess: - storageClass = getStorageBufferStorageClass(); + addressSpace = getStorageBufferAddressSpace(); break; case LayoutResourceKind::PushConstantBuffer: - storageClass = SpvStorageClassPushConstant; + addressSpace = AddressSpace::PushConstant; break; case LayoutResourceKind::RayPayload: - storageClass = SpvStorageClassIncomingRayPayloadKHR; + addressSpace = AddressSpace::IncomingRayPayload; break; case LayoutResourceKind::CallablePayload: - storageClass = SpvStorageClassIncomingCallableDataKHR; + addressSpace = AddressSpace::IncomingCallableData; break; case LayoutResourceKind::HitAttributes: - storageClass = SpvStorageClassHitAttributeKHR; + addressSpace = AddressSpace::HitAttribute; break; case LayoutResourceKind::ShaderRecord: - storageClass = SpvStorageClassShaderRecordBufferKHR; + addressSpace = AddressSpace::ShaderRecordBuffer; break; default: break; } - return storageClass; + return addressSpace; } - SpvStorageClass getGlobalParamStorageClass(IRVarLayout* varLayout) + AddressSpace getGlobalParamAddressSpace(IRVarLayout* varLayout) { auto typeLayout = varLayout->getTypeLayout()->unwrapArray(); if (auto parameterGroupTypeLayout = as<IRParameterGroupTypeLayout>(typeLayout)) @@ -754,22 +754,22 @@ struct SPIRVLegalizationContext : public SourceEmitterBase varLayout = parameterGroupTypeLayout->getContainerVarLayout(); } - SpvStorageClass result = SpvStorageClassMax; + auto result = AddressSpace::Generic; for (auto rr : varLayout->getOffsetAttrs()) { - auto storageClass = getStorageClassFromGlobalParamResourceKind(rr->getResourceKind()); + auto addressSpace = getAddressSpaceFromGlobalParamResourceKind(rr->getResourceKind()); // If we haven't inferred a storage class yet, use the one we just found. - if (result == SpvStorageClassMax) - result = storageClass; - else if (result != storageClass) + if (result == AddressSpace::Generic) + result = addressSpace; + else if (result != addressSpace) { - // If we have inferred a storage class, and it is different from the one we just found, - // then we have conflicting uses of the resource, and we cannot infer a storage class. + // If we have inferred an address space, and it is different from the one we just found, + // then we have conflicting uses of the resource, and we cannot infer an address space. // An exception is that a uniform storage class can be further specialized by PushConstants. - if (result == SpvStorageClassUniform) - result = storageClass; + if (result == AddressSpace::Uniform) + result = addressSpace; else - SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class."); + SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class address space."); } } return result; @@ -783,7 +783,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction); + oldPtrType->getOp(), oldPtrType->getValueType(), AddressSpace::Function); inst->setFullType(newPtrType); addUsersToWorkList(inst); } @@ -862,45 +862,45 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return; } - SpvStorageClass storageClass = SpvStorageClassPrivate; + auto addressSpace = AddressSpace::ThreadLocal; if (as<IRGroupSharedRate>(inst->getRate())) { - storageClass = SpvStorageClassWorkgroup; + addressSpace = AddressSpace::GroupShared; } else if (const auto varLayout = getVarLayout(inst)) { - auto cls = getGlobalParamStorageClass(varLayout); - if (cls != SpvStorageClassMax) - storageClass = cls; + auto cls = getGlobalParamAddressSpace(varLayout); + if (cls != AddressSpace::Generic) + addressSpace = cls; } for (auto decor : inst->getDecorations()) { switch (decor->getOp()) { case kIROp_VulkanRayPayloadDecoration: - storageClass = SpvStorageClassRayPayloadKHR; + addressSpace = AddressSpace::RayPayloadKHR; break; case kIROp_VulkanRayPayloadInDecoration: - storageClass = SpvStorageClassIncomingRayPayloadKHR; + addressSpace = AddressSpace::IncomingRayPayload; break; case kIROp_VulkanCallablePayloadDecoration: - storageClass = SpvStorageClassCallableDataKHR; + addressSpace = AddressSpace::CallableDataKHR; break; case kIROp_VulkanCallablePayloadInDecoration: - storageClass = SpvStorageClassIncomingCallableDataKHR; + addressSpace = AddressSpace::IncomingCallableData; break; case kIROp_VulkanHitObjectAttributesDecoration: - storageClass = SpvStorageClassHitObjectAttributeNV; + addressSpace = AddressSpace::HitObjectAttribute; break; case kIROp_VulkanHitAttributesDecoration: - storageClass = SpvStorageClassHitAttributeKHR; + addressSpace = AddressSpace::HitAttribute; break; } } IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto newPtrType = - builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), addressSpace); inst->setFullType(newPtrType); addUsersToWorkList(inst); return; @@ -916,22 +916,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (!snippet) return; if (snippet->resultStorageClass != SpvStorageClassMax) - { - auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); - if (!ptrType) - return; - IRBuilder builder(m_sharedContext->m_irModule); - builder.setInsertBefore(inst); - auto qualPtrType = builder.getPtrType( - ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass); - List<IRInst*> args; - for (UInt i = 0; i < inst->getArgCount(); i++) - args.add(inst->getArg(i)); - auto newCall = builder.emitCallInst(qualPtrType, funcValue, args); - inst->replaceUsesWith(newCall); - inst->removeAndDeallocate(); - addUsersToWorkList(newCall); - } + SLANG_UNIMPLEMENTED_X("Specifying storage classes in spirv __target_intrinsic snippets"); return; } @@ -1065,7 +1050,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(inst); else setInsertAfterOrdinaryInst(&builder, x); - y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); + y = builder.emitVar(x->getDataType(), AddressSpace::Function); builder.emitStore(y, x); if (x->getParent()->getOp() != kIROp_Module) m_mapArrayValueToVar.set(x, y); @@ -1152,9 +1137,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - SpvStorageClass getStorageBufferStorageClass() + AddressSpace getStorageBufferAddressSpace() { - return m_sharedContext->isSpirv14OrLater() ? SpvStorageClassStorageBuffer : SpvStorageClassUniform; + return m_sharedContext->isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; } void processStructuredBufferLoad(IRInst* loadInst) @@ -1165,7 +1150,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(loadInst); IRInst* args[] = { sb, index }; auto addrInst = builder.emitIntrinsicInst( - builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferStorageClass()), + builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferAddressSpace()), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1184,7 +1169,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(storeInst); IRInst* args[] = { sb, index }; auto addrInst = builder.emitIntrinsicInst( - builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferStorageClass()), + builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferAddressSpace()), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1324,7 +1309,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newPtrType = builder.getPtrType( ptrType->getOp(), ptrType->getValueType(), - SpvStorageClassImage); + AddressSpace::Image); subscript->setFullType(newPtrType); // HACK: assumes the image operand is a load and replace it with @@ -2175,7 +2160,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (auto ptrType = as<IRPtrTypeBase>(type)) { if (ptrType->hasAddressSpace()) - return (AddressSpace)ptrType->getAddressSpace(); + return ptrType->getAddressSpace(); } return AddressSpace::Generic; } @@ -2240,7 +2225,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto lowered = lowerStructuredBufferType(t); IRBuilder builder(t); builder.setInsertBefore(t); - t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, getStorageBufferStorageClass())); + t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, getStorageBufferAddressSpace())); } for (auto t : textureFootprintTypes) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1441b0567..e0769686c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2881,9 +2881,9 @@ namespace Slang operands); } - IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace) + IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { - return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), addressSpace)); + return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace))); } IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace) @@ -4530,7 +4530,7 @@ namespace Slang IRGlobalVar* IRBuilder::createGlobalVar( IRType* valueType, - IRIntegerValue addressSpace) + AddressSpace addressSpace) { auto ptrType = getPtrType(kIROp_PtrType, valueType, addressSpace); IRGlobalVar* globalVar = createInst<IRGlobalVar>( @@ -4807,7 +4807,7 @@ namespace Slang IRVar* IRBuilder::emitVar( IRType* type, - IRIntegerValue addressSpace) + AddressSpace addressSpace) { auto allocatedType = getPtrType(kIROp_PtrType, type, addressSpace); auto inst = createInst<IRVar>( diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 6009ef33e..b10c88a15 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -209,7 +209,7 @@ bool isPointerToResourceType(IRType* type) { while (auto ptrType = as<IRPtrTypeBase>(type)) { - if (ptrType->getAddressSpace() == AddressSpace(SpvStorageClassStorageBuffer) || + if (ptrType->getAddressSpace() == AddressSpace::StorageBuffer || ptrType->getAddressSpace() == AddressSpace::UserPointer) return true; type = ptrType->getValueType(); diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h index 404c84cf4..adf6e26f8 100644 --- a/source/slang/slang-type-system-shared.h +++ b/source/slang/slang-type-system-shared.h @@ -62,12 +62,45 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) enum class AddressSpace : uint64_t { Generic = 0x7fffffff, + // Corresponds to SPIR-V's SpvStorageClassPrivate ThreadLocal = 1, - Global = 2, - GroupShared = 3, - Uniform = 4, + Global, + // Corresponds to SPIR-V's SpvStorageClassWorkgroup + GroupShared, + // Corresponds to SPIR-V's SpvStorageClassUniform + Uniform, // specific address space for payload data in metal - MetalObjectData = 5, + MetalObjectData, + // Corresponds to SPIR-V's SpvStorageClassInput + Input, + // Corresponds to SPIR-V's SpvStorageClassOutput + Output, + // Corresponds to SPIR-V's SpvStorageClassTaskPayloadWorkgroupEXT + TaskPayloadWorkgroup, + // Corresponds to SPIR-V's SpvStorageClassFunction + Function, + // Corresponds to SPIR-V's SpvStorageClassStorageBuffer + StorageBuffer, + // Corresponds to SPIR-V's SpvStorageClassPushConstant, + PushConstant, + // Corresponds to SPIR-V's SpvStorageClassRayPayloadKHR, + RayPayloadKHR, + // Corresponds to SPIR-V's SpvStorageClassIncomingRayPayloadKHR, + IncomingRayPayload, + // Corresponds to SPIR-V's SpvStorageClassCallableDataKHR + CallableDataKHR, + // Corresponds to SPIR-V's SpvStorageClassIncomingCallableDataKHR + IncomingCallableData, + // Corresponds to SPIR-V's SpvStorageClassHitObjectAttributeNV, + HitObjectAttribute, + // Corresponds to SPIR-V's SpvStorageClassHitAttributeKHR, + HitAttribute, + // Corresponds to SPIR-V's SpvStorageClassShaderRecordBufferKHR, + ShaderRecordBuffer, + // Corresponds to SPIR-V's SpvStorageClassUniformConstant, + UniformConstant, + // Corresponds to SPIR-V's SpvStorageClassImage + Image, // Default address space for a user-defined pointer UserPointer = 0x100000001ULL, |
