diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-10 16:17:10 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 16:17:10 -0700 |
| commit | 746d47bb491e0b97e35ab373b4b78d33b9a61164 (patch) | |
| tree | 74e0936472d911d8c6c561ca4b21e800306c5f51 /source | |
| parent | 82f308ca692878bfe9844b86629c6536b4cd0f0a (diff) | |
Specialize address space during spirv legalization. (#4600)
* Specialize address space during spirv legalization.
* Fix.
* Fix building doc.
* Fix cmake.
* Update assert.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 69 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 84 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 2 |
7 files changed, 162 insertions, 60 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8c7232963..1040017da 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4905,7 +4905,7 @@ struct SPIRVEmitContext const SpvWord baseId = getID(ensureInst(base)); // We might replace resultType with a different storage class equivalent - auto resultType = as<IRPtrTypeBase>(inst->getFullType()); + auto resultType = as<IRPtrTypeBase>(inst->getDataType()); SLANG_ASSERT(resultType); if(const auto basePtrType = as<IRPtrTypeBase>(base->getDataType())) diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 0e23460ce..6ff52688b 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -3279,14 +3279,15 @@ void legalizeEntryPointForGLSL( context.sink = codeGenContext->getSink(); context.glslExtensionTracker = glslExtensionTracker; - // We require that the entry-point function has no uses, + // We require that the entry-point function has no calls, // because otherwise we'd invalidate the signature // at all existing call sites. // // TODO: the right thing to do here is to split any // function that both gets called as an entry point // and as an ordinary function. - SLANG_ASSERT(!func->firstUse); + for (auto use = func->firstUse; use; use = use->nextUse) + SLANG_ASSERT(use->getUser()->getOp() != kIROp_Call); // Require SPIRV version based on the stage. switch (stage) diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 6b333f892..bff7363b3 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1666,6 +1666,72 @@ namespace Slang } } + struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + switch (inst->getOp()) + { + case kIROp_Var: + outAddressSpace = AddressSpace::ThreadLocal; + return true; + case kIROp_RWStructuredBufferGetElementPtr: + outAddressSpace = AddressSpace::Global; + return true; + default: + return false; + } + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (as<IRUniformParameterGroupType>(type)) + { + return AddressSpace::Uniform; + } + if (as<IRByteAddressBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRHLSLStructuredBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRGLSLShaderStorageBufferType>(type)) + { + return AddressSpace::Global; + } + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return (AddressSpace)ptrType->getAddressSpace(); + return AddressSpace::Global; + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + if (as<IRGroupSharedRate>(inst->getRate())) + return AddressSpace::GroupShared; + switch (inst->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + return AddressSpace::Global; + case kIROp_Var: + if (as<IRBlock>(inst->getParent())) + return AddressSpace::ThreadLocal; + break; + default: + break; + } + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } + }; + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -1689,7 +1755,8 @@ namespace Slang context.legalizeEntryPointForMetal(entryPoint); context.removeSemanticLayoutsFromLegalizedStructs(); - specializeAddressSpace(module); + MetalAddressSpaceAssigner metalAddressSpaceAssigner; + specializeAddressSpace(module, &metalAddressSpaceAssigner); } } diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 55d61d527..1d899e240 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -7,66 +7,30 @@ namespace Slang { - struct AddressSpaceContext + struct AddressSpaceContext : public AddressSpaceSpecializationContext { IRModule* module; Dictionary<IRInst*, AddressSpace> mapInstToAddrSpace; + InitialAddressSpaceAssigner* addrSpaceAssigner; - AddressSpaceContext(IRModule* inModule) + AddressSpaceContext(IRModule* inModule, InitialAddressSpaceAssigner* inAddrSpaceAssigner) : module(inModule) + , addrSpaceAssigner(inAddrSpaceAssigner) { } AddressSpace getAddressSpaceFromVarType(IRInst* type) { - if (as<IRUniformParameterGroupType>(type)) - { - return AddressSpace::Uniform; - } - if (as<IRByteAddressBufferTypeBase>(type)) - { - return AddressSpace::Global; - } - if (as<IRHLSLStructuredBufferTypeBase>(type)) - { - return AddressSpace::Global; - } - if (as<IRGLSLShaderStorageBufferType>(type)) - { - return AddressSpace::Global; - } - if (auto ptrType = as<IRPtrTypeBase>(type)) - { - if (ptrType->hasAddressSpace()) - return (AddressSpace)ptrType->getAddressSpace(); - return AddressSpace::Global; - } - return AddressSpace::Generic; + return addrSpaceAssigner->getAddressSpaceFromVarType(type); } AddressSpace getLeafInstAddressSpace(IRInst* inst) { - if (as<IRGroupSharedRate>(inst->getRate())) - return AddressSpace::GroupShared; - switch (inst->getOp()) - { - case kIROp_RWStructuredBufferGetElementPtr: - return AddressSpace::Global; - case kIROp_Var: - if (as<IRBlock>(inst->getParent())) - return AddressSpace::ThreadLocal; - break; - default: - break; - } - auto type = unwrapAttributedType(inst->getDataType()); - if (!type) - return AddressSpace::Generic; - return getAddressSpaceFromVarType(type); + return addrSpaceAssigner->getLeafInstAddressSpace(inst); } - AddressSpace getAddrSpace(IRInst* inst) + AddressSpace getAddrSpace(IRInst* inst) override { auto addrSpace = mapInstToAddrSpace.tryGetValue(inst); if (addrSpace) @@ -186,20 +150,29 @@ namespace Slang continue; } + // If the inst already has a pointer type with explicit address space, then use it. + if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType())) + { + if (ptrType->hasAddressSpace()) + { + mapInstToAddrSpace[inst] = (AddressSpace)ptrType->getAddressSpace(); + continue; + } + } + + // Otherwise, try to assign an address space based on the instruction type. switch (inst->getOp()) { case kIROp_Var: - { - // All local variables should be in the thread-local address space. - mapInstToAddrSpace[inst] = AddressSpace::ThreadLocal; - changed = true; - break; - } case kIROp_RWStructuredBufferGetElementPtr: { - // The address space of the result of RWStructuredBufferGetElementPtr is always global. - mapInstToAddrSpace[inst] = AddressSpace::Global; - changed = true; + // The address space of these insts should be assigned by the initial address space assigner. + AddressSpace addrSpace = AddressSpace::Generic; + if (addrSpaceAssigner->tryAssignAddressSpace(inst, addrSpace)) + { + mapInstToAddrSpace[inst] = addrSpace; + changed = true; + } break; } case kIROp_GetElementPtr: @@ -340,7 +313,10 @@ namespace Slang { auto rate = inst->getRate(); if (!rate) + { inst->setFullType(dataType); + return; + } IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -405,9 +381,9 @@ namespace Slang } }; - void specializeAddressSpace(IRModule* module) + void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrSpaceAssigner) { - AddressSpaceContext context(module); + AddressSpaceContext context(module, addrSpaceAssigner); context.processModule(); } } diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h index d74a59efa..300b6129c 100644 --- a/source/slang/slang-ir-specialize-address-space.h +++ b/source/slang/slang-ir-specialize-address-space.h @@ -4,11 +4,27 @@ namespace Slang { struct IRModule; + struct IRInst; + enum class AddressSpace; + + struct AddressSpaceSpecializationContext + { + public: + virtual AddressSpace getAddrSpace(IRInst* inst) = 0; + }; + + struct InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) = 0; + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) = 0; + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) = 0; + }; /// Propagate address space information through the IR module. /// Specialize functions with reference/pointer parameters to use the correct address space /// based on the address space of the arguments. /// void specializeAddressSpace( - IRModule* module); + IRModule* module, + InitialAddressSpaceAssigner* addrSpaceAssigner); } diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index d7b980bf8..27a186ee9 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -22,6 +22,7 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-lower-buffer-element-type.h" +#include "slang-ir-specialize-address-space.h" namespace Slang { @@ -1034,6 +1035,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } else { + // If we reach here, we have determined that all arguments passed as a pointer + // are actual memory objects, so they can be passed in as-is. + // We still need to make sure the callee is specialized to the address-space + // of the arguments, this is done in a separate specialization pass. + translatePtrResultType(inst); } } @@ -2230,6 +2236,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + struct SpirvAddressSpaceAssigner : InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + SLANG_UNUSED(inst); + // Don't assign address space to additional insts, since we should have + // already assigned address space to them in earlier stages of legalization. + outAddressSpace = AddressSpace::Generic; + return false; + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return (AddressSpace)ptrType->getAddressSpace(); + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + // Don't assign address space to additional insts, since we should have + // already assigned address space to them in earlier stages of legalization. + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } + }; + void processModule() { determineSpirvVersion(); @@ -2332,6 +2370,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can // safely lower the pointer load stores early together with other buffer types. lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true); + + // Specalize address space for all pointers. + SpirvAddressSpaceAssigner addressSpaceAssigner; + specializeAddressSpace(m_module, &addressSpaceAssigner); } void updateFunctionTypes() diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 7c04729b0..cf9dfa84c 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -41,7 +41,7 @@ struct IRStructKey; enum class AddressSpace { - Generic = 0, + Generic = 0x7fffffff, ThreadLocal = 1, Global = 2, GroupShared = 3, |
