summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-10 16:17:10 -0700
committerGitHub <noreply@github.com>2024-07-10 16:17:10 -0700
commit746d47bb491e0b97e35ab373b4b78d33b9a61164 (patch)
tree74e0936472d911d8c6c561ca4b21e800306c5f51 /source
parent82f308ca692878bfe9844b86629c6536b4cd0f0a (diff)
Specialize address space during spirv legalization. (#4600)
* Specialize address space during spirv legalization. * Fix. * Fix building doc. * Fix cmake. * Update assert.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-spirv.cpp2
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp5
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp69
-rw-r--r--source/slang/slang-ir-specialize-address-space.cpp84
-rw-r--r--source/slang/slang-ir-specialize-address-space.h18
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp42
-rw-r--r--source/slang/slang-ir.h2
7 files changed, 162 insertions, 60 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8c7232963..1040017da 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4905,7 +4905,7 @@ struct SPIRVEmitContext
const SpvWord baseId = getID(ensureInst(base));
// We might replace resultType with a different storage class equivalent
- auto resultType = as<IRPtrTypeBase>(inst->getFullType());
+ auto resultType = as<IRPtrTypeBase>(inst->getDataType());
SLANG_ASSERT(resultType);
if(const auto basePtrType = as<IRPtrTypeBase>(base->getDataType()))
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 0e23460ce..6ff52688b 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -3279,14 +3279,15 @@ void legalizeEntryPointForGLSL(
context.sink = codeGenContext->getSink();
context.glslExtensionTracker = glslExtensionTracker;
- // We require that the entry-point function has no uses,
+ // We require that the entry-point function has no calls,
// because otherwise we'd invalidate the signature
// at all existing call sites.
//
// TODO: the right thing to do here is to split any
// function that both gets called as an entry point
// and as an ordinary function.
- SLANG_ASSERT(!func->firstUse);
+ for (auto use = func->firstUse; use; use = use->nextUse)
+ SLANG_ASSERT(use->getUser()->getOp() != kIROp_Call);
// Require SPIRV version based on the stage.
switch (stage)
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 6b333f892..bff7363b3 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1666,6 +1666,72 @@ namespace Slang
}
}
+ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner
+ {
+ virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ outAddressSpace = AddressSpace::ThreadLocal;
+ return true;
+ case kIROp_RWStructuredBufferGetElementPtr:
+ outAddressSpace = AddressSpace::Global;
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override
+ {
+ if (as<IRUniformParameterGroupType>(type))
+ {
+ return AddressSpace::Uniform;
+ }
+ if (as<IRByteAddressBufferTypeBase>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (as<IRHLSLStructuredBufferTypeBase>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (as<IRGLSLShaderStorageBufferType>(type))
+ {
+ return AddressSpace::Global;
+ }
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ if (ptrType->hasAddressSpace())
+ return (AddressSpace)ptrType->getAddressSpace();
+ return AddressSpace::Global;
+ }
+ return AddressSpace::Generic;
+ }
+
+ virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override
+ {
+ if (as<IRGroupSharedRate>(inst->getRate()))
+ return AddressSpace::GroupShared;
+ switch (inst->getOp())
+ {
+ case kIROp_RWStructuredBufferGetElementPtr:
+ return AddressSpace::Global;
+ case kIROp_Var:
+ if (as<IRBlock>(inst->getParent()))
+ return AddressSpace::ThreadLocal;
+ break;
+ default:
+ break;
+ }
+ auto type = unwrapAttributedType(inst->getDataType());
+ if (!type)
+ return AddressSpace::Generic;
+ return getAddressSpaceFromVarType(type);
+ }
+ };
+
void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
@@ -1689,7 +1755,8 @@ namespace Slang
context.legalizeEntryPointForMetal(entryPoint);
context.removeSemanticLayoutsFromLegalizedStructs();
- specializeAddressSpace(module);
+ MetalAddressSpaceAssigner metalAddressSpaceAssigner;
+ specializeAddressSpace(module, &metalAddressSpaceAssigner);
}
}
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp
index 55d61d527..1d899e240 100644
--- a/source/slang/slang-ir-specialize-address-space.cpp
+++ b/source/slang/slang-ir-specialize-address-space.cpp
@@ -7,66 +7,30 @@
namespace Slang
{
- struct AddressSpaceContext
+ struct AddressSpaceContext : public AddressSpaceSpecializationContext
{
IRModule* module;
Dictionary<IRInst*, AddressSpace> mapInstToAddrSpace;
+ InitialAddressSpaceAssigner* addrSpaceAssigner;
- AddressSpaceContext(IRModule* inModule)
+ AddressSpaceContext(IRModule* inModule, InitialAddressSpaceAssigner* inAddrSpaceAssigner)
: module(inModule)
+ , addrSpaceAssigner(inAddrSpaceAssigner)
{
}
AddressSpace getAddressSpaceFromVarType(IRInst* type)
{
- if (as<IRUniformParameterGroupType>(type))
- {
- return AddressSpace::Uniform;
- }
- if (as<IRByteAddressBufferTypeBase>(type))
- {
- return AddressSpace::Global;
- }
- if (as<IRHLSLStructuredBufferTypeBase>(type))
- {
- return AddressSpace::Global;
- }
- if (as<IRGLSLShaderStorageBufferType>(type))
- {
- return AddressSpace::Global;
- }
- if (auto ptrType = as<IRPtrTypeBase>(type))
- {
- if (ptrType->hasAddressSpace())
- return (AddressSpace)ptrType->getAddressSpace();
- return AddressSpace::Global;
- }
- return AddressSpace::Generic;
+ return addrSpaceAssigner->getAddressSpaceFromVarType(type);
}
AddressSpace getLeafInstAddressSpace(IRInst* inst)
{
- if (as<IRGroupSharedRate>(inst->getRate()))
- return AddressSpace::GroupShared;
- switch (inst->getOp())
- {
- case kIROp_RWStructuredBufferGetElementPtr:
- return AddressSpace::Global;
- case kIROp_Var:
- if (as<IRBlock>(inst->getParent()))
- return AddressSpace::ThreadLocal;
- break;
- default:
- break;
- }
- auto type = unwrapAttributedType(inst->getDataType());
- if (!type)
- return AddressSpace::Generic;
- return getAddressSpaceFromVarType(type);
+ return addrSpaceAssigner->getLeafInstAddressSpace(inst);
}
- AddressSpace getAddrSpace(IRInst* inst)
+ AddressSpace getAddrSpace(IRInst* inst) override
{
auto addrSpace = mapInstToAddrSpace.tryGetValue(inst);
if (addrSpace)
@@ -186,20 +150,29 @@ namespace Slang
continue;
}
+ // If the inst already has a pointer type with explicit address space, then use it.
+ if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType()))
+ {
+ if (ptrType->hasAddressSpace())
+ {
+ mapInstToAddrSpace[inst] = (AddressSpace)ptrType->getAddressSpace();
+ continue;
+ }
+ }
+
+ // Otherwise, try to assign an address space based on the instruction type.
switch (inst->getOp())
{
case kIROp_Var:
- {
- // All local variables should be in the thread-local address space.
- mapInstToAddrSpace[inst] = AddressSpace::ThreadLocal;
- changed = true;
- break;
- }
case kIROp_RWStructuredBufferGetElementPtr:
{
- // The address space of the result of RWStructuredBufferGetElementPtr is always global.
- mapInstToAddrSpace[inst] = AddressSpace::Global;
- changed = true;
+ // The address space of these insts should be assigned by the initial address space assigner.
+ AddressSpace addrSpace = AddressSpace::Generic;
+ if (addrSpaceAssigner->tryAssignAddressSpace(inst, addrSpace))
+ {
+ mapInstToAddrSpace[inst] = addrSpace;
+ changed = true;
+ }
break;
}
case kIROp_GetElementPtr:
@@ -340,7 +313,10 @@ namespace Slang
{
auto rate = inst->getRate();
if (!rate)
+ {
inst->setFullType(dataType);
+ return;
+ }
IRBuilder builder(inst);
builder.setInsertBefore(inst);
@@ -405,9 +381,9 @@ namespace Slang
}
};
- void specializeAddressSpace(IRModule* module)
+ void specializeAddressSpace(IRModule* module, InitialAddressSpaceAssigner* addrSpaceAssigner)
{
- AddressSpaceContext context(module);
+ AddressSpaceContext context(module, addrSpaceAssigner);
context.processModule();
}
}
diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h
index d74a59efa..300b6129c 100644
--- a/source/slang/slang-ir-specialize-address-space.h
+++ b/source/slang/slang-ir-specialize-address-space.h
@@ -4,11 +4,27 @@
namespace Slang
{
struct IRModule;
+ struct IRInst;
+ enum class AddressSpace;
+
+ struct AddressSpaceSpecializationContext
+ {
+ public:
+ virtual AddressSpace getAddrSpace(IRInst* inst) = 0;
+ };
+
+ struct InitialAddressSpaceAssigner
+ {
+ virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) = 0;
+ virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) = 0;
+ virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) = 0;
+ };
/// Propagate address space information through the IR module.
/// Specialize functions with reference/pointer parameters to use the correct address space
/// based on the address space of the arguments.
///
void specializeAddressSpace(
- IRModule* module);
+ IRModule* module,
+ InitialAddressSpaceAssigner* addrSpaceAssigner);
}
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index d7b980bf8..27a186ee9 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -22,6 +22,7 @@
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-loop-unroll.h"
#include "slang-ir-lower-buffer-element-type.h"
+#include "slang-ir-specialize-address-space.h"
namespace Slang
{
@@ -1034,6 +1035,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
else
{
+ // If we reach here, we have determined that all arguments passed as a pointer
+ // are actual memory objects, so they can be passed in as-is.
+ // We still need to make sure the callee is specialized to the address-space
+ // of the arguments, this is done in a separate specialization pass.
+
translatePtrResultType(inst);
}
}
@@ -2230,6 +2236,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
+ struct SpirvAddressSpaceAssigner : InitialAddressSpaceAssigner
+ {
+ virtual bool tryAssignAddressSpace(IRInst* inst, AddressSpace& outAddressSpace) override
+ {
+ SLANG_UNUSED(inst);
+ // Don't assign address space to additional insts, since we should have
+ // already assigned address space to them in earlier stages of legalization.
+ outAddressSpace = AddressSpace::Generic;
+ return false;
+ }
+
+ virtual AddressSpace getAddressSpaceFromVarType(IRInst* type) override
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ if (ptrType->hasAddressSpace())
+ return (AddressSpace)ptrType->getAddressSpace();
+ }
+ return AddressSpace::Generic;
+ }
+
+ virtual AddressSpace getLeafInstAddressSpace(IRInst* inst) override
+ {
+ // Don't assign address space to additional insts, since we should have
+ // already assigned address space to them in earlier stages of legalization.
+ auto type = unwrapAttributedType(inst->getDataType());
+ if (!type)
+ return AddressSpace::Generic;
+ return getAddressSpaceFromVarType(type);
+ }
+ };
+
void processModule()
{
determineSpirvVersion();
@@ -2332,6 +2370,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can
// safely lower the pointer load stores early together with other buffer types.
lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);
+
+ // Specalize address space for all pointers.
+ SpirvAddressSpaceAssigner addressSpaceAssigner;
+ specializeAddressSpace(m_module, &addressSpaceAssigner);
}
void updateFunctionTypes()
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 7c04729b0..cf9dfa84c 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -41,7 +41,7 @@ struct IRStructKey;
enum class AddressSpace
{
- Generic = 0,
+ Generic = 0x7fffffff,
ThreadLocal = 1,
Global = 2,
GroupShared = 3,