diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-10 16:17:10 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 16:17:10 -0700 |
| commit | 746d47bb491e0b97e35ab373b4b78d33b9a61164 (patch) | |
| tree | 74e0936472d911d8c6c561ca4b21e800306c5f51 | |
| parent | 82f308ca692878bfe9844b86629c6536b4cd0f0a (diff) | |
Specialize address space during spirv legalization. (#4600)
* Specialize address space during spirv legalization.
* Fix.
* Fix building doc.
* Fix cmake.
* Update assert.
| -rw-r--r-- | CMakeLists.txt | 1 | ||||
| -rw-r--r-- | cmake/CompilerFlags.cmake | 7 | ||||
| -rw-r--r-- | docs/building.md | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 69 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.cpp | 84 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-address-space.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 2 | ||||
| -rw-r--r-- | tests/spirv/address-space-specialize.slang | 33 |
11 files changed, 204 insertions, 60 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d9c8c0be..70e204e7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ option(SLANG_EMBED_STDLIB_SOURCE "Embed stdlib source in the binary" ON) option(SLANG_EMBED_STDLIB "Build slang with an embedded version of the stdlib") option(SLANG_ENABLE_FULL_IR_VALIDATION "Enable full IR validation (SLOW!)") +option(SLANG_ENABLE_IR_BREAK_ALLOC, "Enable _debugUID on IR allocation") option(SLANG_ENABLE_ASAN "Enable ASAN (address sanitizer)") option(SLANG_ENABLE_PREBUILT_BINARIES "Enable using prebuilt binaries" ON) diff --git a/cmake/CompilerFlags.cmake b/cmake/CompilerFlags.cmake index 2a9988d4e..cd9021cd0 100644 --- a/cmake/CompilerFlags.cmake +++ b/cmake/CompilerFlags.cmake @@ -195,6 +195,13 @@ function(set_default_compile_options target) ) endif() + if(SLANG_ENABLE_IR_BREAK_ALLOC) + target_compile_definitions( + ${target} + PRIVATE SLANG_ENABLE_IR_BREAK_ALLOC + ) + endif() + if(SLANG_ENABLE_DX_ON_VK) target_compile_definitions(${target} PRIVATE SLANG_CONFIG_DX_ON_VK) endif() diff --git a/docs/building.md b/docs/building.md index aba486a4e..a42f4cde6 100644 --- a/docs/building.md +++ b/docs/building.md @@ -64,6 +64,7 @@ See the [documentation on testing](../tools/slang-test/README.md) for more infor | `SLANG_EMBED_STDLIB_SOURCE` | `TRUE` | Embed stdlib source in the binary | | `SLANG_ENABLE_ASAN` | `FALSE` | Enable ASAN (address sanitizer) | | `SLANG_ENABLE_FULL_IR_VALIDATION` | `FALSE` | Enable full IR validation (SLOW!) | +| `SLANG_ENABLE_IR_BREAK_ALLOC` | `FALSE` | Enable IR BreakAlloc functionality for debugging. | | `SLANG_ENABLE_GFX` | `TRUE` | Enable gfx targets | | `SLANG_ENABLE_SLANGD` | `TRUE` | Enable language server target | | `SLANG_ENABLE_SLANGC` | `TRUE` | Enable standalone compiler target | diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8c7232963..1040017da 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4905,7 +4905,7 @@ struct SPIRVEmitContext const SpvWord baseId = getID(ensureInst(base)); // We might replace resultType with a different storage class equivalent - auto resultType = as<IRPtrTypeBase>(inst->getFullType()); + auto resultType = as<IRPtrTypeBase>(inst->getDataType()); SLANG_ASSERT(resultType); if(const auto basePtrType = as<IRPtrTypeBase>(base->getDataType())) diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 0e23460ce..6ff52688b 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -3279,14 +3279,15 @@ void legalizeEntryPointForGLSL( context.sink = codeGenContext->getSink(); context.glslExtensionTracker = glslExtensionTracker; - // We require that the entry-point function has no uses, + // We require that the entry-point function has no calls, // because otherwise we'd invalidate the signature // at all existing call sites. // // TODO: the right thing to do here is to split any // function that both gets called as an entry point // and as an ordinary function. - SLANG_ASSERT(!func->firstUse); + for (auto use = func->firstUse; use; use = use->nextUse) + SLANG_ASSERT(use->getUser()->getOp() != kIROp_Call); // Require SPIRV version based on the stage. switch (stage) diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 6b333f892..bff7363b3 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1666,6 +1666,72 @@ namespace Slang } } + struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + switch (inst->getOp()) + { + case kIROp_Var: + outAddressSpace = AddressSpace::ThreadLocal; + return true; + case kIROp_RWStructuredBufferGetElementPtr: + outAddressSpace = AddressSpace::Global; + return true; + default: + return false; + } + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (as<IRUniformParameterGroupType>(type)) + { + return AddressSpace::Uniform; + } + if (as<IRByteAddressBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRHLSLStructuredBufferTypeBase>(type)) + { + return AddressSpace::Global; + } + if (as<IRGLSLShaderStorageBufferType>(type)) + { + return AddressSpace::Global; + } + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return (AddressSpace)ptrType->getAddressSpace(); + return AddressSpace::Global; + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + if (as<IRGroupSharedRate>(inst->getRate())) + return AddressSpace::GroupShared; + switch (inst->getOp()) + { + case kIROp_RWStructuredBufferGetElementPtr: + return AddressSpace::Global; + case kIROp_Var: + if (as<IRBlock>(inst->getParent())) + return AddressSpace::ThreadLocal; + break; + default: + break; + } + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } + }; + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -1689,7 +1755,8 @@ namespace Slang context.legalizeEntryPointForMetal(entryPoint); context.removeSemanticLayoutsFromLegalizedStructs(); - specializeAddressSpace(module); + MetalAddressSpaceAssigner metalAddressSpaceAssigner; + specializeAddressSpace(module, &metalAddressSpaceAssigner); } } diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 55d61d527..1d899e240 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -7,66 +7,30 @@ namespace Slang { - struct AddressSpaceContext + struct AddressSpaceContext : public AddressSpaceSpecializationContext { IRModule* module; Dictionary<IRInst*, AddressSpace> mapInstToAddrSpace; + InitialAddressSpaceAssigner* addrSpaceAssigner; - AddressSpaceContext(IRModule* inModule) + AddressSpaceContext(IRModule* inModule, InitialAddressSpaceAssigner* inAddrSpaceAssigner) : module(inModule) + , addrSpaceAssigner(inAddrSpaceAssigner) { } AddressSpace getAddressSpaceFromVarType(IRInst* type) { - if (as<IRUniformParameterGroupType>(type)) - { - return AddressSpace::Uniform; - } - if (as<IRByteAddressBufferTypeBase>(type)) - { - return AddressSpace::Global; - } - if (as<IRHLSLStructuredBufferTypeBase>(type)) - { - return AddressSpace::Global; - } - if (as<IRGLSLShaderStorageBufferType>(type)) - { - return AddressSpace::Global; - } - if (auto ptrType = as<IRPtrTypeBase>(type)) - { - if (ptrType->hasAddressSpace()) - return (AddressSpace)ptrType->getAddressSpace(); - return AddressSpace::Global; - } - return AddressSpace::Generic; + return addrSpaceAssigner->getAddressSpaceFromVarType(type); } AddressSpace getLeafInstAddressSpace(IRInst* inst) { - if (as<IRGroupSharedRate>(inst->getRate())) - return AddressSpace::GroupShared; - switch (inst->getOp()) - { - case kIROp_RWStructuredBufferGetElementPtr: - return AddressSpace::Global; - case kIROp_Var: - if (as<IRBlock>(inst->getParent())) - return AddressSpace::ThreadLocal; - break; - default: - break; - } - auto type = unwrapAttributedType(inst->getDataType()); - if (!type) - return AddressSpace::Generic; - return getAddressSpaceFromVarType(type); + return addrSpaceAssigner->getLeafInstAddressSpace(inst); } - AddressSpace getAddrSpace(IRInst* inst) + AddressSpace getAddrSpace(IRInst* inst) override { auto addrSpace = mapInstToAddrSpace.tryGetValue(inst); if (addrSpace) @@ -186,20 +150,29 @@ namespace Slang continue; } + // If the inst already has a pointer type with explicit address space, then use it. + if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType())) + { + if (ptrType->hasAddressSpace()) + { + mapInstToAddrSpace[inst] = (AddressSpace)ptrType->getAddressSpace(); + continue; + } + } + + // Otherwise, try to assign an address space based on the instruction type. switch (inst->getOp()) { case kIROp_Var: - { - // All local variables should be in the thread-local address space. - mapInstToAddrSpace[inst] = AddressSpace::ThreadLocal; - changed = true; - break; - } case kIROp_RWStructuredBufferGetElementPtr: { - // The address space of the result of RWStructuredBufferGetElementPtr is always global. - mapInstToAddrSpace[inst] = AddressSpace::Global; - changed = true; + // The address space of these insts should be assigned by the initial address space assigner. + AddressSpace addrSpace = AddressSpace::Generic; + if (addrSpaceAssigner->tryAssignAddressSpace(inst, addrSpace)) + { + mapInstToAddrSpace[inst] = addrSpace; + changed = true; + } break; } case kIROp_GetElementPtr: @@ -340,7 +313,10 @@ namespace Slang { auto rate = inst->getRate(); if (!rate) + { inst->setFullType(dataType); + return; + } IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -405,9 +381,9 @@ namespace Slang } }; - void specializeAddressSpace(IRModule* module) + void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrSpaceAssigner) { - AddressSpaceContext context(module); + AddressSpaceContext context(module, addrSpaceAssigner); context.processModule(); } } diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h index d74a59efa..300b6129c 100644 --- a/source/slang/slang-ir-specialize-address-space.h +++ b/source/slang/slang-ir-specialize-address-space.h @@ -4,11 +4,27 @@ namespace Slang { struct IRModule; + struct IRInst; + enum class AddressSpace; + + struct AddressSpaceSpecializationContext + { + public: + virtual AddressSpace getAddrSpace(IRInst* inst) = 0; + }; + + struct InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) = 0; + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) = 0; + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) = 0; + }; /// Propagate address space information through the IR module. /// Specialize functions with reference/pointer parameters to use the correct address space /// based on the address space of the arguments. /// void specializeAddressSpace( - IRModule* module); + IRModule* module, + InitialAddressSpaceAssigner* addrSpaceAssigner); } diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index d7b980bf8..27a186ee9 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -22,6 +22,7 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-lower-buffer-element-type.h" +#include "slang-ir-specialize-address-space.h" namespace Slang { @@ -1034,6 +1035,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } else { + // If we reach here, we have determined that all arguments passed as a pointer + // are actual memory objects, so they can be passed in as-is. + // We still need to make sure the callee is specialized to the address-space + // of the arguments, this is done in a separate specialization pass. + translatePtrResultType(inst); } } @@ -2230,6 +2236,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + struct SpirvAddressSpaceAssigner : InitialAddressSpaceAssigner + { + virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override + { + SLANG_UNUSED(inst); + // Don't assign address space to additional insts, since we should have + // already assigned address space to them in earlier stages of legalization. + outAddressSpace = AddressSpace::Generic; + return false; + } + + virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override + { + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + if (ptrType->hasAddressSpace()) + return (AddressSpace)ptrType->getAddressSpace(); + } + return AddressSpace::Generic; + } + + virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override + { + // Don't assign address space to additional insts, since we should have + // already assigned address space to them in earlier stages of legalization. + auto type = unwrapAttributedType(inst->getDataType()); + if (!type) + return AddressSpace::Generic; + return getAddressSpaceFromVarType(type); + } + }; + void processModule() { determineSpirvVersion(); @@ -2332,6 +2370,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can // safely lower the pointer load stores early together with other buffer types. lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true); + + // Specalize address space for all pointers. + SpirvAddressSpaceAssigner addressSpaceAssigner; + specializeAddressSpace(m_module, &addressSpaceAssigner); } void updateFunctionTypes() diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 7c04729b0..cf9dfa84c 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -41,7 +41,7 @@ struct IRStructKey; enum class AddressSpace { - Generic = 0, + Generic = 0x7fffffff, ThreadLocal = 1, Global = 2, GroupShared = 3, diff --git a/tests/spirv/address-space-specialize.slang b/tests/spirv/address-space-specialize.slang new file mode 100644 index 000000000..e2b48489a --- /dev/null +++ b/tests/spirv/address-space-specialize.slang @@ -0,0 +1,33 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry main -stage compute -emit-spirv-directly -O0 + +// Test that we can pass arguments in different address space to an `inout` parameter, and have +// the callee specialized to the address space of the argument. +// If successful, we should generate SPIRV that passes validation. + +static int gArray0[2]; +groupshared int gArray1[2]; + +// CHECK: %array = OpFunctionParameter %_ptr_Private__arr_int_int_2 +// CHECK: %array_0 = OpFunctionParameter %_ptr_Workgroup__arr_int_int_2 + +void modify(inout int array[2]) +{ + array[0] = 1; + array[1] = 2; +} + +void atomicOp(inout int array[2]) +{ + InterlockedAdd(array[0], 1); +} + +RWStructuredBuffer<int> output; + +[numthreads(1,1,1)] +void main() +{ + modify(gArray0); + modify(gArray1); + atomicOp(gArray1); + output[0] = gArray0[0] + gArray1[1]; +} |
