diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-06 19:21:03 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-06 19:21:03 -0700 |
| commit | 1b3a428bfa24350d9d69b092747b4ad142b7c4b4 (patch) | |
| tree | 5664611c59824c4c208660c7351bdf68e37bec26 /source/slang/slang-ir-explicit-global-context.cpp | |
| parent | 618428a87b8295347288262ea13eff63cc62aa56 (diff) | |
Support groupshared variables for Metal. (#4116)
Diffstat (limited to 'source/slang/slang-ir-explicit-global-context.cpp')
| -rw-r--r-- | source/slang/slang-ir-explicit-global-context.cpp | 90 |
1 files changed, 77 insertions, 13 deletions
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index f63ceb71e..9bbaf3875 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -24,6 +24,11 @@ struct IntroduceExplicitGlobalContextPass List<IRGlobalVar*> m_globalVars; List<IRFunc*> m_entryPoints; + enum class GlobalObjectKind + { + GlobalParam, GlobalVar + }; + void processModule() { IRBuilder builder(m_module); @@ -181,14 +186,14 @@ struct IntroduceExplicitGlobalContextPass // parameters, we create a field that exactly matches its type. // - createContextStructField(globalParam, globalParam->getFullType()); + createContextStructField(globalParam, GlobalObjectKind::GlobalParam, globalParam->getFullType()); } for( auto globalVar : m_globalVars ) { // A `IRGlobalVar` represents a pointer to where the variable is stored, // so we need to create a field of the pointed-to type to represent it. // - createContextStructField(globalVar, globalVar->getDataType()->getValueType()); + createContextStructField(globalVar, GlobalObjectKind::GlobalVar, getGlobalVarPtrType(globalVar)); } // Once all the fields have been created, we can process the entry points. @@ -229,10 +234,18 @@ struct IntroduceExplicitGlobalContextPass // variable parameter, and to record the context pointer // value to use for a function. // - Dictionary<IRInst*, IRStructKey*> m_mapInstToContextFieldKey; + struct ContextFieldInfo + { + IRStructKey* key = nullptr; + + // Is this field a pointer to the actual value? + // For groupshared variables, this will be true. + bool needDereference = false; + }; + Dictionary<IRInst*, ContextFieldInfo> m_mapInstToContextFieldInfo; Dictionary<IRFunc*, IRInst*> m_mapFuncToContextPtr; - void createContextStructField(IRInst* originalInst, IRType* type) + void createContextStructField(IRInst* originalInst, GlobalObjectKind kind, IRType* type) { // Creating a field in the context struct to represent // `originalInst` is straightforward. @@ -240,11 +253,27 @@ struct IntroduceExplicitGlobalContextPass IRBuilder builder(m_module); builder.setInsertBefore(m_contextStructType); + IRType* fieldDataType = type; + bool needDereference = false; + if (kind == GlobalObjectKind::GlobalVar) + { + auto ptrType = as<IRPtrTypeBase>(type); + if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared) + { + fieldDataType = ptrType; + needDereference = true; + } + else + { + fieldDataType = as<IRPtrTypeBase>(type)->getValueType(); + } + } + // We create a "key" for the new field, and then a field // of the appropraite type. // auto key = builder.createStructKey(); - builder.createStructField(m_contextStructType, key, type); + builder.createStructField(m_contextStructType, key, fieldDataType); // Clone all original decorations to the new struct key. IRCloneEnv cloneEnv; @@ -254,7 +283,7 @@ struct IntroduceExplicitGlobalContextPass // for the instruction, so that we can use the key // to access the field later. // - m_mapInstToContextFieldKey.add(originalInst, key); + m_mapInstToContextFieldInfo.add(originalInst, ContextFieldInfo{ key, needDereference }); } void createContextForEntryPoint(IRFunc* entryPointFunc) @@ -321,14 +350,14 @@ struct IntroduceExplicitGlobalContextPass // for (auto entryPointParam : entryPointParams) { - auto fieldKey = m_mapInstToContextFieldKey[entryPointParam.globalParam]; + auto fieldInfo = m_mapInstToContextFieldInfo[entryPointParam.globalParam]; auto fieldType = entryPointParam.globalParam->getFullType(); auto fieldPtrType = builder.getPtrType(fieldType); // We compute the addrress of the field and store the // value of the parameter into it. // - auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldKey); + auto fieldPtr = builder.emitFieldAddress(fieldPtrType, contextVarPtr, fieldInfo.key); builder.emitStore(fieldPtr, entryPointParam.entryPointParam); } @@ -341,6 +370,27 @@ struct IntroduceExplicitGlobalContextPass // run the pass in `slang-ir-explicit-global-init` first, // in order to move all initialization of globals into the // entry point functions. + // + // To support groupshared variables on Metal,we need to allocate the + // memory by defining a local variable in the entry point, and pass + // the address of that variable to the context. + // + for (auto globalVar : m_globalVars) + { + auto fieldInfo = m_mapInstToContextFieldInfo[globalVar]; + if (fieldInfo.needDereference) + { + auto var = builder.emitVar(globalVar->getDataType()->getValueType(), (IRIntegerValue)AddressSpace::GroupShared); + if (auto nameDecor = globalVar->findDecoration<IRNameHintDecoration>()) + { + builder.addNameHintDecoration(var, nameDecor->getName()); + } + auto ptrPtrType = builder.getPtrType(getGlobalVarPtrType(globalVar), AddressSpace::ThreadLocal); + auto fieldPtr = builder.emitFieldAddress(ptrPtrType, contextVarPtr, fieldInfo.key); + builder.emitStore(fieldPtr, var); + } + } + } void replaceUsesOfGlobalParam(IRGlobalParam* globalParam) @@ -350,7 +400,7 @@ struct IntroduceExplicitGlobalContextPass // A global shader parameter was mapped to a field // in the context structure, so we find the appropriate key. // - auto key = m_mapInstToContextFieldKey[globalParam]; + auto fieldInfo = m_mapInstToContextFieldInfo[globalParam]; auto valType = globalParam->getFullType(); auto ptrType = builder.getPtrType(valType); @@ -375,12 +425,22 @@ struct IntroduceExplicitGlobalContextPass // taking the address of the corresponding field // in the context struct and loading from it. // - auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key); auto val = builder.emitLoad(valType, ptr); use->set(val); } } + IRType* getGlobalVarPtrType(IRGlobalVar* globalVar) + { + IRBuilder builder(globalVar); + if (as<IRGroupSharedRate>(globalVar->getRate())) + { + return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::GroupShared); + } + return builder.getPtrType(globalVar->getDataType()->getValueType(), AddressSpace::ThreadLocal); + } + void replaceUsesOfGlobalVar(IRGlobalVar* globalVar) { IRBuilder builder(m_module); @@ -388,9 +448,11 @@ struct IntroduceExplicitGlobalContextPass // A global variable was mapped to a field // in the context structure, so we find the appropriate key. // - auto key = m_mapInstToContextFieldKey[globalVar]; + auto fieldInfo = m_mapInstToContextFieldInfo[globalVar]; - auto ptrType = globalVar->getDataType(); + auto ptrType = getGlobalVarPtrType(globalVar); + if (fieldInfo.needDereference) + ptrType = builder.getPtrType(kIROp_PtrType, ptrType, AddressSpace::ThreadLocal); // We then iterate over the uses of the variable, // being careful to defend against the use/def information @@ -412,7 +474,9 @@ struct IntroduceExplicitGlobalContextPass // taking the address of the corresponding field // in the context struct. // - auto ptr = builder.emitFieldAddress(ptrType, contextParam, key); + auto ptr = builder.emitFieldAddress(ptrType, contextParam, fieldInfo.key); + if (fieldInfo.needDereference) + ptr = builder.emitLoad(ptr); use->set(ptr); } } |
