summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-03-06 14:26:34 -0800
committerGitHub <noreply@github.com>2025-03-06 14:26:34 -0800
commit4485cf3eaf142cfd5f8470e86739acc67d4e12ea (patch)
treec6ce220dfe5f3ab25ea558f2512f3761c9565c69 /source/slang/slang-ir-lower-buffer-element-type.cpp
parent55dd2deaff82bbdb72e125ba4b350030b7e5f427 (diff)
Update SPIRV-Tools and fix new validation errors. (#6511)
* Update SPIRV-Tools and fix new validation errors. * Implement pointers for glsl target. * Reworked packStorage/unpackStorage code gen to operate on pointers rather than values.
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp431
1 files changed, 260 insertions, 171 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 74e84f1ee..6f0e22a57 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -58,14 +58,39 @@ struct LoweredElementTypeContext
this->op = irop;
return *this;
}
- IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operand)
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
{
if (!*this)
- return operand;
+ return builder.emitLoad(operandAddr);
if (kind == ConversionMethodKind::Func)
- return builder.emitCallInst(resultType, func, 1, &operand);
+ return builder.emitCallInst(resultType, func, 1, &operandAddr);
else
- return builder.emitIntrinsicInst(resultType, op, 1, &operand);
+ {
+ auto val = builder.emitLoad(operandAddr);
+ return builder.emitIntrinsicInst(resultType, op, 1, &val);
+ }
+ }
+ void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
+ {
+ if (!*this)
+ {
+ builder.emitStore(dest, operand);
+ return;
+ }
+ if (kind == ConversionMethodKind::Func)
+ {
+ IRInst* operands[] = {dest, operand};
+ builder.emitCallInst(builder.getVoidType(), func, 2, operands);
+ }
+ else
+ {
+ auto val = builder.emitIntrinsicInst(
+ tryGetPointedToType(&builder, dest->getDataType()),
+ op,
+ 1,
+ &operand);
+ builder.emitStore(dest, val);
+ }
}
};
@@ -131,21 +156,23 @@ struct LoweredElementTypeContext
IRFunc* createMatrixUnpackFunc(
IRMatrixType* matrixType,
IRStructType* structType,
- IRStructKey* dataKey,
- IRArrayType* arrayType)
+ IRStructKey* dataKey)
{
IRBuilder builder(structType);
builder.setInsertAfter(structType);
auto func = builder.createFunc();
- auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType);
+ auto refStructType = builder.getRefType(structType, AddressSpace::Generic);
+ auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
func->setFullType(funcType);
builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(func);
builder.setInsertInto(func);
builder.emitBlock();
auto rowCount = (Index)getIntVal(matrixType->getRowCount());
auto colCount = (Index)getIntVal(matrixType->getColumnCount());
- auto packedParam = builder.emitParam(structType);
- auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey);
+ auto packedParamRef = builder.emitParam(refStructType);
+ auto packedParam = builder.emitLoad(packedParamRef);
+ auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
List<IRInst*> args;
args.setCount(rowCount * colCount);
if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
@@ -187,13 +214,17 @@ struct LoweredElementTypeContext
IRBuilder builder(structType);
builder.setInsertAfter(structType);
auto func = builder.createFunc();
- auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType);
+ auto outStructType = builder.getRefType(structType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outStructType, matrixType};
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
func->setFullType(funcType);
builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.addForceInlineDecoration(func);
builder.setInsertInto(func);
builder.emitBlock();
auto rowCount = getIntVal(matrixType->getRowCount());
auto colCount = getIntVal(matrixType->getColumnCount());
+ auto outParam = builder.emitParam(outStructType);
auto originalParam = builder.emitParam(matrixType);
List<IRInst*> elements;
elements.setCount((Index)(rowCount * colCount));
@@ -255,7 +286,8 @@ struct LoweredElementTypeContext
auto vectorArray =
builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
- builder.emitReturn(result);
+ builder.emitStore(outParam, result);
+ builder.emitReturn();
return func;
}
@@ -263,19 +295,20 @@ struct LoweredElementTypeContext
IRArrayType* arrayType,
IRStructType* structType,
IRStructKey* dataKey,
- IRArrayType* innerArrayType,
LoweredElementTypeInfo innerTypeInfo)
{
IRBuilder builder(structType);
builder.setInsertAfter(structType);
auto func = builder.createFunc();
- auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType);
+ auto refStructType = builder.getRefType(structType, AddressSpace::Generic);
+ auto funcType = builder.getFuncType(1, (IRType**)&refStructType, arrayType);
func->setFullType(funcType);
builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(func);
builder.setInsertInto(func);
builder.emitBlock();
- auto packedParam = builder.emitParam(structType);
- auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey);
+ auto packedParam = builder.emitParam(refStructType);
+ auto packedArray = builder.emitFieldAddress(packedParam, dataKey);
auto count = getIntVal(arrayType->getElementCount());
IRInst* result = nullptr;
if (count <= kMaxArraySizeToUnroll)
@@ -285,11 +318,11 @@ struct LoweredElementTypeContext
args.setCount((Index)count);
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
- auto packedElement = builder.emitElementExtract(packedArray, ii);
+ auto packedElementAddr = builder.emitElementAddress(packedArray, ii);
auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply(
builder,
innerTypeInfo.originalType,
- packedElement);
+ packedElementAddr);
args[(Index)ii] = originalElement;
}
result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
@@ -308,11 +341,11 @@ struct LoweredElementTypeContext
loopBreakBlock);
builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
- auto packedElement = builder.emitElementExtract(packedArray, loopParam);
+ auto packedElementAddr = builder.emitElementAddress(packedArray, loopParam);
auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply(
builder,
innerTypeInfo.originalType,
- packedElement);
+ packedElementAddr);
auto varPtr = builder.emitElementAddress(resultVar, loopParam);
builder.emitStore(varPtr, originalElement);
builder.setInsertInto(loopBreakBlock);
@@ -325,20 +358,24 @@ struct LoweredElementTypeContext
IRFunc* createArrayPackFunc(
IRArrayType* arrayType,
IRStructType* structType,
- IRArrayType* innerArrayType,
+ IRStructKey* arrayStructKey,
LoweredElementTypeInfo innerTypeInfo)
{
IRBuilder builder(structType);
builder.setInsertAfter(structType);
auto func = builder.createFunc();
- auto funcType = builder.getFuncType(1, (IRType**)&arrayType, structType);
+ auto outLoweredType = builder.getRefType(structType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outLoweredType, structType};
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
func->setFullType(funcType);
builder.addNameHintDecoration(func, UnownedStringSlice("packStorage"));
+ builder.addForceInlineDecoration(func);
builder.setInsertInto(func);
builder.emitBlock();
+ auto outParam = builder.emitParam(outLoweredType);
auto originalParam = builder.emitParam(arrayType);
- IRInst* packedArray = nullptr;
auto count = getIntVal(arrayType->getElementCount());
+ auto destArray = builder.emitFieldAddress(outParam, arrayStructKey);
if (count <= kMaxArraySizeToUnroll)
{
// If the array is small enough, just process each element directly.
@@ -347,19 +384,16 @@ struct LoweredElementTypeContext
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
auto originalElement = builder.emitElementExtract(originalParam, ii);
- auto packedElement = innerTypeInfo.convertOriginalToLowered.apply(
+ auto destArrayElement = builder.emitElementAddress(destArray, ii);
+ innerTypeInfo.convertOriginalToLowered.applyDestinationDriven(
builder,
- innerTypeInfo.loweredType,
+ destArrayElement,
originalElement);
- args[(Index)ii] = packedElement;
}
- packedArray =
- builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());
}
else
{
// The general case for large arrays is to emit a loop through the elements.
- IRVar* packedArrayVar = builder.emitVar(innerArrayType);
IRBlock* loopBodyBlock;
IRBlock* loopBreakBlock;
auto loopParam = emitLoopBlocks(
@@ -371,18 +405,14 @@ struct LoweredElementTypeContext
builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
auto originalElement = builder.emitElementExtract(originalParam, loopParam);
- auto packedElement = innerTypeInfo.convertOriginalToLowered.apply(
+ auto varPtr = builder.emitElementAddress(destArray, loopParam);
+ innerTypeInfo.convertOriginalToLowered.applyDestinationDriven(
builder,
- innerTypeInfo.loweredType,
+ varPtr,
originalElement);
- auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam);
- builder.emitStore(varPtr, packedElement);
builder.setInsertInto(loopBreakBlock);
- packedArray = builder.emitLoad(packedArrayVar);
}
-
- auto result = builder.emitMakeStruct(structType, 1, &packedArray);
- builder.emitReturn(result);
+ builder.emitReturn();
return func;
}
@@ -451,6 +481,8 @@ struct LoweredElementTypeContext
}
auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
StringBuilder nameSB;
bool isColMajor =
getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
@@ -494,14 +526,14 @@ struct LoweredElementTypeContext
info.loweredInnerArrayType = arrayType;
info.loweredInnerStructKey = structKey;
info.convertLoweredToOriginal =
- createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType);
+ createMatrixUnpackFunc(matrixType, loweredType, structKey);
info.convertOriginalToLowered =
createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
return info;
}
- else if (auto arrayType = as<IRArrayType>(type))
+ else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
{
- auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config);
+ auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config);
if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
options.use16ByteArrayElementForConstantBuffer)
@@ -560,42 +592,59 @@ struct LoweredElementTypeContext
}
}
- auto loweredType = builder.createStructType();
- info.loweredType = loweredType;
- StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
- getTypeNameHint(nameSB, arrayType->getElementType());
- nameSB << getIntVal(arrayType->getElementCount());
- builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
- auto structKey = builder.createStructKey();
- builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- loweredInnerTypeInfo.loweredType,
- &elementSizeAlignment);
- elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
- auto innerArrayType = builder.getArrayType(
- loweredInnerTypeInfo.loweredType,
- arrayType->getElementCount(),
- builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
- builder.createStructField(loweredType, structKey, innerArrayType);
- info.loweredInnerArrayType = innerArrayType;
- info.loweredInnerStructKey = structKey;
- info.convertLoweredToOriginal = createArrayUnpackFunc(
- arrayType,
- loweredType,
- structKey,
- innerArrayType,
- loweredInnerTypeInfo);
- info.convertOriginalToLowered =
- createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo);
- return info;
- }
- else if (as<IRArrayTypeBase>(type))
- {
- info.loweredType = builder.getVoidType();
+ auto arrayType = as<IRArrayType>(arrayTypeBase);
+ if (arrayType)
+ {
+ auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
+ info.loweredType = loweredType;
+ StringBuilder nameSB;
+ nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
+ getTypeNameHint(nameSB, arrayType->getElementType());
+ nameSB << getIntVal(arrayType->getElementCount());
+ builder.addNameHintDecoration(
+ loweredType,
+ nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ loweredInnerTypeInfo.loweredType,
+ &elementSizeAlignment);
+ elementSizeAlignment =
+ config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ auto innerArrayType = builder.getArrayType(
+ loweredInnerTypeInfo.loweredType,
+ arrayType->getElementCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
+ builder.createStructField(loweredType, structKey, innerArrayType);
+ info.loweredInnerArrayType = innerArrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal =
+ createArrayUnpackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo);
+ info.convertOriginalToLowered =
+ createArrayPackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo);
+ }
+ else
+ {
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ loweredInnerTypeInfo.loweredType,
+ &elementSizeAlignment);
+ elementSizeAlignment =
+ config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ auto innerArrayType = builder.getArrayTypeBase(
+ arrayTypeBase->getOp(),
+ loweredInnerTypeInfo.loweredType,
+ nullptr,
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
+ info.loweredType = innerArrayType;
+ }
return info;
}
else if (auto structType = as<IRStructType>(type))
@@ -625,6 +674,8 @@ struct LoweredElementTypeContext
}
}
auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
@@ -635,12 +686,15 @@ struct LoweredElementTypeContext
Index fieldId = 0;
for (auto field : structType->getFields())
{
- if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType))
+ auto& loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId];
+ // When lowering type for user pointer, skip fields that are unsized array.
+ if (config.addressSpace == AddressSpace::UserPointer &&
+ as<IRUnsizedArrayType>(loweredFieldTypeInfo.loweredType))
{
fieldId++;
+ loweredFieldTypeInfo.loweredType = builder.getVoidType();
continue;
}
- auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId];
builder.createStructField(
loweredType,
field->getKey(),
@@ -657,10 +711,12 @@ struct LoweredElementTypeContext
builder.addNameHintDecoration(
info.convertLoweredToOriginal.func,
UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(info.convertLoweredToOriginal.func);
+ auto refLoweredType = builder.getRefType(loweredType, AddressSpace::Generic);
info.convertLoweredToOriginal.func->setFullType(
- builder.getFuncType(1, (IRType**)&loweredType, type));
+ builder.getFuncType(1, (IRType**)&refLoweredType, type));
builder.emitBlock();
- auto loweredParam = builder.emitParam(loweredType);
+ auto loweredParam = builder.emitParam(refLoweredType);
List<IRInst*> args;
Index fieldId = 0;
for (auto field : structType->getFields())
@@ -670,10 +726,7 @@ struct LoweredElementTypeContext
fieldId++;
continue;
}
- auto storageField = builder.emitFieldExtract(
- fieldLoweredTypeInfo[fieldId].loweredType,
- loweredParam,
- field->getKey());
+ auto storageField = builder.emitFieldAddress(loweredParam, field->getKey());
auto unpackedField =
fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal.apply(
builder,
@@ -694,9 +747,14 @@ struct LoweredElementTypeContext
builder.addNameHintDecoration(
info.convertOriginalToLowered.func,
UnownedStringSlice("packStorage"));
+ builder.addForceInlineDecoration(info.convertOriginalToLowered.func);
+
+ auto outLoweredType = builder.getRefType(loweredType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outLoweredType, type};
info.convertOriginalToLowered.func->setFullType(
- builder.getFuncType(1, (IRType**)&type, loweredType));
+ builder.getFuncType(2, paramTypes, builder.getVoidType()));
builder.emitBlock();
+ auto outParam = builder.emitParam(outLoweredType);
auto param = builder.emitParam(type);
List<IRInst*> args;
Index fieldId = 0;
@@ -709,15 +767,15 @@ struct LoweredElementTypeContext
}
auto fieldVal =
builder.emitFieldExtract(field->getFieldType(), param, field->getKey());
- auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.apply(
+ auto destAddr = builder.emitFieldAddress(outParam, field->getKey());
+
+ fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.applyDestinationDriven(
builder,
- fieldLoweredTypeInfo[fieldId].loweredType,
+ destAddr,
fieldVal);
- args.add(packedField);
fieldId++;
}
- auto result = builder.emitMakeStruct(loweredType, args);
- builder.emitReturn(result);
+ builder.emitReturn();
}
return info;
@@ -743,37 +801,8 @@ struct LoweredElementTypeContext
info.loweredType = builder.getVectorType(
info.loweredType,
vectorType->getElementCount());
- // Create unpack func.
- {
- builder.setInsertAfter(type);
- info.convertLoweredToOriginal = builder.createFunc();
- builder.setInsertInto(info.convertLoweredToOriginal.func);
- builder.addNameHintDecoration(
- info.convertLoweredToOriginal.func,
- UnownedStringSlice("unpackStorage"));
- info.convertLoweredToOriginal.func->setFullType(
- builder.getFuncType(1, (IRType**)&info.loweredType, type));
- builder.emitBlock();
- auto loweredParam = builder.emitParam(info.loweredType);
- auto result = builder.emitCast(type, loweredParam);
- builder.emitReturn(result);
- }
-
- // Create pack func.
- {
- builder.setInsertAfter(info.convertLoweredToOriginal.func);
- info.convertOriginalToLowered = builder.createFunc();
- builder.setInsertInto(info.convertOriginalToLowered.func);
- builder.addNameHintDecoration(
- info.convertOriginalToLowered.func,
- UnownedStringSlice("packStorage"));
- info.convertOriginalToLowered.func->setFullType(
- builder.getFuncType(1, (IRType**)&type, info.loweredType));
- builder.emitBlock();
- auto param = builder.emitParam(type);
- auto result = builder.emitCast(info.loweredType, param);
- builder.emitReturn(result);
- }
+ info.convertLoweredToOriginal = kIROp_BuiltinCast;
+ info.convertOriginalToLowered = kIROp_BuiltinCast;
return info;
}
}
@@ -828,7 +857,8 @@ struct LoweredElementTypeContext
IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType)
{
if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) ||
- as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType))
+ as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType) ||
+ as<IRGLSLShaderStorageBufferType>(originalPtrLikeType))
{
IRBuilder builder(newElementType);
builder.setInsertAfter(newElementType);
@@ -859,6 +889,26 @@ struct LoweredElementTypeContext
TypeLoweringConfig config;
};
+ IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst)
+ {
+ switch (loadStoreInst->getOp())
+ {
+ case kIROp_Load:
+ case kIROp_Store:
+ return loadStoreInst->getOperand(0);
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ case kIROp_RWStructuredBufferLoad:
+ case kIROp_RWStructuredBufferLoadStatus:
+ case kIROp_RWStructuredBufferStore:
+ return builder.emitRWStructuredBufferGetElementPtr(
+ loadStoreInst->getOperand(0),
+ loadStoreInst->getOperand(1));
+ default:
+ return nullptr;
+ }
+ }
+
void processModule(IRModule* module)
{
IRBuilder builder(module);
@@ -891,6 +941,8 @@ struct LoweredElementTypeContext
elementType = structBuffer->getElementType();
else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
elementType = constBuffer->getElementType();
+ else if (auto storageBuffer = as<IRGLSLShaderStorageBufferType>(globalInst))
+ elementType = storageBuffer->getElementType();
if (as<IRTextureBufferType>(globalInst))
continue;
if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) &&
@@ -908,6 +960,10 @@ struct LoweredElementTypeContext
{
auto bufferType = bufferTypeInfo.bufferType;
auto elementType = bufferTypeInfo.elementType;
+
+ if (elementType->findDecoration<IRPhysicalTypeDecoration>())
+ continue;
+
auto config = getTypeLoweringConfigForBuffer(target, bufferType);
auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config);
@@ -954,8 +1010,13 @@ struct LoweredElementTypeContext
// getOffsetPtr(trailingPtr, index).
if (auto fieldAddr = as<IRFieldAddress>(ptrVal))
{
- if (auto ptrType = as<IRPtrType>(ptrVal->getDataType()))
+ auto handleUnsizedArrayAccess = [&]() -> bool
{
+ auto ptrType = as<IRPtrType>(ptrVal->getDataType());
+ if (!ptrType)
+ return false;
+ if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
+ return false;
if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType()))
{
builder.setInsertBefore(ptrVal);
@@ -1019,9 +1080,12 @@ struct LoweredElementTypeContext
});
SLANG_ASSERT(!ptrVal->hasUses());
ptrVal->removeAndDeallocate();
- continue;
+ return true;
}
- }
+ return false;
+ };
+ if (handleUnsizedArrayAccess())
+ continue;
}
LoweredElementTypeInfo loweredElementTypeInfo = {};
@@ -1060,7 +1124,7 @@ struct LoweredElementTypeContext
getLoweredTypeInfo((IRType*)originalElementType, config);
}
- if (!loweredElementTypeInfo.convertLoweredToOriginal)
+ if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType)
continue;
ptrVal->setFullType(getLoweredPtrLikeType(
@@ -1083,15 +1147,28 @@ struct LoweredElementTypeContext
case kIROp_RWStructuredBufferLoadStatus:
case kIROp_StructuredBufferConsume:
{
- IRCloneEnv cloneEnv = {};
builder.setInsertBefore(user);
- auto newLoad = cloneInst(&cloneEnv, &builder, user);
- newLoad->setFullType(loweredElementTypeInfo.loweredType);
+ auto addr = getBufferAddr(builder, user);
+ if (!addr)
+ {
+ IRCloneEnv cloneEnv = {};
+ builder.setInsertBefore(user);
+ auto newLoad = cloneInst(&cloneEnv, &builder, user);
+ newLoad->setFullType(loweredElementTypeInfo.loweredType);
+ addr = builder.emitVar(loweredElementTypeInfo.loweredType);
+ builder.emitStore(addr, newLoad);
+ }
+ if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
+ {
+ builder.addAlignedAddressDecoration(
+ addr,
+ alignedAttr->getAlignment());
+ }
auto unpackedVal =
loweredElementTypeInfo.convertLoweredToOriginal.apply(
builder,
loweredElementTypeInfo.originalType,
- newLoad);
+ addr);
user->replaceUsesWith(unpackedVal);
user->removeAndDeallocate();
break;
@@ -1106,19 +1183,33 @@ struct LoweredElementTypeContext
IRCloneEnv cloneEnv = {};
builder.setInsertBefore(user);
auto originalVal = getStoreVal(user);
- auto packedVal =
- loweredElementTypeInfo.convertOriginalToLowered.apply(
- builder,
- loweredElementTypeInfo.loweredType,
- originalVal);
- if (auto store = as<IRStore>(user))
- store->val.set(packedVal);
- else if (auto sbStore = as<IRRWStructuredBufferStore>(user))
- sbStore->setOperand(2, packedVal);
+ IRInst* addr = getBufferAddr(builder, user);
+ if (addr)
+ {
+ if (auto alignedAttr = user->findAttr<IRAlignedAttr>())
+ {
+ builder.addAlignedAddressDecoration(
+ addr,
+ alignedAttr->getAlignment());
+ }
+
+ loweredElementTypeInfo.convertOriginalToLowered
+ .applyDestinationDriven(builder, addr, originalVal);
+ user->removeAndDeallocate();
+ }
else if (auto sbAppend = as<IRStructuredBufferAppend>(user))
+ {
+ builder.setInsertBefore(sbAppend);
+ addr = builder.emitVar(loweredElementTypeInfo.loweredType);
+ loweredElementTypeInfo.convertOriginalToLowered
+ .applyDestinationDriven(builder, addr, originalVal);
+ auto packedVal = builder.emitLoad(addr);
sbAppend->setOperand(1, packedVal);
+ }
else
+ {
SLANG_UNREACHABLE("unhandled store type");
+ }
break;
}
case kIROp_GetElementPtr:
@@ -1176,24 +1267,18 @@ struct LoweredElementTypeContext
// access, we need to materialize the object as a local variable,
// and pass the address of the local variable to the function.
builder.setInsertBefore(user);
- auto newLoad =
- builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal);
auto unpackedVal =
loweredElementTypeInfo.convertLoweredToOriginal.apply(
builder,
(IRType*)originalElementType,
- newLoad);
+ ptrVal);
auto var = builder.emitVar((IRType*)originalElementType);
builder.emitStore(var, unpackedVal);
use->set(var);
builder.setInsertAfter(user);
auto newVal = builder.emitLoad(var);
- auto packedVal =
- loweredElementTypeInfo.convertOriginalToLowered.apply(
- builder,
- (IRType*)loweredElementTypeInfo.loweredType,
- newVal);
- builder.emitStore(ptrVal, packedVal);
+ loweredElementTypeInfo.convertOriginalToLowered
+ .applyDestinationDriven(builder, ptrVal, newVal);
}
break;
default:
@@ -1355,6 +1440,21 @@ void lowerBufferElementTypeToStorageType(
context.processModule(module);
}
+IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout)
+{
+ switch (layoutTypeOp)
+ {
+ case kIROp_DefaultBufferLayoutType:
+ return defaultLayout;
+ case kIROp_Std140BufferLayoutType:
+ return IRTypeLayoutRules::getStd140();
+ case kIROp_Std430BufferLayoutType:
+ return IRTypeLayoutRules::getStd430();
+ case kIROp_ScalarBufferLayoutType:
+ return IRTypeLayoutRules::getNatural();
+ }
+ return defaultLayout;
+}
IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
{
@@ -1395,18 +1495,7 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = structBufferType->getDataLayout()
? structBufferType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- switch (layoutTypeOp)
- {
- case kIROp_DefaultBufferLayoutType:
- return IRTypeLayoutRules::getStd430();
- case kIROp_Std140BufferLayoutType:
- return IRTypeLayoutRules::getStd140();
- case kIROp_Std430BufferLayoutType:
- return IRTypeLayoutRules::getStd430();
- case kIROp_ScalarBufferLayoutType:
- return IRTypeLayoutRules::getNatural();
- }
- return IRTypeLayoutRules::getStd430();
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
}
case kIROp_ConstantBufferType:
case kIROp_ParameterBlockType:
@@ -1416,18 +1505,15 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = parameterGroupType->getDataLayout()
? parameterGroupType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- switch (layoutTypeOp)
- {
- case kIROp_DefaultBufferLayoutType:
- return IRTypeLayoutRules::getStd140();
- case kIROp_Std140BufferLayoutType:
- return IRTypeLayoutRules::getStd140();
- case kIROp_Std430BufferLayoutType:
- return IRTypeLayoutRules::getStd430();
- case kIROp_ScalarBufferLayoutType:
- return IRTypeLayoutRules::getNatural();
- }
- return IRTypeLayoutRules::getStd140();
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140());
+ }
+ case kIROp_GLSLShaderStorageBufferType:
+ {
+ auto storageBufferType = as<IRGLSLShaderStorageBufferType>(bufferType);
+ auto layoutTypeOp = storageBufferType->getDataLayout()
+ ? storageBufferType->getDataLayout()->getOp()
+ : kIROp_Std430BufferLayoutType;
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
}
case kIROp_PtrType:
return IRTypeLayoutRules::getNatural();
@@ -1446,6 +1532,9 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType*
case AddressSpace::Output:
addrSpace = AddressSpace::Input;
break;
+ case AddressSpace::UserPointer:
+ addrSpace = AddressSpace::UserPointer;
+ break;
}
}
auto rules = getTypeLayoutRuleForBuffer(target, bufferType);