diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-08 23:06:46 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-08 23:06:46 -0700 |
| commit | bf088c3f12cb47d204fdd3df1bb8a2415d46ba7b (patch) | |
| tree | 82145968864a816ceba4c46619c3841b9a0befd4 /source | |
| parent | 526430a0e7053b04eeb9b0c6514065a850042aaf (diff) | |
Metal: propagate and specialize address space. (#4137)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 413 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.h | 14 |
7 files changed, 447 insertions, 2 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 626c372e9..2119d8fee 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1748,6 +1748,15 @@ void CLikeSourceEmitter::emitArgs(IRInst* inst) m_writer->emit(")"); } +void CLikeSourceEmitter::emitRateQualifiers(IRInst* value) +{ + const auto rate = value->getRate(); + if (rate) + { + emitRateQualifiersAndAddressSpaceImpl(rate, -1); + } +} + void CLikeSourceEmitter::emitRateQualifiersAndAddressSpace(IRInst* value) { const auto rate = value->getRate(); @@ -1770,7 +1779,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitTempModifiers(inst); - emitRateQualifiersAndAddressSpace(inst); + emitRateQualifiers(inst); if(as<IRModuleInst>(inst->getParent())) { diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index ba17caace..450770238 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -306,7 +306,7 @@ public: void emitArgs(IRInst* inst); - + void emitRateQualifiers(IRInst* value); void emitRateQualifiersAndAddressSpace(IRInst* value); void emitInstResultDecl(IRInst* inst); diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 7da48cac1..4d8a207c3 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -420,6 +420,11 @@ void MetalSourceEmitter::emitSimpleValueImpl(IRInst* inst) Super::emitSimpleValueImpl(inst); } +void MetalSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) +{ + emitType(type, name); +} + void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) { switch (type->getOp()) diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index fc1390143..a60d28b96 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -37,6 +37,7 @@ protected: virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitParamTypeImpl(IRType* type, String const& name) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 822a1e2f1..f81608316 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir-clone.h" +#include "slang-ir-specialize-address-space.h" namespace Slang { @@ -333,5 +334,7 @@ namespace Slang for (auto entryPoint : entryPoints) legalizeEntryPointForMetal(entryPoint, sink); + + specializeAddressSpace(module); } } diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp new file mode 100644 index 000000000..5a1874e08 --- /dev/null +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -0,0 +1,413 @@ +#include "slang-ir-specialize-address-space.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + struct AddressSpaceContext + { + IRModule* module; + + Dictionary<IRInst*, AddressSpace> mapInstToAddrSpace; + + AddressSpaceContext(IRModule* inModule) + : module(inModule) + { + } + + AddressSpace getLeafInstAddressSpace(IRInst* inst) + { + if (as<IRGroupSharedRate>(inst->getRate())) + return AddressSpace::GroupShared; + switch (inst->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + return AddressSpace::Global; + case kIROp_Var: + if (as<IRBlock>(inst->getParent())) + return AddressSpace::ThreadLocal; + break; + default: + break; + } + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + if (as<IRUniformParameterGroupType>(type)) + { + return AddressSpace::Uniform; + } + if (as<IRByteAddressBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRHLSLStructuredBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRGLSLShaderStorageBufferType>(type)) + { + return AddressSpace::Global; + } + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return (AddressSpace)ptrType->getAddressSpace(); + return AddressSpace::Global; + } + return AddressSpace::Generic; + } + + AddressSpace getAddrSpace(IRInst* inst) + { + auto addrSpace = mapInstToAddrSpace.tryGetValue(inst); + if (addrSpace) + return *addrSpace; + return AddressSpace::Generic; + } + + List<IRFunc*> workList; + + struct FuncSpecializationKey + { + private: + IRFunc* func; + List<AddressSpace> argAddrSpaces; + HashCode hashCode; + public: + IRFunc* getFunc() const { return func; } + ArrayView<AddressSpace> getArgAddrSpaces() const { return argAddrSpaces.getArrayView(); } + + FuncSpecializationKey() = default; + + FuncSpecializationKey(IRFunc* func, List<AddressSpace> argAddrSpaces) + : func(func) + , argAddrSpaces(argAddrSpaces) + { + Hasher hasher; + hasher.addHash(Slang::getHashCode(func)); + for (auto addrSpace : argAddrSpaces) + { + hasher.addHash((HashCode)addrSpace); + } + hashCode = hasher.getResult(); + } + + bool operator==(const FuncSpecializationKey& key) const + { + if (func != key.func) + return false; + if (argAddrSpaces.getCount() != key.argAddrSpaces.getCount()) + return false; + for (Index i = 0; i < argAddrSpaces.getCount(); i++) + { + if (argAddrSpaces[i] != key.argAddrSpaces[i]) + return false; + } + return true; + } + + HashCode getHashCode() const + { + return hashCode; + } + }; + + Dictionary<FuncSpecializationKey, IRFunc*> functionSpecializations; + + IRFunc* specializeFunc(const FuncSpecializationKey& key) + { + auto func = key.getFunc(); + IRCloneEnv cloneEnv; + IRBuilder builder(module); + + // First, clone the function body. + builder.setInsertBefore(func); + auto specializedFunc = as<IRFunc>(cloneInst(&cloneEnv, &builder, func)); + + // Update the parameter types with new address spaces in the specialized function. + Index paramIndex = 0; + for (auto param : specializedFunc->getParams()) + { + auto paramType = param->getFullType(); + auto ptrType = as<IRPtrTypeBase>(paramType); + if (ptrType) + { + auto paramAddrSpace = key.getArgAddrSpaces()[paramIndex]; + auto newParamType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), paramAddrSpace); + param->setFullType(newParamType); + mapInstToAddrSpace[param] = paramAddrSpace; + } + paramIndex++; + } + + // Update the function type. + fixUpFuncType(specializedFunc); + + functionSpecializations[key] = specializedFunc; + return specializedFunc; + } + + AddressSpace getFuncResultAddrSpace(IRFunc* callee) + { + auto funcType = as<IRFuncType>(callee->getDataType()); + auto ptrResultType = as<IRPtrTypeBase>(funcType->getResultType()); + if (!ptrResultType) + return AddressSpace::Generic; + AddressSpace resultAddrSpace = AddressSpace::Generic; + if (ptrResultType->hasAddressSpace()) + resultAddrSpace = (AddressSpace)ptrResultType->getAddressSpace(); + return resultAddrSpace; + } + + // Return true if the address space of the function return type is changed. + bool processFunction(IRFunc* func) + { + bool retValAddrSpaceChanged = false; + Dictionary<IRInst*, AddressSpace> mapVarValueToAddrSpace; + bool changed = true; + while (changed) + { + changed = false; + for (auto block : func->getBlocks()) + { + bool isFirstBlock = block == func->getFirstBlock(); + + for (auto inst : block->getChildren()) + { + // If we have already assigned an address space to this instruction, then skip it. + if (mapInstToAddrSpace.containsKey(inst)) + { + // TODO: if the inst is a phi node, we need to check if the address space of the phi arguments + // is consistent. If not, then we need to report an error. + // For now, we just skip the checks. + continue; + } + + switch (inst->getOp()) + { + case kIROp_Var: + { + // All local variables should be in the thread-local address space. + mapInstToAddrSpace[inst] = AddressSpace::ThreadLocal; + changed = true; + break; + } + case kIROp_RWStructuredBufferGetElementPtr: + { + // The address space of the result of RWStructuredBufferGetElementPtr is always global. + mapInstToAddrSpace[inst] = AddressSpace::Global; + changed = true; + break; + } + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + if (!mapInstToAddrSpace.containsKey(inst)) + { + auto addrSpace = getAddrSpace(inst->getOperand(0)); + if (addrSpace != AddressSpace::Generic) + { + mapInstToAddrSpace[inst] = addrSpace; + changed = true; + } + } + break; + case kIROp_Store: + { + auto addrSpace = getAddrSpace(inst->getOperand(1)); + if (addrSpace != AddressSpace::Generic) + { + mapVarValueToAddrSpace[inst->getOperand(0)] = addrSpace; + changed = true; + } + } + break; + case kIROp_Load: + { + if (auto addrSpace = mapVarValueToAddrSpace.tryGetValue(inst->getOperand(0))) + { + mapInstToAddrSpace[inst] = *addrSpace; + changed = true; + } + } + break; + case kIROp_Param: + if (!isFirstBlock) + { + auto phiArgs = getPhiArgs(inst); + AddressSpace addrSpace = AddressSpace::Generic; + for (auto arg : phiArgs) + { + auto argAddrSpace = getAddrSpace(arg); + if (argAddrSpace != AddressSpace::Generic) + { + if (addrSpace != AddressSpace::Generic && addrSpace != argAddrSpace) + { + // TODO: this is an error in user code, because the address spaces of the + // phi arguments don't match. + } + addrSpace = argAddrSpace; + } + } + if (addrSpace != AddressSpace::Generic) + { + mapInstToAddrSpace[inst] = addrSpace; + changed = true; + } + break; + } + break; + case kIROp_Call: + { + auto callInst = as<IRCall>(inst); + auto callee = as<IRFunc>(inst->getOperand(0)); + if (callee) + { + List<AddressSpace> argAddrSpaces; + bool fullySpecialized = true; + for (UInt i = 0; i < callInst->getArgCount(); i++) + { + auto arg = callInst->getArg(i); + auto argAddrSpace = getAddrSpace(arg); + argAddrSpaces.add(getAddrSpace(arg)); + if (argAddrSpace == AddressSpace::Generic && + as<IRPtrTypeBase>(arg->getDataType())) + { + fullySpecialized = false; + break; + } + } + if (!fullySpecialized) + break; + + FuncSpecializationKey key(callee, argAddrSpaces); + IRFunc* specializedCallee = nullptr; + if (IRFunc** specializedFunc = functionSpecializations.tryGetValue(key)) + { + specializedCallee = *specializedFunc; + } + else + { + specializedCallee = specializeFunc(key); + workList.add(specializedCallee); + } + IRBuilder builder(callInst); + builder.setInsertBefore(callInst); + if (specializedCallee != callInst->getCallee()) + { + callInst = as<IRCall>(builder.replaceOperand(callInst->getOperands(), specializedCallee)); + } + auto callResultAddrSpace = getFuncResultAddrSpace(specializedCallee); + if (callResultAddrSpace != AddressSpace::Generic) + { + mapInstToAddrSpace[callInst] = callResultAddrSpace; + changed = true; + } + } + } + break; + case kIROp_Return: + { + auto retVal = inst->getOperand(0); + auto addrSpace = getAddrSpace(retVal); + if (addrSpace != AddressSpace::Generic) + { + auto funcType = as<IRFuncType>(func->getDataType()); + auto ptrResultType = as<IRPtrTypeBase>(funcType->getResultType()); + SLANG_ASSERT(ptrResultType); + AddressSpace resultAddrSpace = getFuncResultAddrSpace(func); + if (resultAddrSpace != addrSpace) + { + IRBuilder builder(func); + auto newResultType = builder.getPtrType(ptrResultType->getOp(), ptrResultType->getValueType(), addrSpace); + fixUpFuncType(func, newResultType); + retValAddrSpaceChanged = true; + } + } + } + break; + } + } + } + } + return retValAddrSpaceChanged; + } + + static void setDataType(IRInst* inst, IRType* dataType) + { + auto rate = inst->getRate(); + if (!rate) + inst->setFullType(dataType); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newType = builder.getRateQualifiedType(rate, dataType); + inst->setFullType(newType); + } + + void applyAddressSpaceToInstType() + { + for (auto [inst, addrSpace] : mapInstToAddrSpace) + { + auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); + if (ptrType) + { + IRBuilder builder(inst); + auto newType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace); + setDataType(inst, newType); + } + } + } + + void processModule() + { + for (auto globalInst : module->getGlobalInsts()) + { + auto addrSpace = getLeafInstAddressSpace(globalInst); + if (addrSpace != AddressSpace::Generic) + { + mapInstToAddrSpace[globalInst] = addrSpace; + } + if (auto func = as<IRFunc>(globalInst)) + { + if (func->findDecoration<IREntryPointDecoration>()) + workList.add(func); + } + } + + HashSet<IRFunc*> newWorkList; + while (workList.getCount()) + { + for (Index i = 0; i < workList.getCount(); i++) + { + auto func = workList[i]; + bool resultTypeChanged = processFunction(func); + if (resultTypeChanged) + { + for (auto use = func->firstUse; use; use = use->nextUse) + { + if (auto callInst = as<IRCall>(use->getUser())) + { + newWorkList.add(getParentFunc(callInst)); + } + } + } + } + workList.clear(); + for (auto f : newWorkList) + workList.add(f); + } + + applyAddressSpaceToInstType(); + } + }; + + void specializeAddressSpace(IRModule* module) + { + AddressSpaceContext context(module); + context.processModule(); + } +} diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h new file mode 100644 index 000000000..d74a59efa --- /dev/null +++ b/source/slang/slang-ir-specialize-address-space.h @@ -0,0 +1,14 @@ +// slang-ir-specialize-address-space.h +#pragma once + +namespace Slang +{ + struct IRModule; + + /// Propagate address space information through the IR module. + /// Specialize functions with reference/pointer parameters to use the correct address space + /// based on the address space of the arguments. + /// + void specializeAddressSpace( + IRModule* module); +} |
