summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-08-20 06:06:34 +0800
committerGitHub <noreply@github.com>2024-08-19 15:06:34 -0700
commitf77a5ac9d1547a4394bba4ab8e94d905972c79b7 (patch)
tree0d66b3c8386d8cb1e75970c93914fe2a60f03c61 /source
parent453683bf44f2112719802eaac2b332d49eebd640 (diff)
Remove using SpvStorageClass values casted into AddressSpace values (#4861)
* Remove using SpvStorageClass values casted into AddressSpace values Also removes support for specific storage classes in __target_intrinsic snippets * remove SLANG_RETURN_NEVER macro * squash warnings * Make nonexhaustive switch statement error on gcc * Add SLANG_EXHAUSTIVE_SWITCH_BEGIN/END macros --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-common.h24
-rw-r--r--source/core/slang-signal.cpp2
-rw-r--r--source/core/slang-signal.h2
-rw-r--r--source/slang/slang-emit-glsl.cpp2
-rw-r--r--source/slang/slang-emit-metal.cpp2
-rw-r--r--source/slang/slang-emit-spirv.cpp97
-rw-r--r--source/slang/slang-ir-explicit-global-context.cpp6
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp6
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp2
-rw-r--r--source/slang/slang-ir-specialize-address-space.cpp2
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp141
-rw-r--r--source/slang/slang-ir.cpp8
-rw-r--r--source/slang/slang-legalize-types.cpp2
-rw-r--r--source/slang/slang-type-system-shared.h41
15 files changed, 208 insertions, 138 deletions
diff --git a/source/core/slang-common.h b/source/core/slang-common.h
index 362a509a7..907bf4593 100644
--- a/source/core/slang-common.h
+++ b/source/core/slang-common.h
@@ -158,6 +158,30 @@ public:
#define UNREACHABLE_RETURN(x) return x;
#endif
+#if SLANG_GCC
+# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \
+ _Pragma("GCC diagnostic push"); \
+ _Pragma("GCC diagnostic error \"-Wswitch-enum\"");
+# define SLANG_EXHAUSTIVE_SWITCH_END \
+ _Pragma("GCC diagnostic pop");
+#elif SLANG_CLANG
+# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \
+ _Pragma("clang diagnostic push"); \
+ _Pragma("clang diagnostic error \"-Wswitch-enum\"");
+# define SLANG_EXHAUSTIVE_SWITCH_END \
+ _Pragma("clang diagnostic pop");
+#elif SLANG_VC
+# define SLANG_EXHAUSTIVE_SWITCH_BEGIN \
+ _Pragma("warning(push)"); \
+ _Pragma("warning(error : 4062)");
+# define SLANG_EXHAUSTIVE_SWITCH_END \
+ _Pragma("warning(pop)");
+#else
+# define SLANG_EXHAUSTIVE_SWITCH_BEGIN
+# define SLANG_EXHAUSTIVE_SWITCH_END
+#endif
+
+
//
// Use `SLANG_ASSUME(myBoolExpression);` to inform the compiler that the condition is true.
// Do not rely on side effects of the condition being performed.
diff --git a/source/core/slang-signal.cpp b/source/core/slang-signal.cpp
index d8218b379..5f53cba93 100644
--- a/source/core/slang-signal.cpp
+++ b/source/core/slang-signal.cpp
@@ -36,7 +36,7 @@ String _getMessage(SignalType type, char const* message)
// One point of having as a single function is a choke point both for handling (allowing different
// handling scenarios) as well as a choke point to set a breakpoint to catch 'signal' types
-SLANG_RETURN_NEVER void handleSignal(SignalType type, char const* message)
+[[noreturn]] void handleSignal(SignalType type, char const* message)
{
StringBuilder buf;
const char*const typeText = _getSignalTypeAsText(type);
diff --git a/source/core/slang-signal.h b/source/core/slang-signal.h
index 2151bdcfe..759581ee2 100644
--- a/source/core/slang-signal.h
+++ b/source/core/slang-signal.h
@@ -18,7 +18,7 @@ enum class SignalType
// Note that message can be passed as nullptr for no message.
-SLANG_RETURN_NEVER void handleSignal(SignalType type, char const* message);
+[[noreturn]] void handleSignal(SignalType type, char const* message);
#define SLANG_UNEXPECTED(reason) \
::Slang::handleSignal(::Slang::SignalType::Unexpected, reason)
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 86ad0be6e..50d71b6b6 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -2639,7 +2639,7 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
{
- if(addressSpace == (AddressSpace)SpvStorageClassTaskPayloadWorkgroupEXT)
+ if(addressSpace == AddressSpace::TaskPayloadWorkgroup)
{
m_writer->emit("taskPayloadSharedEXT ");
}
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index 302e2b6b7..0b282c4d9 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -700,7 +700,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
{
auto ptrType = cast<IRPtrTypeBase>(type);
emitType((IRType*)ptrType->getValueType());
- switch ((AddressSpace)ptrType->getAddressSpace())
+ switch (ptrType->getAddressSpace())
{
case AddressSpace::Global:
m_writer->emit(" device");
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index bb0a2565c..20cbbac02 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -324,7 +324,6 @@ struct SpvSnippetEmitContext
bool isResultTypeFloat;
// True if resultType is signed.
bool isResultTypeSigned;
- Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes;
List<SpvWord> argumentIds;
};
@@ -1223,17 +1222,59 @@ struct SPIRVEmitContext
return m_NonSemanticDebugPrintfExtInst;
}
- SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace)
+ static SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace)
{
+ SLANG_EXHAUSTIVE_SWITCH_BEGIN
switch (addrSpace)
{
case AddressSpace::Generic:
return SpvStorageClassMax;
+ case AddressSpace::ThreadLocal:
+ return SpvStorageClassPrivate;
+ case AddressSpace::GroupShared:
+ return SpvStorageClassWorkgroup;
+ case AddressSpace::Uniform:
+ return SpvStorageClassUniform;
+ case AddressSpace::Input:
+ return SpvStorageClassInput;
+ case AddressSpace::Output:
+ return SpvStorageClassOutput;
+ case AddressSpace::TaskPayloadWorkgroup:
+ return SpvStorageClassTaskPayloadWorkgroupEXT;
+ case AddressSpace::Function:
+ return SpvStorageClassFunction;
+ case AddressSpace::StorageBuffer:
+ return SpvStorageClassStorageBuffer;
+ case AddressSpace::PushConstant:
+ return SpvStorageClassPushConstant;
+ case AddressSpace::RayPayloadKHR:
+ return SpvStorageClassRayPayloadKHR;
+ case AddressSpace::IncomingRayPayload:
+ return SpvStorageClassIncomingRayPayloadKHR;
+ case AddressSpace::CallableDataKHR:
+ return SpvStorageClassCallableDataKHR;
+ case AddressSpace::IncomingCallableData:
+ return SpvStorageClassIncomingCallableDataKHR;
+ case AddressSpace::HitObjectAttribute:
+ return SpvStorageClassHitObjectAttributeNV;
+ case AddressSpace::HitAttribute:
+ return SpvStorageClassHitAttributeKHR;
+ case AddressSpace::ShaderRecordBuffer:
+ return SpvStorageClassShaderRecordBufferKHR;
+ case AddressSpace::UniformConstant:
+ return SpvStorageClassUniformConstant;
+ case AddressSpace::Image:
+ return SpvStorageClassImage;
case AddressSpace::UserPointer:
return SpvStorageClassPhysicalStorageBuffer;
- default:
- return (SpvStorageClass)addrSpace;
+ case AddressSpace::Global:
+ case AddressSpace::MetalObjectData:
+ // msvc is limiting us from putting the UNEXPECTED macro here, so
+ // just fall out
+ ;
}
+ SLANG_UNEXPECTED("Unhandled AddressSpace in addressSpaceToStorageClass");
+ SLANG_EXHAUSTIVE_SWITCH_END
}
// Now that we've gotten the core infrastructure out of the way,
@@ -1829,7 +1870,7 @@ struct SPIRVEmitContext
SpvStorageClass storageClass = SpvStorageClassFunction;
if (ptrType && ptrType->hasAddressSpace())
{
- storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
}
return storageClass;
}
@@ -2369,7 +2410,7 @@ struct SPIRVEmitContext
if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
{
if (ptrType->hasAddressSpace())
- storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
}
if (auto systemValInst = maybeEmitSystemVal(param))
{
@@ -2398,7 +2439,7 @@ struct SPIRVEmitContext
if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType()))
{
if (ptrType->hasAddressSpace())
- storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
}
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
@@ -2418,7 +2459,7 @@ struct SPIRVEmitContext
const auto kind = (SpvBuiltIn)(getIntVal(spvAsmBuiltinVar->getOperand(0)));
IRBuilder builder(spvAsmBuiltinVar);
builder.setInsertBefore(spvAsmBuiltinVar);
- auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), SpvStorageClassInput), kind, spvAsmBuiltinVar);
+ auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), AddressSpace::Input), kind, spvAsmBuiltinVar);
registerInst(spvAsmBuiltinVar, varInst);
return varInst;
}
@@ -3373,7 +3414,7 @@ struct SPIRVEmitContext
if (auto ptrType = as<IRPtrTypeBase>(globalInst->getDataType()))
{
auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace != AddressSpace(SpvStorageClassInput) && addrSpace != AddressSpace(SpvStorageClassOutput))
+ if (addrSpace != AddressSpace::Input && addrSpace != AddressSpace::Output)
continue;
}
}
@@ -4076,7 +4117,7 @@ struct SPIRVEmitContext
if (!ptrType)
return;
auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace == AddressSpace(SpvStorageClassInput))
+ if (addrSpace == AddressSpace::Input)
{
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
{
@@ -4091,7 +4132,7 @@ struct SPIRVEmitContext
SpvInst* result = nullptr;
auto ptrType = as<IRPtrTypeBase>(type);
SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type.");
- auto storageClass = static_cast<SpvStorageClass>(ptrType->getAddressSpace());
+ auto storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
auto key = BuiltinSpvVarKey(builtinVal, storageClass);
if (m_builtinGlobalVars.tryGetValue(key, result))
{
@@ -4103,7 +4144,7 @@ struct SPIRVEmitContext
getSection(SpvLogicalSectionID::GlobalVariables),
nullptr,
type,
- static_cast<SpvStorageClass>(ptrType->getAddressSpace())
+ addressSpaceToStorageClass(ptrType->getAddressSpace())
);
emitOpDecorateBuiltIn(
getSection(SpvLogicalSectionID::Annotations),
@@ -4652,15 +4693,8 @@ struct SPIRVEmitContext
{
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
- for (auto storageClass : snippet->usedPtrResultTypeStorageClasses)
- {
- auto newPtrType = builder.getPtrType(
- kIROp_PtrType,
- inst->getDataType(),
- storageClass
- );
- context.qualifiedResultTypes[storageClass] = newPtrType;
- }
+ if(snippet->usedPtrResultTypeStorageClasses.getCount())
+ SLANG_UNIMPLEMENTED_X("specifying storage classes in __target_intrinsic modifiers");
}
return emitSpvSnippet(parent, inst, context, snippet);
}
@@ -4780,11 +4814,6 @@ struct SPIRVEmitContext
emitOperand(kResultID);
break;
case SpvSnippet::ASMOperandType::ResultTypeId:
- if (operand.content != 0xFFFFFFFF)
- {
- emitOperand(context.qualifiedResultTypes.getValue((SpvStorageClass)operand.content));
- }
- else
{
emitOperand(context.resultType);
}
@@ -5067,9 +5096,9 @@ struct SPIRVEmitContext
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType());
- SpvStorageClass addrSpace = SpvStorageClassFunction;
+ auto addrSpace = AddressSpace::Function;
if (destPtrType->hasAddressSpace())
- addrSpace = (SpvStorageClass)destPtrType->getAddressSpace();
+ addrSpace = destPtrType->getAddressSpace();
auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace);
for (UInt i = 0; i < inst->getElementCount(); i++)
{
@@ -5117,8 +5146,8 @@ struct SPIRVEmitContext
if(ptrTypeWithNoAddressSpace->getAddressSpace() == addressSpace)
return ptrTypeWithNoAddressSpace;
- // It has an address space, but it doesn't match fail, this indicates a
- // problem with whatever's creating these types
+ // It has an address space, but it doesn't match then fail, this
+ // indicates a problem with whatever's creating these types
SLANG_ASSERT(!ptrTypeWithNoAddressSpace->hasAddressSpace());
IRBuilder builder(ptrTypeWithNoAddressSpace);
@@ -5133,12 +5162,12 @@ struct SPIRVEmitContext
{
//"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;"
IRBuilder builder(inst);
- auto storageClass = isSpirv14OrLater()? SpvStorageClassStorageBuffer : SpvStorageClassUniform;
+ auto addressSpace = isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform;
return emitOpAccessChain(
parent,
inst,
// Make sure the resulting pointer has the correct storage class
- getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), AddressSpace(storageClass)),
+ getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), addressSpace),
inst->getOperand(0),
makeArray(emitIntConstant(0, builder.getIntType()), ensureInst(inst->getOperand(1)))
);
@@ -5943,7 +5972,7 @@ struct SPIRVEmitContext
SpvStorageClass storageClass = SpvStorageClassFunction;
if (fieldPtrType->hasAddressSpace())
- storageClass = (SpvStorageClass)fieldPtrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(fieldPtrType->getAddressSpace());
spvFieldType = emitOpDebugTypePointer(
getSection(SpvLogicalSectionID::ConstantsAndTypes),
@@ -6108,7 +6137,7 @@ struct SPIRVEmitContext
SpvInst* debugBaseType = emitDebugType(baseType);
SpvStorageClass storageClass = SpvStorageClassFunction;
if (ptrType->hasAddressSpace())
- storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
return emitOpDebugTypePointer(
getSection(SpvLogicalSectionID::ConstantsAndTypes),
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp
index 3dc3d3ad4..f68877d7d 100644
--- a/source/slang/slang-ir-explicit-global-context.cpp
+++ b/source/slang/slang-ir-explicit-global-context.cpp
@@ -47,7 +47,7 @@ struct IntroduceExplicitGlobalContextPass
case CodeGenTarget::SPIRVAssembly:
hoistableGlobalObjectKind = GlobalObjectKind::GlobalVar;
requiresFuncTypeCorrectionPass = true;
- addressSpaceOfLocals = (AddressSpace)SpvStorageClassFunction;
+ addressSpaceOfLocals = AddressSpace::Function;
hoistGlobalVarOptions = HoistGlobalVarOptions::PlainGlobal;
break;
case CodeGenTarget::CUDASource:
@@ -284,7 +284,7 @@ struct IntroduceExplicitGlobalContextPass
// The context will usually be passed around by pointer,
// so we get and cache that pointer type up front.
//
- m_contextStructPtrType = builder.getPtrType(kIROp_PtrType, m_contextStructType, (IRIntegerValue)getAddressSpaceOfLocal());
+ m_contextStructPtrType = builder.getPtrType(kIROp_PtrType, m_contextStructType, getAddressSpaceOfLocal());
// The first step will be to create fields in the `KernelContext`
@@ -505,7 +505,7 @@ struct IntroduceExplicitGlobalContextPass
auto fieldInfo = m_mapInstToContextFieldInfo[globalVar];
if (fieldInfo.needDereference)
{
- auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared);
+ auto var = builder.emitVar(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared);
if (auto nameDecor = globalVar->findDecoration<IRNameHintDecoration>())
{
builder.addNameHintDecoration(var, nameDecor->getName());
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 404111a85..f2e2fa7cb 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2179,7 +2179,7 @@ static void legalizeMeshPayloadInputParam(
builder->setInsertInto(builder->getModule());
const auto ptrType = cast<IRPtrTypeBase>(pp->getDataType());
- const auto g = builder->createGlobalVar(ptrType->getValueType(), SpvStorageClassTaskPayloadWorkgroupEXT);
+ const auto g = builder->createGlobalVar(ptrType->getValueType(), AddressSpace::TaskPayloadWorkgroup);
g->setFullType(builder->getRateQualifiedType(builder->getGroupSharedRate(), g->getFullType()));
// moveValueBefore(g, builder->getFunc());
builder->addNameHintDecoration(g, pp->findDecoration<IRNameHintDecoration>()->getName());
@@ -3589,7 +3589,7 @@ void legalizeDispatchMeshPayloadForGLSL(IRModule* module)
builder.getPtrType(
payloadPtrType->getOp(),
payloadPtrType->getValueType(),
- SpvStorageClassTaskPayloadWorkgroupEXT
+ AddressSpace::TaskPayloadWorkgroup
)
);
payload->setFullType(payloadSharedPtrType);
@@ -3601,7 +3601,7 @@ void legalizeDispatchMeshPayloadForGLSL(IRModule* module)
// parameter and store into the value being passed to this
// call.
builder.setInsertInto(module->getModuleInst());
- const auto v = builder.createGlobalVar(payloadType, SpvStorageClassTaskPayloadWorkgroupEXT);
+ const auto v = builder.createGlobalVar(payloadType, AddressSpace::TaskPayloadWorkgroup);
v->setFullType(builder.getRateQualifiedType(builder.getGroupSharedRate(), v->getFullType()));
builder.setInsertBefore(call);
builder.emitStore(v, builder.emitLoad(payload));
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4111fb983..3236bb2e6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3537,10 +3537,9 @@ public:
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
IRConstRefType* getConstRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
- IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace);
+ IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace);
- IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); }
- IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); }
+ IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, addressSpace); }
IRTextureTypeBase* getTextureType(
IRType* elementType,
@@ -4030,7 +4029,7 @@ public:
IRType* valueType);
IRGlobalVar* createGlobalVar(
IRType* valueType,
- IRIntegerValue addressSpace);
+ AddressSpace addressSpace);
IRGlobalParam* createGlobalParam(
IRType* valueType);
@@ -4124,7 +4123,7 @@ public:
IRType* type);
IRVar* emitVar(
IRType* type,
- IRIntegerValue addressSpace);
+ AddressSpace addressSpace);
IRInst* emitLoad(
IRType* type,
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index bff7363b3..435abc369 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1704,7 +1704,7 @@ namespace Slang
if (auto ptrType = as<IRPtrTypeBase>(type))
{
if (ptrType->hasAddressSpace())
- return (AddressSpace)ptrType->getAddressSpace();
+ return ptrType->getAddressSpace();
return AddressSpace::Global;
}
return AddressSpace::Generic;
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp
index 24eee3d76..60aa1f10a 100644
--- a/source/slang/slang-ir-specialize-address-space.cpp
+++ b/source/slang/slang-ir-specialize-address-space.cpp
@@ -155,7 +155,7 @@ namespace Slang
{
if (ptrType->hasAddressSpace())
{
- mapInstToAddrSpace[inst] = (AddressSpace)ptrType->getAddressSpace();
+ mapInstToAddrSpace[inst] = ptrType->getAddressSpace();
continue;
}
}
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 528fe2331..11fbb5b4b 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -203,13 +203,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
traverseUses(cbParamInst, [&](IRUse* use)
{
builder.setInsertBefore(use->getUser());
- auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, SpvStorageClassUniform), cbParamInst, key);
+ auto addr = builder.emitFieldAddress(builder.getPtrType(kIROp_PtrType, innerType, AddressSpace::Uniform), cbParamInst, key);
use->set(addr);
});
return structType;
}
- static void insertLoadAtLatestLocation(IRInst* addrInst, IRUse* inUse, SpvStorageClass storageClass)
+ static void insertLoadAtLatestLocation(IRInst* addrInst, IRUse* inUse, AddressSpace addressSpace)
{
struct WorkItem { IRInst* addr; IRUse* use; };
List<WorkItem> workList;
@@ -257,7 +257,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// passing as is, which needs to be a pointer (pass as is).
if (user->getDataType()
&& user->getDataType()->getOp() == kIROp_RefType
- && storageClass == SpvStorageClassInput)
+ && addressSpace == AddressSpace::Input)
{
builder.replaceOperand(use, addr);
continue;
@@ -552,19 +552,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
innerType = arrayType->getElementType();
}
- SpvStorageClass storageClass = SpvStorageClassPrivate;
+ AddressSpace addressSpace = AddressSpace::ThreadLocal;
// Figure out storage class based on var layout.
if (auto layout = getVarLayout(inst))
{
- auto cls = getGlobalParamStorageClass(layout);
- if (cls != SpvStorageClassMax)
- storageClass = cls;
+ auto cls = getGlobalParamAddressSpace(layout);
+ if (cls != AddressSpace::Generic)
+ addressSpace = cls;
else if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>())
{
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
if (semanticName == "sv_pointsize")
- storageClass = SpvStorageClassInput;
+ addressSpace = AddressSpace::Input;
}
}
@@ -572,7 +572,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// placed here then put them in UniformConstant instead
if (isSpirvUniformConstantType(inst->getDataType()))
{
- storageClass = SpvStorageClassUniformConstant;
+ addressSpace = AddressSpace::UniformConstant;
}
// Strip any HLSL wrappers
@@ -582,8 +582,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (cbufferType || paramBlockType)
{
innerType = as<IRUniformParameterGroupType>(innerType)->getElementType();
- if (storageClass == SpvStorageClassPrivate)
- storageClass = SpvStorageClassUniform;
+ if (addressSpace == AddressSpace::ThreadLocal)
+ addressSpace = AddressSpace::Uniform;
// Constant buffer is already treated like a pointer type, and
// we are not adding another layer of indirection when replacing it
// with a pointer type. Therefore we don't need to insert a load at
@@ -623,7 +623,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
}
- else if (storageClass == SpvStorageClassPushConstant)
+ else if (addressSpace == AddressSpace::PushConstant)
{
// Push constant params does not need a VarLayout.
varLayoutInst->removeAndDeallocate();
@@ -632,7 +632,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(innerType))
{
innerType = lowerStructuredBufferType(structuredBufferType).structType;
- storageClass = getStorageBufferStorageClass();
+ addressSpace = getStorageBufferAddressSpace();
needLoad = false;
auto memoryFlags = MemoryQualifierSetModifier::Flags::kNone;
@@ -652,12 +652,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (m_sharedContext->isSpirv14OrLater())
{
builder.addDecorationIfNotExist(innerType, kIROp_SPIRVBlockDecoration);
- storageClass = SpvStorageClassStorageBuffer;
+ addressSpace = AddressSpace::StorageBuffer;
}
else
{
builder.addDecorationIfNotExist(innerType, kIROp_SPIRVBufferBlockDecoration);
- storageClass = SpvStorageClassUniform;
+ addressSpace = AddressSpace::Uniform;
}
needLoad = false;
}
@@ -674,14 +674,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Make a pointer type of storageClass.
builder.setInsertBefore(inst);
- ptrType = builder.getPtrType(kIROp_PtrType, innerType, storageClass);
+ ptrType = builder.getPtrType(kIROp_PtrType, innerType, addressSpace);
inst->setFullType(ptrType);
if (needLoad)
{
// Insert an explicit load at each use site.
traverseUses(inst, [&](IRUse* use)
{
- insertLoadAtLatestLocation(inst, use, storageClass);
+ insertLoadAtLatestLocation(inst, use, addressSpace);
});
}
else if (arrayType)
@@ -694,7 +694,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// For array resources, getElement(r, index) ==> getElementPtr(r, index).
IRBuilder builder(getElement);
builder.setInsertBefore(user);
- auto newAddr = builder.emitElementAddress(builder.getPtrType(kIROp_PtrType, innerElementType, storageClass), inst, getElement->getIndex());
+ auto newAddr = builder.emitElementAddress(builder.getPtrType(kIROp_PtrType, innerElementType, addressSpace), inst, getElement->getIndex());
user->replaceUsesWith(newAddr);
user->removeAndDeallocate();
return;
@@ -705,48 +705,48 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
processGlobalVar(inst);
}
- SpvStorageClass getStorageClassFromGlobalParamResourceKind(LayoutResourceKind kind)
+ AddressSpace getAddressSpaceFromGlobalParamResourceKind(LayoutResourceKind kind)
{
- SpvStorageClass storageClass = SpvStorageClassMax;
+ AddressSpace addressSpace = AddressSpace::Generic;
switch (kind)
{
case LayoutResourceKind::Uniform:
case LayoutResourceKind::DescriptorTableSlot:
case LayoutResourceKind::ConstantBuffer:
- storageClass = SpvStorageClassUniform;
+ addressSpace = AddressSpace::Uniform;
break;
case LayoutResourceKind::VaryingInput:
- storageClass = SpvStorageClassInput;
+ addressSpace = AddressSpace::Input;
break;
case LayoutResourceKind::VaryingOutput:
- storageClass = SpvStorageClassOutput;
+ addressSpace = AddressSpace::Output;
break;
case LayoutResourceKind::ShaderResource:
case LayoutResourceKind::UnorderedAccess:
- storageClass = getStorageBufferStorageClass();
+ addressSpace = getStorageBufferAddressSpace();
break;
case LayoutResourceKind::PushConstantBuffer:
- storageClass = SpvStorageClassPushConstant;
+ addressSpace = AddressSpace::PushConstant;
break;
case LayoutResourceKind::RayPayload:
- storageClass = SpvStorageClassIncomingRayPayloadKHR;
+ addressSpace = AddressSpace::IncomingRayPayload;
break;
case LayoutResourceKind::CallablePayload:
- storageClass = SpvStorageClassIncomingCallableDataKHR;
+ addressSpace = AddressSpace::IncomingCallableData;
break;
case LayoutResourceKind::HitAttributes:
- storageClass = SpvStorageClassHitAttributeKHR;
+ addressSpace = AddressSpace::HitAttribute;
break;
case LayoutResourceKind::ShaderRecord:
- storageClass = SpvStorageClassShaderRecordBufferKHR;
+ addressSpace = AddressSpace::ShaderRecordBuffer;
break;
default:
break;
}
- return storageClass;
+ return addressSpace;
}
- SpvStorageClass getGlobalParamStorageClass(IRVarLayout* varLayout)
+ AddressSpace getGlobalParamAddressSpace(IRVarLayout* varLayout)
{
auto typeLayout = varLayout->getTypeLayout()->unwrapArray();
if (auto parameterGroupTypeLayout = as<IRParameterGroupTypeLayout>(typeLayout))
@@ -754,22 +754,22 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
varLayout = parameterGroupTypeLayout->getContainerVarLayout();
}
- SpvStorageClass result = SpvStorageClassMax;
+ auto result = AddressSpace::Generic;
for (auto rr : varLayout->getOffsetAttrs())
{
- auto storageClass = getStorageClassFromGlobalParamResourceKind(rr->getResourceKind());
+ auto addressSpace = getAddressSpaceFromGlobalParamResourceKind(rr->getResourceKind());
// If we haven't inferred a storage class yet, use the one we just found.
- if (result == SpvStorageClassMax)
- result = storageClass;
- else if (result != storageClass)
+ if (result == AddressSpace::Generic)
+ result = addressSpace;
+ else if (result != addressSpace)
{
- // If we have inferred a storage class, and it is different from the one we just found,
- // then we have conflicting uses of the resource, and we cannot infer a storage class.
+ // If we have inferred an address space, and it is different from the one we just found,
+ // then we have conflicting uses of the resource, and we cannot infer an address space.
// An exception is that a uniform storage class can be further specialized by PushConstants.
- if (result == SpvStorageClassUniform)
- result = storageClass;
+ if (result == AddressSpace::Uniform)
+ result = addressSpace;
else
- SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class.");
+ SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class address space.");
}
}
return result;
@@ -783,7 +783,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
- oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction);
+ oldPtrType->getOp(), oldPtrType->getValueType(), AddressSpace::Function);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
}
@@ -862,45 +862,45 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return;
}
- SpvStorageClass storageClass = SpvStorageClassPrivate;
+ auto addressSpace = AddressSpace::ThreadLocal;
if (as<IRGroupSharedRate>(inst->getRate()))
{
- storageClass = SpvStorageClassWorkgroup;
+ addressSpace = AddressSpace::GroupShared;
}
else if (const auto varLayout = getVarLayout(inst))
{
- auto cls = getGlobalParamStorageClass(varLayout);
- if (cls != SpvStorageClassMax)
- storageClass = cls;
+ auto cls = getGlobalParamAddressSpace(varLayout);
+ if (cls != AddressSpace::Generic)
+ addressSpace = cls;
}
for (auto decor : inst->getDecorations())
{
switch (decor->getOp())
{
case kIROp_VulkanRayPayloadDecoration:
- storageClass = SpvStorageClassRayPayloadKHR;
+ addressSpace = AddressSpace::RayPayloadKHR;
break;
case kIROp_VulkanRayPayloadInDecoration:
- storageClass = SpvStorageClassIncomingRayPayloadKHR;
+ addressSpace = AddressSpace::IncomingRayPayload;
break;
case kIROp_VulkanCallablePayloadDecoration:
- storageClass = SpvStorageClassCallableDataKHR;
+ addressSpace = AddressSpace::CallableDataKHR;
break;
case kIROp_VulkanCallablePayloadInDecoration:
- storageClass = SpvStorageClassIncomingCallableDataKHR;
+ addressSpace = AddressSpace::IncomingCallableData;
break;
case kIROp_VulkanHitObjectAttributesDecoration:
- storageClass = SpvStorageClassHitObjectAttributeNV;
+ addressSpace = AddressSpace::HitObjectAttribute;
break;
case kIROp_VulkanHitAttributesDecoration:
- storageClass = SpvStorageClassHitAttributeKHR;
+ addressSpace = AddressSpace::HitAttribute;
break;
}
}
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto newPtrType =
- builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass);
+ builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), addressSpace);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
return;
@@ -916,22 +916,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (!snippet)
return;
if (snippet->resultStorageClass != SpvStorageClassMax)
- {
- auto ptrType = as<IRPtrTypeBase>(inst->getDataType());
- if (!ptrType)
- return;
- IRBuilder builder(m_sharedContext->m_irModule);
- builder.setInsertBefore(inst);
- auto qualPtrType = builder.getPtrType(
- ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass);
- List<IRInst*> args;
- for (UInt i = 0; i < inst->getArgCount(); i++)
- args.add(inst->getArg(i));
- auto newCall = builder.emitCallInst(qualPtrType, funcValue, args);
- inst->replaceUsesWith(newCall);
- inst->removeAndDeallocate();
- addUsersToWorkList(newCall);
- }
+ SLANG_UNIMPLEMENTED_X("Specifying storage classes in spirv __target_intrinsic snippets");
return;
}
@@ -1065,7 +1050,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(inst);
else
setInsertAfterOrdinaryInst(&builder, x);
- y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
+ y = builder.emitVar(x->getDataType(), AddressSpace::Function);
builder.emitStore(y, x);
if (x->getParent()->getOp() != kIROp_Module)
m_mapArrayValueToVar.set(x, y);
@@ -1152,9 +1137,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
- SpvStorageClass getStorageBufferStorageClass()
+ AddressSpace getStorageBufferAddressSpace()
{
- return m_sharedContext->isSpirv14OrLater() ? SpvStorageClassStorageBuffer : SpvStorageClassUniform;
+ return m_sharedContext->isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform;
}
void processStructuredBufferLoad(IRInst* loadInst)
@@ -1165,7 +1150,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(loadInst);
IRInst* args[] = { sb, index };
auto addrInst = builder.emitIntrinsicInst(
- builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferStorageClass()),
+ builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferAddressSpace()),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1184,7 +1169,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(storeInst);
IRInst* args[] = { sb, index };
auto addrInst = builder.emitIntrinsicInst(
- builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferStorageClass()),
+ builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferAddressSpace()),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1324,7 +1309,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newPtrType = builder.getPtrType(
ptrType->getOp(),
ptrType->getValueType(),
- SpvStorageClassImage);
+ AddressSpace::Image);
subscript->setFullType(newPtrType);
// HACK: assumes the image operand is a load and replace it with
@@ -2175,7 +2160,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (auto ptrType = as<IRPtrTypeBase>(type))
{
if (ptrType->hasAddressSpace())
- return (AddressSpace)ptrType->getAddressSpace();
+ return ptrType->getAddressSpace();
}
return AddressSpace::Generic;
}
@@ -2240,7 +2225,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto lowered = lowerStructuredBufferType(t);
IRBuilder builder(t);
builder.setInsertBefore(t);
- t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, getStorageBufferStorageClass()));
+ t->replaceUsesWith(builder.getPtrType(kIROp_PtrType, lowered.structType, getStorageBufferAddressSpace()));
}
for (auto t : textureFootprintTypes)
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 1441b0567..e0769686c 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2881,9 +2881,9 @@ namespace Slang
operands);
}
- IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace)
+ IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
{
- return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), addressSpace));
+ return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace)));
}
IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace)
@@ -4530,7 +4530,7 @@ namespace Slang
IRGlobalVar* IRBuilder::createGlobalVar(
IRType* valueType,
- IRIntegerValue addressSpace)
+ AddressSpace addressSpace)
{
auto ptrType = getPtrType(kIROp_PtrType, valueType, addressSpace);
IRGlobalVar* globalVar = createInst<IRGlobalVar>(
@@ -4807,7 +4807,7 @@ namespace Slang
IRVar* IRBuilder::emitVar(
IRType* type,
- IRIntegerValue addressSpace)
+ AddressSpace addressSpace)
{
auto allocatedType = getPtrType(kIROp_PtrType, type, addressSpace);
auto inst = createInst<IRVar>(
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 6009ef33e..b10c88a15 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -209,7 +209,7 @@ bool isPointerToResourceType(IRType* type)
{
while (auto ptrType = as<IRPtrTypeBase>(type))
{
- if (ptrType->getAddressSpace() == AddressSpace(SpvStorageClassStorageBuffer) ||
+ if (ptrType->getAddressSpace() == AddressSpace::StorageBuffer ||
ptrType->getAddressSpace() == AddressSpace::UserPointer)
return true;
type = ptrType->getValueType();
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index 404c84cf4..adf6e26f8 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -62,12 +62,45 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE)
enum class AddressSpace : uint64_t
{
Generic = 0x7fffffff,
+ // Corresponds to SPIR-V's SpvStorageClassPrivate
ThreadLocal = 1,
- Global = 2,
- GroupShared = 3,
- Uniform = 4,
+ Global,
+ // Corresponds to SPIR-V's SpvStorageClassWorkgroup
+ GroupShared,
+ // Corresponds to SPIR-V's SpvStorageClassUniform
+ Uniform,
// specific address space for payload data in metal
- MetalObjectData = 5,
+ MetalObjectData,
+ // Corresponds to SPIR-V's SpvStorageClassInput
+ Input,
+ // Corresponds to SPIR-V's SpvStorageClassOutput
+ Output,
+ // Corresponds to SPIR-V's SpvStorageClassTaskPayloadWorkgroupEXT
+ TaskPayloadWorkgroup,
+ // Corresponds to SPIR-V's SpvStorageClassFunction
+ Function,
+ // Corresponds to SPIR-V's SpvStorageClassStorageBuffer
+ StorageBuffer,
+ // Corresponds to SPIR-V's SpvStorageClassPushConstant,
+ PushConstant,
+ // Corresponds to SPIR-V's SpvStorageClassRayPayloadKHR,
+ RayPayloadKHR,
+ // Corresponds to SPIR-V's SpvStorageClassIncomingRayPayloadKHR,
+ IncomingRayPayload,
+ // Corresponds to SPIR-V's SpvStorageClassCallableDataKHR
+ CallableDataKHR,
+ // Corresponds to SPIR-V's SpvStorageClassIncomingCallableDataKHR
+ IncomingCallableData,
+ // Corresponds to SPIR-V's SpvStorageClassHitObjectAttributeNV,
+ HitObjectAttribute,
+ // Corresponds to SPIR-V's SpvStorageClassHitAttributeKHR,
+ HitAttribute,
+ // Corresponds to SPIR-V's SpvStorageClassShaderRecordBufferKHR,
+ ShaderRecordBuffer,
+ // Corresponds to SPIR-V's SpvStorageClassUniformConstant,
+ UniformConstant,
+ // Corresponds to SPIR-V's SpvStorageClassImage
+ Image,
// Default address space for a user-defined pointer
UserPointer = 0x100000001ULL,