summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-11-06 10:58:09 -0800
committerGitHub <noreply@github.com>2024-11-06 10:58:09 -0800
commitb86703432629bbfd75a902671d15e40c591065a7 (patch)
tree2109f87304fa534034f01812551f29be098c4710 /source
parentf8294202ce8d5658f6308eeaed634058db9bbb4b (diff)
[WGSL] Enable arbitrary arrays in uniform buffers. (#5497)
* [WGSL] Enable arbitrary arrays in uniform buffers. * format code * Undo irrelevant change and fixups. * Update expected failure list. * Fix. * Rename. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp353
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.h10
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp7
-rw-r--r--source/slang/slang-ir.cpp5
6 files changed, 281 insertions, 100 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c23195f7c..1950f251c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1455,7 +1455,10 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.meshOutput)
legalizeMeshOutputTypes(irModule);
- lowerBufferElementTypeToStorageType(targetProgram, irModule);
+ BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
+ bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer =
+ isWGPUTarget(targetRequest);
+ lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions);
// Rewrite functions that return arrays to return them via `out` parameter,
// since our target languages doesn't allow returning arrays.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9a081f9de..4baa786e3 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3566,6 +3566,7 @@ public:
IRInst* replaceOperand(IRUse* use, IRInst* newValue);
IRInst* getBoolValue(bool value);
+ IRInst* getIntValue(IRIntegerValue value);
IRInst* getIntValue(IRType* type, IRIntegerValue value);
IRInst* getFloatValue(IRType* type, IRFloatingPointValue value);
IRStringLit* getStringValue(const UnownedStringSlice& slice);
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 7d722bf9c..f9017ebe1 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -12,6 +12,47 @@ struct LoweredElementTypeContext
{
static const IRIntegerValue kMaxArraySizeToUnroll = 32;
+ enum ConversionMethodKind
+ {
+ Func,
+ Opcode
+ };
+ struct ConversionMethod
+ {
+ ConversionMethodKind kind = ConversionMethodKind::Func;
+ union
+ {
+ IRFunc* func;
+ IROp op;
+ };
+ ConversionMethod() { func = nullptr; }
+ operator bool()
+ {
+ return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
+ }
+ ConversionMethod& operator=(IRFunc* f)
+ {
+ kind = ConversionMethodKind::Func;
+ this->func = f;
+ return *this;
+ }
+ ConversionMethod& operator=(IROp irop)
+ {
+ kind = ConversionMethodKind::Opcode;
+ this->op = irop;
+ return *this;
+ }
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operand)
+ {
+ if (!*this)
+ return operand;
+ if (kind == ConversionMethodKind::Func)
+ return builder.emitCallInst(resultType, func, 1, &operand);
+ else
+ return builder.emitIntrinsicInst(resultType, op, 1, &operand);
+ }
+ };
+
struct LoweredElementTypeInfo
{
IRType* originalType;
@@ -22,25 +63,48 @@ struct LoweredElementTypeContext
IRStructKey* loweredInnerStructKey =
nullptr; // For matrix/array types that are lowered into a struct type, this is the
// struct key of the data field.
- IRFunc* convertOriginalToLowered = nullptr;
- IRFunc* convertLoweredToOriginal = nullptr;
+ ConversionMethod convertOriginalToLowered;
+ ConversionMethod convertLoweredToOriginal;
};
Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count];
Dictionary<IRType*, LoweredElementTypeInfo>
mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count];
+ struct ConversionMethodKey
+ {
+ IRType* toType;
+ IRType* fromType;
+ bool operator==(const ConversionMethodKey& other) const
+ {
+ return toType == other.toType && fromType == other.fromType;
+ }
+ HashCode64 getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(toType), Slang::getHashCode(fromType));
+ }
+ };
+
+ Dictionary<ConversionMethodKey, ConversionMethod> conversionMethodMap;
+ ConversionMethod getConversionMethod(IRType* toType, IRType* fromType)
+ {
+ ConversionMethodKey key;
+ key.toType = toType;
+ key.fromType = fromType;
+ ConversionMethod method;
+ conversionMethodMap.tryGetValue(key, method);
+ return method;
+ }
+
SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
TargetProgram* target;
- bool lowerBufferPointer = false;
+ BufferElementTypeLoweringOptions options;
LoweredElementTypeContext(
TargetProgram* target,
- bool lowerBufferPointer,
+ BufferElementTypeLoweringOptions inOptions,
SlangMatrixLayoutMode inDefaultMatrixLayout)
- : target(target)
- , defaultMatrixLayout(inDefaultMatrixLayout)
- , lowerBufferPointer(lowerBufferPointer)
+ : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions)
{
}
@@ -133,6 +197,11 @@ struct LoweredElementTypeContext
auto element = elements[(Index)(r * colCount + c)];
vecArgs.add(element);
}
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
auto colVector = builder.emitMakeVector(
vectorType,
(UInt)vecArgs.getCount(),
@@ -150,6 +219,11 @@ struct LoweredElementTypeContext
auto element = elements[(Index)(r * colCount + c)];
vecArgs.add(element);
}
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
auto rowVector = builder.emitMakeVector(
vectorType,
(UInt)vecArgs.getCount(),
@@ -192,13 +266,10 @@ struct LoweredElementTypeContext
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
auto packedElement = builder.emitElementExtract(packedArray, ii);
- auto originalElement = innerTypeInfo.convertLoweredToOriginal
- ? builder.emitCallInst(
- innerTypeInfo.originalType,
- innerTypeInfo.convertLoweredToOriginal,
- 1,
- &packedElement)
- : packedElement;
+ auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply(
+ builder,
+ innerTypeInfo.originalType,
+ packedElement);
args[(Index)ii] = originalElement;
}
result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
@@ -218,13 +289,10 @@ struct LoweredElementTypeContext
builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
auto packedElement = builder.emitElementExtract(packedArray, loopParam);
- auto originalElement = innerTypeInfo.convertLoweredToOriginal
- ? builder.emitCallInst(
- innerTypeInfo.originalType,
- innerTypeInfo.convertLoweredToOriginal,
- 1,
- &packedElement)
- : packedElement;
+ auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply(
+ builder,
+ innerTypeInfo.originalType,
+ packedElement);
auto varPtr = builder.emitElementAddress(resultVar, loopParam);
builder.emitStore(varPtr, originalElement);
builder.setInsertInto(loopBreakBlock);
@@ -259,13 +327,10 @@ struct LoweredElementTypeContext
for (IRIntegerValue ii = 0; ii < count; ++ii)
{
auto originalElement = builder.emitElementExtract(originalParam, ii);
- auto packedElement = innerTypeInfo.convertOriginalToLowered
- ? builder.emitCallInst(
- innerTypeInfo.loweredType,
- innerTypeInfo.convertOriginalToLowered,
- 1,
- &originalElement)
- : originalElement;
+ auto packedElement = innerTypeInfo.convertOriginalToLowered.apply(
+ builder,
+ innerTypeInfo.loweredType,
+ originalElement);
args[(Index)ii] = packedElement;
}
packedArray =
@@ -286,13 +351,10 @@ struct LoweredElementTypeContext
builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst());
auto originalElement = builder.emitElementExtract(originalParam, loopParam);
- auto packedElement = innerTypeInfo.convertOriginalToLowered
- ? builder.emitCallInst(
- innerTypeInfo.loweredType,
- innerTypeInfo.convertOriginalToLowered,
- 1,
- &originalElement)
- : originalElement;
+ auto packedElement = innerTypeInfo.convertOriginalToLowered.apply(
+ builder,
+ innerTypeInfo.loweredType,
+ originalElement);
auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam);
builder.emitStore(varPtr, packedElement);
builder.setInsertInto(loopBreakBlock);
@@ -319,6 +381,17 @@ struct LoweredElementTypeContext
}
}
+ // Returns the number of elements N that ensures the IRVectorType(elementType,N)
+ // has 16-byte aligned size and N is no less than `minCount`.
+ IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount)
+ {
+ IRSizeAndAlignment sizeAlignment;
+ getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
+ if (sizeAlignment.size)
+ return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
+ return 4;
+ }
+
LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules)
{
IRBuilder builder(type);
@@ -357,9 +430,18 @@ struct LoweredElementTypeContext
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
auto structKey = builder.createStructKey();
builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- auto vectorType = builder.getVectorType(
- matrixType->getElementType(),
- isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount());
+ auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
+ if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ options.use16ByteArrayElementForConstantBuffer)
+ {
+ // For constant buffer layout, we need to use 16-byte aligned vector if
+ // we are required to ensure array element types has 16-byte stride.
+ vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
+ matrixType->getElementType(),
+ getIntVal(vectorSize)));
+ }
+
+ auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment);
elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
@@ -382,6 +464,52 @@ struct LoweredElementTypeContext
else if (auto arrayType = as<IRArrayType>(type))
{
auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules);
+
+ if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ options.use16ByteArrayElementForConstantBuffer)
+ {
+ // For constant buffer layout, we need to use 16-byte-aligned vector if
+ // we are required to ensure array element types has 16-byte stride.
+ // We only need to handle the case where the element type is a scalar or vector
+ // type here, because if the element type is a matrix type or struct type,
+ // the size promotion will be handled during lowering of the element type.
+ IRType* packedVectorType = nullptr;
+ if (auto vectorType = as<IRVectorType>(loweredInnerTypeInfo.loweredType))
+ {
+ packedVectorType = builder.getVectorType(
+ vectorType->getElementType(),
+ builder.getIntValue(get16ByteAlignedVectorElementCount(
+ vectorType->getElementType(),
+ getIntVal(vectorType->getElementCount()))));
+ if (packedVectorType != loweredInnerTypeInfo.originalType)
+ {
+ loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape;
+ loweredInnerTypeInfo.convertOriginalToLowered = kIROp_VectorReshape;
+ }
+ }
+ else if (auto scalarType = as<IRBasicType>(loweredInnerTypeInfo.loweredType))
+ {
+ packedVectorType = builder.getVectorType(
+ loweredInnerTypeInfo.loweredType,
+ get16ByteAlignedVectorElementCount(scalarType, 1));
+ loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape;
+ loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar;
+ }
+ if (packedVectorType)
+ {
+ loweredInnerTypeInfo.loweredType = packedVectorType;
+ if (loweredInnerTypeInfo.convertLoweredToOriginal)
+ conversionMethodMap[ConversionMethodKey{
+ packedVectorType,
+ loweredInnerTypeInfo.originalType}] =
+ loweredInnerTypeInfo.convertOriginalToLowered;
+ if (loweredInnerTypeInfo.convertOriginalToLowered)
+ conversionMethodMap[ConversionMethodKey{
+ loweredInnerTypeInfo.originalType,
+ packedVectorType}] = loweredInnerTypeInfo.convertLoweredToOriginal;
+ }
+ }
+
// For spirv backend, we always want to lower all array types, even if the element type
// comes out the same. This is because different layout rules may have different array
// stride requirements.
@@ -393,6 +521,7 @@ struct LoweredElementTypeContext
return info;
}
}
+
auto loweredType = builder.createStructType();
info.loweredType = loweredType;
StringBuilder nameSB;
@@ -486,11 +615,11 @@ struct LoweredElementTypeContext
{
builder.setInsertAfter(loweredType);
info.convertLoweredToOriginal = builder.createFunc();
- builder.setInsertInto(info.convertLoweredToOriginal);
+ builder.setInsertInto(info.convertLoweredToOriginal.func);
builder.addNameHintDecoration(
- info.convertLoweredToOriginal,
+ info.convertLoweredToOriginal.func,
UnownedStringSlice("unpackStorage"));
- info.convertLoweredToOriginal->setFullType(
+ info.convertLoweredToOriginal.func->setFullType(
builder.getFuncType(1, (IRType**)&loweredType, type));
builder.emitBlock();
auto loweredParam = builder.emitParam(loweredType);
@@ -508,13 +637,10 @@ struct LoweredElementTypeContext
loweredParam,
field->getKey());
auto unpackedField =
- fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal
- ? builder.emitCallInst(
- field->getFieldType(),
- fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal,
- 1,
- &storageField)
- : storageField;
+ fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal.apply(
+ builder,
+ field->getFieldType(),
+ storageField);
args.add(unpackedField);
fieldId++;
}
@@ -524,13 +650,13 @@ struct LoweredElementTypeContext
// Create pack func.
{
- builder.setInsertAfter(info.convertLoweredToOriginal);
+ builder.setInsertAfter(info.convertLoweredToOriginal.func);
info.convertOriginalToLowered = builder.createFunc();
- builder.setInsertInto(info.convertOriginalToLowered);
+ builder.setInsertInto(info.convertOriginalToLowered.func);
builder.addNameHintDecoration(
- info.convertOriginalToLowered,
+ info.convertOriginalToLowered.func,
UnownedStringSlice("packStorage"));
- info.convertOriginalToLowered->setFullType(
+ info.convertOriginalToLowered.func->setFullType(
builder.getFuncType(1, (IRType**)&type, loweredType));
builder.emitBlock();
auto param = builder.emitParam(type);
@@ -545,14 +671,10 @@ struct LoweredElementTypeContext
}
auto fieldVal =
builder.emitFieldExtract(field->getFieldType(), param, field->getKey());
- auto packedField =
- fieldLoweredTypeInfo[fieldId].convertOriginalToLowered
- ? builder.emitCallInst(
- fieldLoweredTypeInfo[fieldId].loweredType,
- fieldLoweredTypeInfo[fieldId].convertOriginalToLowered,
- 1,
- &fieldVal)
- : fieldVal;
+ auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.apply(
+ builder,
+ fieldLoweredTypeInfo[fieldId].loweredType,
+ fieldVal);
args.add(packedField);
fieldId++;
}
@@ -587,11 +709,11 @@ struct LoweredElementTypeContext
{
builder.setInsertAfter(type);
info.convertLoweredToOriginal = builder.createFunc();
- builder.setInsertInto(info.convertLoweredToOriginal);
+ builder.setInsertInto(info.convertLoweredToOriginal.func);
builder.addNameHintDecoration(
- info.convertLoweredToOriginal,
+ info.convertLoweredToOriginal.func,
UnownedStringSlice("unpackStorage"));
- info.convertLoweredToOriginal->setFullType(
+ info.convertLoweredToOriginal.func->setFullType(
builder.getFuncType(1, (IRType**)&info.loweredType, type));
builder.emitBlock();
auto loweredParam = builder.emitParam(info.loweredType);
@@ -601,13 +723,13 @@ struct LoweredElementTypeContext
// Create pack func.
{
- builder.setInsertAfter(info.convertLoweredToOriginal);
+ builder.setInsertAfter(info.convertLoweredToOriginal.func);
info.convertOriginalToLowered = builder.createFunc();
- builder.setInsertInto(info.convertOriginalToLowered);
+ builder.setInsertInto(info.convertOriginalToLowered.func);
builder.addNameHintDecoration(
- info.convertOriginalToLowered,
+ info.convertOriginalToLowered.func,
UnownedStringSlice("packStorage"));
- info.convertOriginalToLowered->setFullType(
+ info.convertOriginalToLowered.func->setFullType(
builder.getFuncType(1, (IRType**)&type, info.loweredType));
builder.emitBlock();
auto param = builder.emitParam(type);
@@ -644,6 +766,8 @@ struct LoweredElementTypeContext
getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment);
loweredTypeInfo[(int)rules->ruleName].set(type, info);
mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info);
+ conversionMethodMap[{info.originalType, info.loweredType}] = info.convertLoweredToOriginal;
+ conversionMethodMap[{info.loweredType, info.originalType}] = info.convertOriginalToLowered;
return info;
}
@@ -693,7 +817,7 @@ struct LoweredElementTypeContext
for (auto globalInst : module->getGlobalInsts())
{
IRType* elementType = nullptr;
- if (lowerBufferPointer)
+ if (options.lowerBufferPointer)
{
if (auto ptrType = as<IRPtrType>(globalInst))
{
@@ -729,7 +853,8 @@ struct LoweredElementTypeContext
auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules);
// If the lowered type is the same as original type, no change is required.
- if (!loweredBufferElementTypeInfo.convertLoweredToOriginal)
+ if (loweredBufferElementTypeInfo.loweredType ==
+ loweredBufferElementTypeInfo.originalType)
continue;
builder.setInsertBefore(bufferType);
@@ -862,8 +987,41 @@ struct LoweredElementTypeContext
}
}
- auto loweredElementTypeInfo =
- getLoweredTypeInfo((IRType*)originalElementType, layoutRules);
+ LoweredElementTypeInfo loweredElementTypeInfo = {};
+ if (auto getElementPtr = as<IRGetElementPtr>(ptrVal))
+ {
+ if (auto arrayType = as<IRArrayTypeBase>(
+ tryGetPointedToType(&builder, getElementPtr->getBase()->getDataType())))
+ {
+ // For WGSL, an array of scalar or vector type will always be converted to
+ // an array of 16-byte aligned vector type. In this case, we will run into a
+ // GetElementPtr where the result type is different from the element type of
+ // the base array.
+ // We should setup loweredElementTypeInfo so the remaining logic can handle
+ // this case and insert proper packing/unpacking logic around it.
+ if (arrayType->getElementType() != originalElementType)
+ {
+ loweredElementTypeInfo.loweredType = arrayType->getElementType();
+ loweredElementTypeInfo.originalType = (IRType*)originalElementType;
+ loweredElementTypeInfo.convertLoweredToOriginal = getConversionMethod(
+ loweredElementTypeInfo.originalType,
+ loweredElementTypeInfo.loweredType);
+ loweredElementTypeInfo.convertOriginalToLowered = getConversionMethod(
+ loweredElementTypeInfo.loweredType,
+ loweredElementTypeInfo.originalType);
+ }
+ }
+ }
+
+ // For general cases we simply check if the element type needs lowering.
+ // If so we will insert packing/unpacking logic if necessary.
+ //
+ if (!loweredElementTypeInfo.loweredType)
+ {
+ loweredElementTypeInfo =
+ getLoweredTypeInfo((IRType*)originalElementType, layoutRules);
+ }
+
if (!loweredElementTypeInfo.convertLoweredToOriginal)
continue;
@@ -891,11 +1049,11 @@ struct LoweredElementTypeContext
builder.setInsertBefore(user);
auto newLoad = cloneInst(&cloneEnv, &builder, user);
newLoad->setFullType(loweredElementTypeInfo.loweredType);
- auto unpackedVal = builder.emitCallInst(
- (IRType*)originalElementType,
- loweredElementTypeInfo.convertLoweredToOriginal,
- 1,
- &newLoad);
+ auto unpackedVal =
+ loweredElementTypeInfo.convertLoweredToOriginal.apply(
+ builder,
+ loweredElementTypeInfo.originalType,
+ newLoad);
user->replaceUsesWith(unpackedVal);
user->removeAndDeallocate();
break;
@@ -910,11 +1068,11 @@ struct LoweredElementTypeContext
IRCloneEnv cloneEnv = {};
builder.setInsertBefore(user);
auto originalVal = getStoreVal(user);
- auto packedVal = builder.emitCallInst(
- loweredElementTypeInfo.loweredType,
- loweredElementTypeInfo.convertOriginalToLowered,
- 1,
- &originalVal);
+ auto packedVal =
+ loweredElementTypeInfo.convertOriginalToLowered.apply(
+ builder,
+ loweredElementTypeInfo.loweredType,
+ originalVal);
if (auto store = as<IRStore>(user))
store->val.set(packedVal);
else if (auto sbStore = as<IRRWStructuredBufferStore>(user))
@@ -954,9 +1112,9 @@ struct LoweredElementTypeContext
}
else
{
- // If we getting a derived address from the pointer, we need to
- // recursively lower the new address. We do so by pushing the
- // address inst into the work list.
+ // If we getting a derived address from the pointer, we need
+ // to recursively lower the new address. We do so by pushing
+ // the address inst into the work list.
ptrValsWorkList.add(user);
}
}
@@ -973,7 +1131,8 @@ struct LoweredElementTypeContext
// an argument, we don't need to do any marshalling here.
if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType()))
break;
- if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType()))
+ if (options.lowerBufferPointer &&
+ as<IRPtrType>(ptrVal->getDataType()))
break;
// If we are calling a function with an l-value pointer from buffer
// access, we need to materialize the object as a local variable,
@@ -981,21 +1140,21 @@ struct LoweredElementTypeContext
builder.setInsertBefore(user);
auto newLoad =
builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal);
- auto unpackedVal = builder.emitCallInst(
- (IRType*)originalElementType,
- loweredElementTypeInfo.convertLoweredToOriginal,
- 1,
- &newLoad);
+ auto unpackedVal =
+ loweredElementTypeInfo.convertLoweredToOriginal.apply(
+ builder,
+ (IRType*)originalElementType,
+ newLoad);
auto var = builder.emitVar((IRType*)originalElementType);
builder.emitStore(var, unpackedVal);
use->set(var);
builder.setInsertAfter(user);
auto newVal = builder.emitLoad(var);
- auto packedVal = builder.emitCallInst(
- (IRType*)loweredElementTypeInfo.loweredType,
- loweredElementTypeInfo.convertOriginalToLowered,
- 1,
- &newVal);
+ auto packedVal =
+ loweredElementTypeInfo.convertOriginalToLowered.apply(
+ builder,
+ (IRType*)loweredElementTypeInfo.loweredType,
+ newVal);
builder.emitStore(ptrVal, packedVal);
}
break;
@@ -1148,7 +1307,7 @@ struct LoweredElementTypeContext
void lowerBufferElementTypeToStorageType(
TargetProgram* target,
IRModule* module,
- bool lowerBufferPointer)
+ BufferElementTypeLoweringOptions options)
{
SlangMatrixLayoutMode defaultMatrixMode =
(SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
@@ -1157,7 +1316,7 @@ void lowerBufferElementTypeToStorageType(
defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode);
+ LoweredElementTypeContext context(target, options, defaultMatrixMode);
context.processModule(module);
}
diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h
index d6082798f..2c69c5476 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.h
+++ b/source/slang/slang-ir-lower-buffer-element-type.h
@@ -8,6 +8,14 @@ class TargetProgram;
struct IRTypeLayoutRules;
struct IRType;
+struct BufferElementTypeLoweringOptions
+{
+ bool lowerBufferPointer = false;
+
+ // For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers.
+ bool use16ByteArrayElementForConstantBuffer = false;
+};
+
// For each struct type S used as element type of a ConstantBuffer, ParameterBlock or
// [RW]StructuredBuffer, we create a lowered type L, where matrix types are lowered to arrays of
// vectors based on major-ness, and loads from the buffer are converted to L_to_S(load(buffer)), and
@@ -18,7 +26,7 @@ struct IRType;
void lowerBufferElementTypeToStorageType(
TargetProgram* target,
IRModule* module,
- bool lowerBufferPointer = false);
+ BufferElementTypeLoweringOptions options = BufferElementTypeLoweringOptions());
// Returns the type layout rules should be used for a buffer resource type.
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index ff1ddadca..4baa28d67 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -2205,7 +2205,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// pointers in this pass. In the future we should consider separate out IRAddress as
// the type for IRVar, and use IRPtrType to dedicate pointers in user code, so we can
// safely lower the pointer load stores early together with other buffer types.
- lowerBufferElementTypeToStorageType(m_sharedContext->m_targetProgram, m_module, true);
+ BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
+ bufferElementTypeLoweringOptions.lowerBufferPointer = true;
+ lowerBufferElementTypeToStorageType(
+ m_sharedContext->m_targetProgram,
+ m_module,
+ bufferElementTypeLoweringOptions);
// The above step may produce empty struct types, so we need to lower them out of existence.
legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 2cbafea6c..823b3cd7d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2264,6 +2264,11 @@ IRInst* IRBuilder::getBoolValue(bool inValue)
return _findOrEmitConstant(keyInst);
}
+IRInst* IRBuilder::getIntValue(IRIntegerValue value)
+{
+ return getIntValue(getIntType(), value);
+}
+
IRInst* IRBuilder::getIntValue(IRType* type, IRIntegerValue inValue)
{
IRConstant keyInst;