diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-03 22:16:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-03 22:16:49 -0800 |
| commit | f8c54056048f38369ac93b5da5b823a6f758e227 (patch) | |
| tree | e9d8e2e1b0c56a8ef7db697c40f2080a7c90cfcc | |
| parent | a4919e3e16d6958b70d665ed682aae910ecf1d4b (diff) | |
Fix SPIRV pointer codegen. (#3664)
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 79 | ||||
| -rw-r--r-- | tests/spirv/pointer-array.slang | 25 |
2 files changed, 74 insertions, 30 deletions
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index bc571cba5..5589510ed 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -748,7 +748,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction); + oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassFunction); inst->setFullType(newPtrType); addUsersToWorkList(inst); } @@ -793,7 +793,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassPhysicalStorageBuffer); + oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassPhysicalStorageBuffer); inst->setFullType(newPtrType); addUsersToWorkList(inst); } @@ -806,6 +806,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (!oldPtrType) return; + // Update the pointer value type with storage-buffer-address-space-decorated types. + auto newPtrValueType = translateToStorageBufferPointer(oldPtrType->getValueType()); + if (newPtrValueType != oldPtrType->getValueType()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* newPtrType = oldPtrType->hasAddressSpace() + ? builder.getPtrType(oldPtrType->getOp(), newPtrValueType, oldPtrType->getAddressSpace()) + : builder.getPtrType(oldPtrType->getOp(), newPtrValueType); + inst->setFullType(newPtrType); + } + // If the pointer type is already qualified with address spaces (such as // lowered pointer type from a `HLSLStructuredBufferType`), make no // further modifications. @@ -847,7 +859,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto newPtrType = - builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + builder.getPtrType(oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), storageClass); inst->setFullType(newPtrType); addUsersToWorkList(inst); return; @@ -870,7 +882,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto qualPtrType = builder.getPtrType( - ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass); + ptrType->getOp(), translateToStorageBufferPointer(ptrType->getValueType()), snippet->resultStorageClass); List<IRInst*> args; for (UInt i = 0; i < inst->getArgCount(); i++) args.add(inst->getArg(i)); @@ -958,7 +970,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } // If we reach here, we need to allocate a temp var. - auto tempVar = builder.emitVar(ptrType->getValueType()); + auto tempVar = builder.emitVar(translateToStorageBufferPointer(ptrType->getValueType())); auto load = builder.emitLoad(arg); builder.emitStore(tempVar, load); newArgs.add(tempVar); @@ -1016,7 +1028,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(inst); else setInsertAfterOrdinaryInst(&builder, x); - y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); + y = builder.emitVar(translateToStorageBufferPointer(x->getDataType()), SpvStorageClassFunction); builder.emitStore(y, x); if (x->getParent()->getOp() != kIROp_Module) m_mapArrayValueToVar.set(x, y); @@ -1043,7 +1055,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(gepInst); auto newPtrType = builder.getPtrType( oldResultType->getOp(), - oldResultType->getValueType(), + translateToStorageBufferPointer(oldResultType->getValueType()), ptrType->getAddressSpace()); IRInst* args[2] = { base, index }; auto newInst = @@ -1080,7 +1092,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(offsetPtrInst); builder.setInsertBefore(offsetPtrInst); auto newResultType = builder.getPtrType(resultPtrType->getOp(), - resultPtrType->getValueType(), + translateToStorageBufferPointer(resultPtrType->getValueType()), ptrOperandType->getAddressSpace()); auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType); addUsersToWorkList(newInst); @@ -1095,7 +1107,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(loadInst); IRInst* args[] = { sb, index }; auto addrInst = builder.emitIntrinsicInst( - builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), SpvStorageClassStorageBuffer), + builder.getPtrType(kIROp_PtrType, translateToStorageBufferPointer(loadInst->getFullType()), SpvStorageClassStorageBuffer), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1611,19 +1623,34 @@ struct SPIRVLegalizationContext : public SourceEmitterBase addToWorkList(branch->getOperand(0)); } - IRType* translateToStorageBufferPointer(IRType* pointerType) + // If type is pointer type and does not have an address space, make it a + // storage buffer pointer. + IRType* translateToStorageBufferPointer(IRType* type) { - auto ptrType = as<IRPtrType>(pointerType); - if (!ptrType) - return pointerType; - auto oldValueType = ptrType->getValueType(); - auto newValueType = translateToStorageBufferPointer(oldValueType); - if (oldValueType != newValueType || !ptrType->hasAddressSpace()) + if (auto ptrType = as<IRPtrType>(type)) + { + auto oldValueType = ptrType->getValueType(); + auto newValueType = translateToStorageBufferPointer(oldValueType); + if (oldValueType != newValueType || !ptrType->hasAddressSpace()) + { + IRBuilder builder(m_module); + return builder.getPtrType(ptrType->getOp(), newValueType, + ptrType->hasAddressSpace() ? ptrType->getAddressSpace() : SpvStorageClassPhysicalStorageBuffer); + } + return ptrType; + } + else if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) { - IRBuilder builder(m_module); - return builder.getPtrType(ptrType->getOp(), newValueType, SpvStorageClassPhysicalStorageBuffer); + auto oldValueType = arrayTypeBase->getElementType(); + auto newValueType = translateToStorageBufferPointer(oldValueType); + if (oldValueType != newValueType) + { + IRBuilder builder(m_module); + return builder.getArrayTypeBase(arrayTypeBase->getOp(), newValueType, arrayTypeBase->getElementCount()); + } + return arrayTypeBase; } - return ptrType; + return type; } void translatePtrResultType(IRInst* inst) @@ -1659,17 +1686,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processStructField(IRStructField* field) { - auto ptrType = as<IRPtrTypeBase>(field->getFieldType()); - if (!ptrType) - return; - if (ptrType->hasAddressSpace()) - return; - IRBuilder builder(field); - auto newPtrType = builder.getPtrType( - ptrType->getOp(), - ptrType->getValueType(), - SpvStorageClassPhysicalStorageBuffer); - field->setFieldType(newPtrType); + auto newFieldType = translateToStorageBufferPointer(field->getFieldType()); + if (newFieldType != field->getFieldType()) + field->setFieldType(newFieldType); } void processComparison(IRInst* inst) diff --git a/tests/spirv/pointer-array.slang b/tests/spirv/pointer-array.slang new file mode 100644 index 000000000..6c4050536 --- /dev/null +++ b/tests/spirv/pointer-array.slang @@ -0,0 +1,25 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage compute -emit-spirv-directly + + +struct Tester +{ + uint i; +}; + +struct Push +{ + Tester* ptr_array[2]; + uint * out_ptr; +}; + +[[vk::push_constant]] Push p; + +// CHECK: OpEntryPoint + +[shader("compute")] +[numthreads(1, 1, 1)] +void main(int id : SV_DispatchThreadID) +{ + uint i = p.ptr_array[0].i; + *p.out_ptr = i; +}
\ No newline at end of file |
