summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-05-06 19:21:03 -0700
committerGitHub <noreply@github.com>2024-05-06 19:21:03 -0700
commit1b3a428bfa24350d9d69b092747b4ad142b7c4b4 (patch)
tree5664611c59824c4c208660c7351bdf68e37bec26 /source
parent618428a87b8295347288262ea13eff63cc62aa56 (diff)
Support groupshared variables for Metal. (#4116)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-metal.cpp42
-rw-r--r--source/slang/slang-emit-metal.h2
-rw-r--r--source/slang/slang-ir-explicit-global-context.cpp90
-rw-r--r--source/slang/slang-ir-insts.h3
4 files changed, 120 insertions, 17 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index 7580ed74d..2c327b613 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -494,7 +494,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_ParameterBlockType:
case kIROp_ConstantBufferType:
{
- emitType((IRType*)type->getOperand(0));
+ emitSimpleTypeImpl((IRType*)type->getOperand(0));
m_writer->emit(" constant*");
return;
}
@@ -607,11 +607,17 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}
-void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
+void MetalSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator)
{
- if (as<IRGroupSharedRate>(rate))
+ switch (type->getOp())
{
- m_writer->emit("threadgroup ");
+ case kIROp_ArrayType:
+ emitSimpleType(type);
+ emitDeclarator(declarator);
+ break;
+ default:
+ Super::_emitType(type, declarator);
+ break;
}
}
@@ -796,6 +802,34 @@ void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTy
// We emit packoffset as a semantic in `emitSemantic`, so nothing to do here.
}
+void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace)
+{
+ if (as<IRGroupSharedRate>(rate))
+ {
+ m_writer->emit("threadgroup ");
+ return;
+ }
+
+ switch ((AddressSpace)addressSpace)
+ {
+ case AddressSpace::GroupShared:
+ m_writer->emit("threadgroup ");
+ break;
+ case AddressSpace::Uniform:
+ m_writer->emit("constant ");
+ break;
+ case AddressSpace::Global:
+ m_writer->emit("device ");
+ break;
+ case AddressSpace::ThreadLocal:
+ m_writer->emit("thread ");
+ break;
+ default:
+ break;
+ }
+}
+
+
void MetalSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst)
{
SLANG_UNUSED(varInst);
diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h
index d925365da..fc1390143 100644
--- a/source/slang/slang-emit-metal.h
+++ b/source/slang/slang-emit-metal.h
@@ -57,6 +57,8 @@ protected:
void emitFuncParamLayoutImpl(IRInst* param);
+ virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
+
void _emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type);
void _emitHLSLTextureType(IRTextureTypeBase* texType);
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp
index f63ceb71e..9bbaf3875 100644
--- a/source/slang/slang-ir-explicit-global-context.cpp
+++ b/source/slang/slang-ir-explicit-global-context.cpp
@@ -24,6 +24,11 @@ struct IntroduceExplicitGlobalContextPass
List<IRGlobalVar*> m_globalVars;
List<IRFunc*> m_entryPoints;
+ enum class GlobalObjectKind
+ {
+ GlobalParam, GlobalVar
+ };
+
void processModule()
{
IRBuilder builder(m_module);
@@ -181,14 +186,14 @@ struct IntroduceExplicitGlobalContextPass
// parameters, we create a field that exactly matches its type.
//
- createContextStructField(globalParam, globalParam->getFullType());
+ createContextStructField(globalParam, GlobalObjectKind::GlobalParam, globalParam->getFullType());
}
for( auto globalVar : m_globalVars )
{
// A `IRGlobalVar` represents a pointer to where the variable is stored,
// so we need to create a field of the pointed-to type to represent it.
//
- createContextStructField(globalVar, globalVar->getDataType()->getValueType());
+ createContextStructField(globalVar, GlobalObjectKind::GlobalVar, getGlobalVarPtrType(globalVar));
}
// Once all the fields have been created, we can process the entry points.
@@ -229,10 +234,18 @@ struct IntroduceExplicitGlobalContextPass
// variable parameter, and to record the context pointer
// value to use for a function.
//
- Dictionary<IRInst*, IRStructKey*> m_mapInstToContextFieldKey;
+ struct ContextFieldInfo
+ {
+ IRStructKey* key = nullptr;
+
+ // Is this field a pointer to the actual value?
+ // For groupshared variables, this will be true.
+ bool needDereference = false;
+ };
+ Dictionary<IRInst*, ContextFieldInfo> m_mapInstToContextFieldInfo;
Dictionary<IRFunc*, IRInst*> m_mapFuncToContextPtr;
- void createContextStructField(IRInst* originalInst, IRType* type)
+ void createContextStructField(IRInst* originalInst, GlobalObjectKind kind, IRType* type)
{
// Creating a field in the context struct to represent
// `originalInst` is straightforward.
@@ -240,11 +253,27 @@ struct IntroduceExplicitGlobalContextPass
IRBuilder builder(m_module);
builder.setInsertBefore(m_contextStructType);
+ IRType* fieldDataType = type;
+ bool needDereference = false;
+ if (kind == GlobalObjectKind::GlobalVar)
+ {
+ auto ptrType = as<IRPtrTypeBase>(type);
+ if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared)
+ {
+ fieldDataType = ptrType;
+ needDereference = true;
+ }
+ else
+ {
+ fieldDataType = as<IRPtrTypeBase>(type)->getValueType();
+ }
+ }
+
// We create a "key" for the new field, and then a field
// of the appropraite type.
//
auto key = builder.createStructKey();
- builder.createStructField(m_contextStructType, key, type);
+ builder.createStructField(m_contextStructType, key, fieldDataType);
// Clone all original decorations to the new struct key.
IRCloneEnv cloneEnv;
@@ -254,7 +283,7 @@ struct IntroduceExplicitGlobalContextPass
// for the instruction, so that we can use the key
// to access the field later.
//
- m_mapInstToContextFieldKey.add(originalInst, key);
+ m_mapInstToContextFieldInfo.add(originalInst, ContextFieldInfo{ key, needDereference });
}
void createContextForEntryPoint(IRFunc* entryPointFunc)
@@ -321,14 +350,14 @@ struct IntroduceExplicitGlobalContextPass
//
for (auto entryPointParam : entryPointParams)
{
- auto fieldKey = m_mapInstToContextFieldKey[entryPointParam.globalParam];
+ auto fieldInfo = m_mapInstToContextFieldInfo[entryPointParam.globalParam];
auto fieldType = entryPointParam.globalParam->getFullType();
auto fieldPtrType = builder.getPtrType(fieldType);
// We compute the addrress of the field and store the
// value of the parameter into it.
//
- auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldKey);
+ auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldInfo.key);
builder.emitStore(fieldPtr, entryPointParam.entryPointParam);
}
@@ -341,6 +370,27 @@ struct IntroduceExplicitGlobalContextPass
// run the pass in `slang-ir-explicit-global-init` first,
// in order to move all initialization of globals into the
// entry point functions.
+ //
+ // To support groupshared variables on Metal,we need to allocate the
+ // memory by defining a local variable in the entry point, and pass
+ // the address of that variable to the context.
+ //
+ for (auto globalVar : m_globalVars)
+ {
+ auto fieldInfo = m_mapInstToContextFieldInfo[globalVar];
+ if (fieldInfo.needDereference)
+ {
+ auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared);
+ if (auto nameDecor = globalVar->findDecoration<IRNameHintDecoration>())
+ {
+ builder.addNameHintDecoration(var, nameDecor->getName());
+ }
+ auto ptrPtrType = builder.getPtrType(getGlobalVarPtrType(globalVar), AddressSpace::ThreadLocal);
+ auto fieldPtr = builder.emitFieldAddress(ptrPtrType, contextVarPtr, fieldInfo.key);
+ builder.emitStore(fieldPtr, var);
+ }
+ }
+
}
void replaceUsesOfGlobalParam(IRGlobalParam* globalParam)
@@ -350,7 +400,7 @@ struct IntroduceExplicitGlobalContextPass
// A global shader parameter was mapped to a field
// in the context structure, so we find the appropriate key.
//
- auto key = m_mapInstToContextFieldKey[globalParam];
+ auto fieldInfo = m_mapInstToContextFieldInfo[globalParam];
auto valType = globalParam->getFullType();
auto ptrType = builder.getPtrType(valType);
@@ -375,12 +425,22 @@ struct IntroduceExplicitGlobalContextPass
// taking the address of the corresponding field
// in the context struct and loading from it.
//
- auto ptr = builder.emitFieldAddress(ptrType, contextParam, key);
+ auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key);
auto val = builder.emitLoad(valType, ptr);
use->set(val);
}
}
+ IRType* getGlobalVarPtrType(IRGlobalVar* globalVar)
+ {
+ IRBuilder builder(globalVar);
+ if (as<IRGroupSharedRate>(globalVar->getRate()))
+ {
+ return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared);
+ }
+ return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::ThreadLocal);
+ }
+
void replaceUsesOfGlobalVar(IRGlobalVar* globalVar)
{
IRBuilder builder(m_module);
@@ -388,9 +448,11 @@ struct IntroduceExplicitGlobalContextPass
// A global variable was mapped to a field
// in the context structure, so we find the appropriate key.
//
- auto key = m_mapInstToContextFieldKey[globalVar];
+ auto fieldInfo = m_mapInstToContextFieldInfo[globalVar];
- auto ptrType = globalVar->getDataType();
+ auto ptrType = getGlobalVarPtrType(globalVar);
+ if (fieldInfo.needDereference)
+ ptrType = builder.getPtrType(kIROp_PtrType, ptrType, AddressSpace::ThreadLocal);
// We then iterate over the uses of the variable,
// being careful to defend against the use/def information
@@ -412,7 +474,9 @@ struct IntroduceExplicitGlobalContextPass
// taking the address of the corresponding field
// in the context struct.
//
- auto ptr = builder.emitFieldAddress(ptrType, contextParam, key);
+ auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key);
+ if (fieldInfo.needDereference)
+ ptr = builder.emitLoad(ptr);
use->set(ptr);
}
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9329e3806..5c4f01ae7 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3435,6 +3435,9 @@ public:
IRConstRefType* getConstRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace);
+ IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); }
+ IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); }
+
IRTextureTypeBase* getTextureType(
IRType* elementType,
IRInst* shape,