summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-10-21 08:49:15 -0700
committerGitHub <noreply@github.com>2024-10-21 08:49:15 -0700
commit3e84726f45c66b477569be9e62da71956ab78e94 (patch)
tree8c69306133ee04b6acd14dd07d12a0ed47bf0079 /source
parent20fa42e82dfa8398c9c818773fa40817883fb7ec (diff)
Fix spirv codegen for pointer to empty structs. (#5355)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-compiler-tu.cpp6
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-legalize-types.cpp9
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp31
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp16
-rw-r--r--source/slang/slang-ir.cpp7
-rw-r--r--source/slang/slang-legalize-types.cpp42
7 files changed, 89 insertions, 26 deletions
diff --git a/source/slang/slang-compiler-tu.cpp b/source/slang/slang-compiler-tu.cpp
index 49595117d..5f88e871d 100644
--- a/source/slang/slang-compiler-tu.cpp
+++ b/source/slang/slang-compiler-tu.cpp
@@ -104,7 +104,7 @@ namespace Slang
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);
- TargetRequest* targetReq = new TargetRequest(linkage, targetEnum);
+ RefPtr<TargetRequest> targetReq = new TargetRequest(linkage, targetEnum);
List<RefPtr<ComponentType>> allComponentTypes;
allComponentTypes.add(this); // Add Module as a component type
@@ -206,8 +206,8 @@ namespace Slang
}
}
- ISlangBlob* blob;
- outArtifact->loadBlob(ArtifactKeep::Yes, &blob);
+ ComPtr<ISlangBlob> blob;
+ outArtifact->loadBlob(ArtifactKeep::Yes, blob.writeRef());
// Add the precompiled blob to the module
builder.setInsertInto(module);
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fc8bf99d0..bfe6c27b5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3624,6 +3624,10 @@ public:
IRGenericKind* getGenericKind();
IRPtrType* getPtrType(IRType* valueType);
+
+ // Form a ptr type to `valueType` using the same opcode and address space as `ptrWithAddrSpace`.
+ IRPtrTypeBase* getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace);
+
IROutType* getOutType(IRType* valueType);
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 073b7216b..fe8ee311d 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -934,6 +934,8 @@ static LegalVal legalizeStore(
case LegalVal::Flavor::simple:
{
+ if (legalVal.flavor == LegalVal::Flavor::none)
+ return LegalVal();
context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple());
return legalVal;
}
@@ -2248,7 +2250,7 @@ static LegalVal legalizeLocalVar(
// Easy case: the type is usable as-is, and we
// should just do that.
auto type = maybeSimpleType.getSimple();
- type = context->builder->getPtrType(type);
+ type = context->builder->getPtrTypeWithAddressSpace(type, irLocalVar->getDataType());
if( originalRate )
{
type = context->builder->getRateQualifiedType(
@@ -3669,7 +3671,7 @@ static LegalVal legalizeGlobalVar(
auto legalValueType = legalizeType(
context,
originalValueType);
-
+ auto varPtrType = as<IRPtrTypeBase>(irGlobalVar->getDataType());
switch (legalValueType.flavor)
{
case LegalType::Flavor::simple:
@@ -3678,7 +3680,8 @@ static LegalVal legalizeGlobalVar(
context->builder->setDataType(
irGlobalVar,
context->builder->getPtrType(
- legalValueType.getSimple()));
+ legalValueType.getSimple(),
+ varPtrType ? varPtrType->getAddressSpace():AddressSpace::Global));
return LegalVal::simple(irGlobalVar);
default:
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 29999017a..9ea41e3b4 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -671,7 +671,26 @@ namespace Slang
if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
{
builder.setInsertBefore(ptrVal);
- auto newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
+ auto newArrayPtrVal = fieldAddr->getBase();
+ // Is base a pointer to an empty struct? If so, don't offset it.
+ // For example, if the user has written:
+ // ```
+ // struct S {int arr[]};
+ // uniform S* p;
+ // void test() { p->arr[1]; }
+ // ```
+ // Then `S` will become an empty struct after we remove `arr[]`.
+ // And `p` will be come a `void*`.
+ // We don't want to offset `p` to `p+1` to get the starting address of the array in this case.
+ IRSizeAndAlignment parentStructSize = {};
+ getNaturalSizeAndAlignment(
+ target->getOptionSet(),
+ tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
+ &parentStructSize);
+ if (parentStructSize.size != 0)
+ {
+ newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1));
+ }
auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
IRSizeAndAlignment arrayElementSizeAlignment;
@@ -685,12 +704,14 @@ namespace Slang
&baseSizeAlignment);
// Convert pointer to uint64 and adjust offset.
- auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
IRIntegerValue offset = baseSizeAlignment.size;
offset = align(offset, arrayElementSizeAlignment.alignment);
- newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
- builder.getIntValue(builder.getUInt64Type(), offset));
-
+ if (offset != 0)
+ {
+ auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal);
+ newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr,
+ builder.getIntValue(builder.getUInt64Type(), offset));
+ }
newArrayPtrVal = builder.emitBitCast(
builder.getPtrType(loweredInnerType.loweredType,
ptrType->getAddressSpace()), newArrayPtrVal);
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 7f815df95..cac7c9c5c 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -23,6 +23,7 @@
#include "slang-ir-loop-unroll.h"
#include "slang-ir-lower-buffer-element-type.h"
#include "slang-ir-specialize-address-space.h"
+#include "slang-legalize-types.h"
namespace Slang
{
@@ -37,6 +38,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRModule* m_module;
+ DiagnosticSink* m_sink;
+
struct LoweredStructuredBufferTypeInfo
{
IRType* structType;
@@ -173,8 +176,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
- SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module)
- : m_sharedContext(sharedContext), m_module(module)
+ SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
+ : m_sharedContext(sharedContext), m_module(module), m_sink(sink)
{
}
@@ -2108,6 +2111,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// safely lower the pointer load stores early together with other buffer types.
lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);
+ // The above step may produce empty struct types, so we need to lower them out of existence.
+ legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink);
+
// Specalize address space for all pointers.
SpirvAddressSpaceAssigner addressSpaceAssigner;
specializeAddressSpace(m_module, &addressSpaceAssigner);
@@ -2184,9 +2190,9 @@ SpvSnippet* SPIRVEmitSharedContext::getParsedSpvSnippet(IRTargetIntrinsicDecorat
return snippet;
}
-void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module)
+void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module, DiagnosticSink* sink)
{
- SPIRVLegalizationContext context(sharedContext, module);
+ SPIRVLegalizationContext context(sharedContext, module, sink);
context.processModule();
}
@@ -2326,7 +2332,7 @@ void legalizeIRForSPIRV(
CodeGenContext* codeGenContext)
{
SLANG_UNUSED(entryPoints);
- legalizeSPIRV(context, module);
+ legalizeSPIRV(context, module, codeGenContext->getSink());
simplifyIRForSpirvLegalization(context->m_targetProgram, codeGenContext->getSink(), module);
buildEntryPointReferenceGraph(context->m_referencingEntryPoints, module);
insertFragmentShaderInterlock(context, module);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e0998779a..2fd090877 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2881,6 +2881,13 @@ namespace Slang
operands);
}
+ IRPtrTypeBase* IRBuilder::getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace)
+ {
+ if (ptrWithAddrSpace->hasAddressSpace())
+ return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType, ptrWithAddrSpace->getAddressSpace());
+ return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType);
+ }
+
IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
{
return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace)));
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index b10c88a15..e83316dd3 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -896,24 +896,46 @@ static LegalType createLegalUniformBufferType(
// Create a pointer type with a given legalized value type.
static LegalType createLegalPtrType(
TypeLegalizationContext* context,
- IROp op,
+ IRInst* originalPtrType,
LegalType legalValueType)
{
switch (legalValueType.flavor)
{
case LegalType::Flavor::none:
+ if (auto ptrType = as<IRPtrType>(originalPtrType))
+ {
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::UserPointer:
+ case AddressSpace::Global:
+ // If this is a physical pointer, we need to create an untyped pointer if
+ // the element type is nothing.
+ return LegalType::simple(
+ context->getBuilder()->getPtrTypeWithAddressSpace(
+ context->getBuilder()->getVoidType(),
+ ptrType));
+ }
+ }
return LegalType();
case LegalType::Flavor::simple:
{
- // Easy case: we just have a simple element type,
- // so we want to create a uniform buffer that wraps it.
+ // Easy case: we just have a simple element type.
+ if (auto ptrTypeBase = as<IRPtrTypeBase>(originalPtrType))
+ {
+ if (ptrTypeBase->hasAddressSpace())
+ {
+ return LegalType::simple(
+ context->getBuilder()->getPtrTypeWithAddressSpace(
+ legalValueType.getSimple(),
+ ptrTypeBase));
+ }
+ }
return LegalType::simple(createBuiltinGenericType(
context,
- op,
+ originalPtrType->getOp(),
legalValueType.getSimple()));
}
- break;
case LegalType::Flavor::implicitDeref:
{
@@ -936,7 +958,7 @@ static LegalType createLegalPtrType(
// will matter.
return LegalType::implicitDeref(createLegalPtrType(
context,
- op,
+ originalPtrType,
legalValueType.getImplicitDeref()->valueType));
}
break;
@@ -948,11 +970,11 @@ static LegalType createLegalPtrType(
auto ordinaryType = createLegalPtrType(
context,
- op,
+ originalPtrType,
pairType->ordinaryType);
auto specialType = createLegalPtrType(
context,
- op,
+ originalPtrType,
pairType->specialType);
return LegalType::pair(ordinaryType, specialType, pairType->pairInfo);
@@ -974,7 +996,7 @@ static LegalType createLegalPtrType(
newElement.key = ee.key;
newElement.type = createLegalPtrType(
context,
- op,
+ originalPtrType,
ee.type);
ptrPseudoTupleType->elements.add(newElement);
@@ -1310,7 +1332,7 @@ LegalType legalizeTypeImpl(
if (legalValueType.flavor == LegalType::Flavor::simple &&
legalValueType.getSimple() == ptrType->getValueType())
return LegalType::simple(ptrType);
- return createLegalPtrType(context, ptrType->getOp(), legalValueType);
+ return createLegalPtrType(context, ptrType, legalValueType);
}
else if(auto structType = as<IRStructType>(type))
{