summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-03 22:16:49 -0800
committerGitHub <noreply@github.com>2024-03-03 22:16:49 -0800
commitf8c54056048f38369ac93b5da5b823a6f758e227 (patch)
treee9d8e2e1b0c56a8ef7db697c40f2080a7c90cfcc
parenta4919e3e16d6958b70d665ed682aae910ecf1d4b (diff)
Fix SPIRV pointer codegen. (#3664)
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp79
-rw-r--r--tests/spirv/pointer-array.slang25
2 files changed, 74 insertions, 30 deletions
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index bc571cba5..5589510ed 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -748,7 +748,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
- oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction);
+ oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassFunction);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
}
@@ -793,7 +793,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newPtrType = builder.getPtrType(
- oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassPhysicalStorageBuffer);
+ oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassPhysicalStorageBuffer);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
}
@@ -806,6 +806,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (!oldPtrType)
return;
+ // Update the pointer value type with storage-buffer-address-space-decorated types.
+ auto newPtrValueType = translateToStorageBufferPointer(oldPtrType->getValueType());
+ if (newPtrValueType != oldPtrType->getValueType())
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ IRType* newPtrType = oldPtrType->hasAddressSpace()
+ ? builder.getPtrType(oldPtrType->getOp(), newPtrValueType, oldPtrType->getAddressSpace())
+ : builder.getPtrType(oldPtrType->getOp(), newPtrValueType);
+ inst->setFullType(newPtrType);
+ }
+
// If the pointer type is already qualified with address spaces (such as
// lowered pointer type from a `HLSLStructuredBufferType`), make no
// further modifications.
@@ -847,7 +859,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto newPtrType =
- builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass);
+ builder.getPtrType(oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), storageClass);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
return;
@@ -870,7 +882,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto qualPtrType = builder.getPtrType(
- ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass);
+ ptrType->getOp(), translateToStorageBufferPointer(ptrType->getValueType()), snippet->resultStorageClass);
List<IRInst*> args;
for (UInt i = 0; i < inst->getArgCount(); i++)
args.add(inst->getArg(i));
@@ -958,7 +970,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
// If we reach here, we need to allocate a temp var.
- auto tempVar = builder.emitVar(ptrType->getValueType());
+ auto tempVar = builder.emitVar(translateToStorageBufferPointer(ptrType->getValueType()));
auto load = builder.emitLoad(arg);
builder.emitStore(tempVar, load);
newArgs.add(tempVar);
@@ -1016,7 +1028,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(inst);
else
setInsertAfterOrdinaryInst(&builder, x);
- y = builder.emitVar(x->getDataType(), SpvStorageClassFunction);
+ y = builder.emitVar(translateToStorageBufferPointer(x->getDataType()), SpvStorageClassFunction);
builder.emitStore(y, x);
if (x->getParent()->getOp() != kIROp_Module)
m_mapArrayValueToVar.set(x, y);
@@ -1043,7 +1055,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(gepInst);
auto newPtrType = builder.getPtrType(
oldResultType->getOp(),
- oldResultType->getValueType(),
+ translateToStorageBufferPointer(oldResultType->getValueType()),
ptrType->getAddressSpace());
IRInst* args[2] = { base, index };
auto newInst =
@@ -1080,7 +1092,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(offsetPtrInst);
builder.setInsertBefore(offsetPtrInst);
auto newResultType = builder.getPtrType(resultPtrType->getOp(),
- resultPtrType->getValueType(),
+ translateToStorageBufferPointer(resultPtrType->getValueType()),
ptrOperandType->getAddressSpace());
auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType);
addUsersToWorkList(newInst);
@@ -1095,7 +1107,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(loadInst);
IRInst* args[] = { sb, index };
auto addrInst = builder.emitIntrinsicInst(
- builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), SpvStorageClassStorageBuffer),
+ builder.getPtrType(kIROp_PtrType, translateToStorageBufferPointer(loadInst->getFullType()), SpvStorageClassStorageBuffer),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1611,19 +1623,34 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
addToWorkList(branch->getOperand(0));
}
- IRType* translateToStorageBufferPointer(IRType* pointerType)
+ // If type is pointer type and does not have an address space, make it a
+ // storage buffer pointer.
+ IRType* translateToStorageBufferPointer(IRType* type)
{
- auto ptrType = as<IRPtrType>(pointerType);
- if (!ptrType)
- return pointerType;
- auto oldValueType = ptrType->getValueType();
- auto newValueType = translateToStorageBufferPointer(oldValueType);
- if (oldValueType != newValueType || !ptrType->hasAddressSpace())
+ if (auto ptrType = as<IRPtrType>(type))
+ {
+ auto oldValueType = ptrType->getValueType();
+ auto newValueType = translateToStorageBufferPointer(oldValueType);
+ if (oldValueType != newValueType || !ptrType->hasAddressSpace())
+ {
+ IRBuilder builder(m_module);
+ return builder.getPtrType(ptrType->getOp(), newValueType,
+ ptrType->hasAddressSpace() ? ptrType->getAddressSpace() : SpvStorageClassPhysicalStorageBuffer);
+ }
+ return ptrType;
+ }
+ else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
{
- IRBuilder builder(m_module);
- return builder.getPtrType(ptrType->getOp(), newValueType, SpvStorageClassPhysicalStorageBuffer);
+ auto oldValueType = arrayTypeBase->getElementType();
+ auto newValueType = translateToStorageBufferPointer(oldValueType);
+ if (oldValueType != newValueType)
+ {
+ IRBuilder builder(m_module);
+ return builder.getArrayTypeBase(arrayTypeBase->getOp(), newValueType, arrayTypeBase->getElementCount());
+ }
+ return arrayTypeBase;
}
- return ptrType;
+ return type;
}
void translatePtrResultType(IRInst* inst)
@@ -1659,17 +1686,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processStructField(IRStructField* field)
{
- auto ptrType = as<IRPtrTypeBase>(field->getFieldType());
- if (!ptrType)
- return;
- if (ptrType->hasAddressSpace())
- return;
- IRBuilder builder(field);
- auto newPtrType = builder.getPtrType(
- ptrType->getOp(),
- ptrType->getValueType(),
- SpvStorageClassPhysicalStorageBuffer);
- field->setFieldType(newPtrType);
+ auto newFieldType = translateToStorageBufferPointer(field->getFieldType());
+ if (newFieldType != field->getFieldType())
+ field->setFieldType(newFieldType);
}
void processComparison(IRInst* inst)
diff --git a/tests/spirv/pointer-array.slang b/tests/spirv/pointer-array.slang
new file mode 100644
index 000000000..6c4050536
--- /dev/null
+++ b/tests/spirv/pointer-array.slang
@@ -0,0 +1,25 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage compute -emit-spirv-directly
+
+
+struct Tester
+{
+ uint i;
+};
+
+struct Push
+{
+ Tester* ptr_array[2];
+ uint * out_ptr;
+};
+
+[[vk::push_constant]] Push p;
+
+// CHECK: OpEntryPoint
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(int id : SV_DispatchThreadID)
+{
+ uint i = p.ptr_array[0].i;
+ *p.out_ptr = i;
+} \ No newline at end of file