summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-10-03 12:52:26 -0700
committerGitHub <noreply@github.com>2025-10-03 19:52:26 +0000
commit6a2cf239a89340ed2985d04609499e8c4a2d8f89 (patch)
treec80f3c8dc7e89762aeab0ee7830d1ad728665460 /source/slang/slang-ir-lower-buffer-element-type.cpp
parentcc8f6a241edb47c43c5698ee33abed4fe57d4566 (diff)
Fix legalization crash when processing metal parameter blocks. (#8591)
Closes #7606. When Slang compile for a bindful target, we will run the resource type legalization pass to hoist resource typed struct fields outside of the struct type and define them as global parameters and passing them around via dedicated function parameters. When we compile for a bindless target, we don't run this pass. However, Metal is a hybrid bindful and bindless target. We need to run type legalization for the constant buffer, but skip type legalization for parameter block. The previous attempt to support this behavior is to hack the type legalization pass to return `LegalVal::simple` when it sees a `ParameterBlock<T>`. However, whenever the code is accessing `parameterBlock.someNestedField`, the type of the nested field may get a `LegalType::tuple`, and now we will run into inconsistent scenarios where we have a `LegalVal::simple` on the operand val, and but the legalization logic is expecting that val to be a `LegalType::tuple`. This breaks a lot of assumptions and invariants in the type legalization pass, resulting unstable/fragile behavior. To systematically solve this problem, this change generalizes the existing legalize buffer element type pass to translate `ParameterBlock<Texture2D>` (and similar cases) to `ParameterBlock<Texture2D.Handle>`. So that such parameter block will always be legalized to `LegalType:::simple` during type legalization, and we will never run into any inconsistent cases. This allowed us to get rid of the hacky logic in the type legalization pass to try to workaround the inconsistencies.
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp1058
1 files changed, 615 insertions, 443 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 83218bade..c69592939 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -214,104 +214,173 @@
namespace Slang
{
+enum ConversionMethodKind
+{
+ Func,
+ Opcode
+};
+
+struct ConversionMethod
+{
+ ConversionMethodKind kind = ConversionMethodKind::Func;
+ union
+ {
+ IRFunc* func;
+ IROp op;
+ };
+ ConversionMethod() { func = nullptr; }
+ operator bool()
+ {
+ return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
+ }
+ ConversionMethod& operator=(IRFunc* f)
+ {
+ kind = ConversionMethodKind::Func;
+ this->func = f;
+ return *this;
+ }
+ ConversionMethod& operator=(IROp irop)
+ {
+ kind = ConversionMethodKind::Opcode;
+ this->op = irop;
+ return *this;
+ }
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr);
+ void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand);
+};
+
struct TypeLoweringConfig
{
AddressSpace addressSpace;
- IRTypeLayoutRules* layoutRule;
+ IRTypeLayoutRuleName layoutRuleName;
+ IRTypeLayoutRules* getLayoutRule() const { return IRTypeLayoutRules::get(layoutRuleName); }
+
bool operator==(const TypeLoweringConfig& other) const
{
- return addressSpace == other.addressSpace && layoutRule == other.layoutRule;
+ return addressSpace == other.addressSpace && layoutRuleName == other.layoutRuleName;
}
HashCode getHashCode() const
{
- return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule));
+ return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRuleName));
}
};
+
+struct LoweredElementTypeInfo
+{
+ IRType* originalType;
+ IRType* loweredType;
+ IRType* loweredInnerArrayType =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // inner array type of the data field.
+ IRStructKey* loweredInnerStructKey =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // struct key of the data field.
+ ConversionMethod convertOriginalToLowered;
+ ConversionMethod convertLoweredToOriginal;
+};
+
+/// Defines target-specific behavior of how to lower buffer element types.
+struct BufferElementTypeLoweringPolicy : public RefObject
+{
+ /// Defines target-specific behavior of how to translate a non-composite logical type to a
+ /// storage type.
+ virtual LoweredElementTypeInfo lowerLeafLogicalType(
+ IRType* type,
+ TypeLoweringConfig config) = 0;
+
+ /// Returns true if we should always create a fresh lowered storage type for a composite type,
+ /// even if every member/element of the composite type is not changed by the lowering.
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(TypeLoweringConfig config)
+ {
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ /// Returns true if the target requires all array of scalars or vectors inside a constant buffer
+ /// to be translated into a 16-byte aligned vector type.
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()
+ {
+ return false;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options);
+
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType);
-struct LoweredElementTypeContext
+IRInst* ConversionMethod::apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
{
- static const IRIntegerValue kMaxArraySizeToUnroll = 32;
+ if (!*this)
+ return builder.emitLoad(operandAddr);
+ if (kind == ConversionMethodKind::Func)
+ return builder.emitCallInst(resultType, func, 1, &operandAddr);
+ else
+ {
+ auto val = builder.emitLoad(operandAddr);
+ return builder.emitIntrinsicInst(resultType, op, 1, &val);
+ }
+}
- enum ConversionMethodKind
+void ConversionMethod::applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
+{
+ if (!*this)
{
- Func,
- Opcode
- };
- struct ConversionMethod
+ builder.emitStore(dest, operand);
+ return;
+ }
+ if (kind == ConversionMethodKind::Func)
{
- ConversionMethodKind kind = ConversionMethodKind::Func;
- union
- {
- IRFunc* func;
- IROp op;
- };
- ConversionMethod() { func = nullptr; }
- operator bool()
- {
- return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
- }
- ConversionMethod& operator=(IRFunc* f)
- {
- kind = ConversionMethodKind::Func;
- this->func = f;
- return *this;
- }
- ConversionMethod& operator=(IROp irop)
- {
- kind = ConversionMethodKind::Opcode;
- this->op = irop;
- return *this;
- }
- IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
- {
- if (!*this)
- return builder.emitLoad(operandAddr);
- if (kind == ConversionMethodKind::Func)
- return builder.emitCallInst(resultType, func, 1, &operandAddr);
- else
- {
- auto val = builder.emitLoad(operandAddr);
- return builder.emitIntrinsicInst(resultType, op, 1, &val);
- }
- }
- void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
- {
- if (!*this)
- {
- builder.emitStore(dest, operand);
- return;
- }
- if (kind == ConversionMethodKind::Func)
- {
- IRInst* operands[] = {dest, operand};
- builder.emitCallInst(builder.getVoidType(), func, 2, operands);
- }
- else
- {
- auto val = builder.emitIntrinsicInst(
- tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
- op,
- 1,
- &operand);
- builder.emitStore(dest, val);
- }
- }
- };
+ IRInst* operands[] = {dest, operand};
+ builder.emitCallInst(builder.getVoidType(), func, 2, operands);
+ }
+ else
+ {
+ auto val = builder.emitIntrinsicInst(
+ tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
+ op,
+ 1,
+ &operand);
+ builder.emitStore(dest, val);
+ }
+}
+
+// Returns the number of elements N that ensures the IRVectorType(elementType,N)
+// has 16-byte aligned size and N is no less than `minCount`.
+IRIntegerValue get16ByteAlignedVectorElementCount(
+ TargetProgram* target,
+ IRType* elementType,
+ IRIntegerValue minCount)
+{
+ IRSizeAndAlignment sizeAlignment;
+ getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
+ if (sizeAlignment.size)
+ return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
+ return 4;
+}
- struct LoweredElementTypeInfo
+const char* getLayoutName(IRTypeLayoutRuleName name)
+{
+ switch (name)
{
- IRType* originalType;
- IRType* loweredType;
- IRType* loweredInnerArrayType =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // inner array type of the data field.
- IRStructKey* loweredInnerStructKey =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // struct key of the data field.
- ConversionMethod convertOriginalToLowered;
- ConversionMethod convertLoweredToOriginal;
- };
+ case IRTypeLayoutRuleName::Std140:
+ return "std140";
+ case IRTypeLayoutRuleName::Std430:
+ return "std430";
+ case IRTypeLayoutRuleName::Natural:
+ return "natural";
+ case IRTypeLayoutRuleName::C:
+ return "c";
+ default:
+ return "default";
+ }
+}
+
+struct LoweredElementTypeContext
+{
+ static const IRIntegerValue kMaxArraySizeToUnroll = 32;
struct LoweredTypeMap : RefObject
{
@@ -320,6 +389,7 @@ struct LoweredElementTypeContext
};
Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps;
+ RefPtr<BufferElementTypeLoweringPolicy> leafTypeLoweringPolicy;
struct ConversionMethodKey
{
@@ -366,150 +436,11 @@ struct LoweredElementTypeContext
// Specialized functions that takes storage-typed pointers instead of logical-typed pointers.
Dictionary<SpecializationKey, IRFunc*> specializedFuncs;
- LoweredElementTypeContext(
- TargetProgram* target,
- BufferElementTypeLoweringOptions inOptions,
- SlangMatrixLayoutMode inDefaultMatrixLayout)
- : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions)
- {
- }
-
- IRFunc* createMatrixUnpackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRStructKey* dataKey)
- {
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = (Index)getIntVal(matrixType->getRowCount());
- auto colCount = (Index)getIntVal(matrixType->getColumnCount());
- auto packedParamRef = builder.emitParam(refStructType);
- auto packedParam = builder.emitLoad(packedParamRef);
- auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
- List<IRInst*> args;
- args.setCount(rowCount * colCount);
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto vector = builder.emitElementExtract(vectorArray, c);
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = builder.emitElementExtract(vector, r);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(vectorArray, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- IRInst* result =
- builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
- builder.emitReturn(result);
- return func;
- }
-
- IRFunc* createMatrixPackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRVectorType* vectorType,
- IRArrayType* arrayType)
+ LoweredElementTypeContext(TargetProgram* target, BufferElementTypeLoweringOptions inOptions)
+ : target(target), options(inOptions)
{
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- IRType* paramTypes[] = {outStructType, matrixType};
- auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = getIntVal(matrixType->getRowCount());
- auto colCount = getIntVal(matrixType->getColumnCount());
- auto outParam = builder.emitParam(outStructType);
- auto originalParam = builder.emitParam(matrixType);
- List<IRInst*> elements;
- elements.setCount((Index)(rowCount * colCount));
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(originalParam, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- elements[(Index)(r * colCount + c)] = element;
- }
- }
- List<IRInst*> vectors;
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto colVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(colVector);
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto rowVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(rowVector);
- }
- }
-
- auto vectorArray =
- builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
- auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
- builder.emitStore(outParam, result);
- builder.emitReturn();
- return func;
+ leafTypeLoweringPolicy =
+ getBufferElementTypeLoweringPolicy(options.loweringPolicyKind, target, options);
}
IRFunc* createArrayUnpackFunc(
@@ -637,56 +568,6 @@ struct LoweredElementTypeContext
return func;
}
- const char* getLayoutName(IRTypeLayoutRuleName name)
- {
- switch (name)
- {
- case IRTypeLayoutRuleName::Std140:
- return "std140";
- case IRTypeLayoutRuleName::Std430:
- return "std430";
- case IRTypeLayoutRuleName::Natural:
- return "natural";
- case IRTypeLayoutRuleName::C:
- return "c";
- default:
- return "default";
- }
- }
-
- // Returns the number of elements N that ensures the IRVectorType(elementType,N)
- // has 16-byte aligned size and N is no less than `minCount`.
- IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount)
- {
- IRSizeAndAlignment sizeAlignment;
- getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
- if (sizeAlignment.size)
- return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
- return 4;
- }
-
- bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
- {
- // For spirv, we always want to lower all matrix types, because SPIRV does not support
- // specifying matrix layout/stride if the matrix type is used in places other than
- // defining a struct field. This means that if a matrix is used to define a varying
- // parameter, we always want to wrap it in a struct.
- //
- if (target->shouldEmitSPIRVDirectly())
- {
- return true;
- }
-
- if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
- {
- // For other targets, we only lower the matrix types if they differ from the default
- // matrix layout.
- return false;
- }
- return true;
- }
-
LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
@@ -694,72 +575,13 @@ struct LoweredElementTypeContext
LoweredElementTypeInfo info;
info.originalType = type;
-
- if (auto matrixType = as<IRMatrixType>(type))
- {
- if (!shouldLowerMatrixType(matrixType, config))
- {
- info.loweredType = type;
- return info;
- }
-
- auto loweredType = builder.createStructType();
- builder.addPhysicalTypeDecoration(loweredType);
-
- StringBuilder nameSB;
- bool isColMajor =
- getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
- nameSB << "_MatrixStorage_";
- getTypeNameHint(nameSB, matrixType->getElementType());
- nameSB << getIntVal(matrixType->getRowCount()) << "x"
- << getIntVal(matrixType->getColumnCount());
- if (isColMajor)
- nameSB << "_ColMajor";
- nameSB << getLayoutName(config.layoutRule->ruleName);
- builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
- auto structKey = builder.createStructKey();
- builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
- {
- // For constant buffer layout, we need to use 16-byte aligned vector if
- // we are required to ensure array element types has 16-byte stride.
- vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
- matrixType->getElementType(),
- getIntVal(vectorSize)));
- }
-
- auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
- IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- vectorType,
- &elementSizeAlignment);
- elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
-
- auto arrayType = builder.getArrayType(
- vectorType,
- isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
- builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
- builder.createStructField(loweredType, structKey, arrayType);
-
- info.loweredType = loweredType;
- info.loweredInnerArrayType = arrayType;
- info.loweredInnerStructKey = structKey;
- info.convertLoweredToOriginal =
- createMatrixUnpackFunc(matrixType, loweredType, structKey);
- info.convertOriginalToLowered =
- createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
- return info;
- }
- else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
+ if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
{
auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config);
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ leafTypeLoweringPolicy
+ ->shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
{
// For constant buffer layout, we need to use 16-byte-aligned vector if
// we are required to ensure array element types has 16-byte stride.
@@ -772,6 +594,7 @@ struct LoweredElementTypeContext
packedVectorType = builder.getVectorType(
vectorType->getElementType(),
builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
vectorType->getElementType(),
getIntVal(vectorType->getElementCount()))));
if (packedVectorType != loweredInnerTypeInfo.originalType)
@@ -784,7 +607,7 @@ struct LoweredElementTypeContext
{
packedVectorType = builder.getVectorType(
loweredInnerTypeInfo.loweredType,
- get16ByteAlignedVectorElementCount(scalarType, 1));
+ get16ByteAlignedVectorElementCount(target, scalarType, 1));
loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape;
loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar;
}
@@ -803,10 +626,10 @@ struct LoweredElementTypeContext
}
}
- // For spirv backend, we always want to lower all array types for non-varying
- // parameters, even if the element type comes out the same. This is because different
- // layout rules may have different array stride requirements.
- if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input)
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
@@ -823,7 +646,7 @@ struct LoweredElementTypeContext
info.loweredType = loweredType;
StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
+ nameSB << "_Array_" << getLayoutName(config.layoutRuleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
nameSB << getArraySizeVal(arrayType->getElementCount());
@@ -835,11 +658,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayType(
loweredInnerTypeInfo.loweredType,
arrayType->getElementCount(),
@@ -857,11 +680,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayTypeBase(
arrayTypeBase->getOp(),
loweredInnerTypeInfo.loweredType,
@@ -880,20 +703,15 @@ struct LoweredElementTypeContext
auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
if (loweredFieldTypeInfo.convertLoweredToOriginal ||
- config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural)
+ config.layoutRuleName != IRTypeLayoutRuleName::Natural)
isTrivial = false;
}
- // For spirv backend, we always want to lower all array types, even if the element type
- // comes out the same. This is because different layout rules may have different array
- // stride requirements.
- //
- // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
- // GLSL.
- if (!isKhronosTarget(target->getTargetReq()))
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
- // For non-spirv target, we skip lowering this type if all field types are
- // unchanged.
if (isTrivial)
{
info.loweredType = type;
@@ -905,7 +723,7 @@ struct LoweredElementTypeContext
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
+ nameSB << "_" << getLayoutName(config.layoutRuleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
// Create fields.
@@ -1007,58 +825,7 @@ struct LoweredElementTypeContext
return info;
}
-
- if (target->shouldEmitSPIRVDirectly())
- {
- switch (target->getTargetReq()->getTarget())
- {
- case CodeGenTarget::SPIRV:
- case CodeGenTarget::SPIRVAssembly:
- {
- auto scalarType = type;
- auto vectorType = as<IRVectorType>(scalarType);
- if (vectorType)
- scalarType = vectorType->getElementType();
-
- if (as<IRBoolType>(scalarType))
- {
- // Bool is an abstract type in SPIRV, so we need to lower them into an int.
-
- // Find an integer type of the correct size for the current layout rule.
- IRSizeAndAlignment boolSizeAndAlignment;
- if (getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- scalarType,
- &boolSizeAndAlignment) == SLANG_OK)
- {
- IntInfo ii;
- ii.width = boolSizeAndAlignment.size * 8;
- ii.isSigned = true;
- info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
- }
- else
- {
- // Just in case that fails for some reason, just use an int.
- info.loweredType = builder.getIntType();
- }
-
- if (vectorType)
- info.loweredType = builder.getVectorType(
- info.loweredType,
- vectorType->getElementCount());
- info.convertLoweredToOriginal = kIROp_BuiltinCast;
- info.convertOriginalToLowered = kIROp_BuiltinCast;
- return info;
- }
- }
- default:
- break;
- }
- }
-
- info.loweredType = type;
- return info;
+ return leafTypeLoweringPolicy->lowerLeafLogicalType(type, config);
}
LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config)
@@ -1090,7 +857,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
info.loweredType,
&sizeAlignment);
loweredTypeInfo.set(type, info);
@@ -1190,13 +957,13 @@ struct LoweredElementTypeContext
IRSizeAndAlignment arrayElementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerType.loweredType,
&arrayElementSizeAlignment);
IRSizeAndAlignment baseSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()),
&baseSizeAlignment);
@@ -1701,9 +1468,6 @@ struct LoweredElementTypeContext
switch (ptrType->getAddressSpace())
{
case AddressSpace::UserPointer:
- if (!options.lowerBufferPointer)
- continue;
- [[fallthrough]];
case AddressSpace::Input:
case AddressSpace::Output:
elementType = ptrType->getValueType();
@@ -1720,7 +1484,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
elementType,
&sizeAlignment);
SLANG_UNUSED(sizeAlignment);
@@ -2181,63 +1945,60 @@ void lowerBufferElementTypeToStorageType(
IRModule* module,
BufferElementTypeLoweringOptions options)
{
- SlangMatrixLayoutMode defaultMatrixMode =
- (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
- if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
- isMetalTarget(target->getTargetReq())))
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(target, options, defaultMatrixMode);
+ LoweredElementTypeContext context(target, options);
context.processModule(module);
}
-IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout)
+IRTypeLayoutRuleName getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRuleName defaultLayout)
{
switch (layoutTypeOp)
{
case kIROp_DefaultBufferLayoutType:
return defaultLayout;
case kIROp_Std140BufferLayoutType:
- return IRTypeLayoutRules::getStd140();
+ return IRTypeLayoutRuleName::Std140;
case kIROp_Std430BufferLayoutType:
- return IRTypeLayoutRules::getStd430();
+ return IRTypeLayoutRuleName::Std430;
case kIROp_ScalarBufferLayoutType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
case kIROp_CBufferLayoutType:
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
}
return defaultLayout;
}
-IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType)
{
+ if (bufferType->getOp() == kIROp_ParameterBlockType && isMetalTarget(target->getTargetReq()))
+ {
+ return IRTypeLayoutRuleName::MetalParameterBlock;
+ }
if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL)
{
if (!isKhronosTarget(target->getTargetReq()))
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If we are just emitting GLSL, we can just use the general layout rule.
if (!target->shouldEmitSPIRVDirectly())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If the user specified a C-compatible buffer layout, then do that.
if (target->getOptionSet().shouldUseCLayout())
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
// If the user specified a scalar buffer layout, then just use that.
if (target->getOptionSet().shouldUseScalarLayout())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
if (target->getOptionSet().shouldUseDXLayout())
{
if (as<IRUniformParameterGroupType>(bufferType))
{
- return IRTypeLayoutRules::getConstantBuffer();
+ return IRTypeLayoutRuleName::D3DConstantBuffer;
}
else
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
// The default behavior is to use std140 for constant buffers and std430 for other buffers.
@@ -2253,17 +2014,17 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = structBufferType->getDataLayout()
? structBufferType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
- case kIROp_ConstantBufferType:
case kIROp_ParameterBlockType:
+ case kIROp_ConstantBufferType:
{
auto parameterGroupType = as<IRUniformParameterGroupType>(bufferType);
auto layoutTypeOp = parameterGroupType->getDataLayout()
? parameterGroupType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std140);
}
case kIROp_GLSLShaderStorageBufferType:
{
@@ -2271,12 +2032,18 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = storageBufferType->getDataLayout()
? storageBufferType->getDataLayout()->getOp()
: kIROp_Std430BufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
case kIROp_PtrType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
+}
+
+IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+{
+ auto ruleName = getTypeLayoutRuleNameForBuffer(target, bufferType);
+ return IRTypeLayoutRules::get(ruleName);
}
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType)
@@ -2295,8 +2062,413 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType*
break;
}
}
- auto rules = getTypeLayoutRuleForBuffer(target, bufferType);
+ auto rules = getTypeLayoutRuleNameForBuffer(target, bufferType);
return TypeLoweringConfig{addrSpace, rules};
}
+struct DefaultBufferElementTypeLoweringPolicy : BufferElementTypeLoweringPolicy
+{
+ TargetProgram* target;
+ BufferElementTypeLoweringOptions options;
+ SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+
+ DefaultBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : target(inTarget), options(inOptions)
+ {
+ defaultMatrixLayout = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
+ if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
+ isMetalTarget(target->getTargetReq())))
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ else if (defaultMatrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
+ {
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
+ config.getLayoutRule()->ruleName == IRTypeLayoutRuleName::Natural)
+ {
+ // We only lower the matrix types if they differ from the default
+ // matrix layout.
+ return false;
+ }
+ return true;
+ }
+
+ IRFunc* createMatrixUnpackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRStructKey* dataKey)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = (Index)getIntVal(matrixType->getRowCount());
+ auto colCount = (Index)getIntVal(matrixType->getColumnCount());
+ auto packedParamRef = builder.emitParam(refStructType);
+ auto packedParam = builder.emitLoad(packedParamRef);
+ auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
+ List<IRInst*> args;
+ args.setCount(rowCount * colCount);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, c);
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = builder.emitElementExtract(vector, r);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ IRInst* result =
+ builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createMatrixPackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRVectorType* vectorType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outStructType, matrixType};
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto outParam = builder.emitParam(outStructType);
+ auto originalParam = builder.emitParam(matrixType);
+ List<IRInst*> elements;
+ elements.setCount((Index)(rowCount * colCount));
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(originalParam, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ elements[(Index)(r * colCount + c)] = element;
+ }
+ }
+ List<IRInst*> vectors;
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto colVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(colVector);
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto rowVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(rowVector);
+ }
+ }
+
+ auto vectorArray =
+ builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
+ builder.emitStore(outParam, result);
+ builder.emitReturn();
+ return func;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ IRBuilder builder(type);
+ builder.setInsertAfter(type);
+
+ LoweredElementTypeInfo info;
+ info.originalType = type;
+
+ if (auto matrixType = as<IRMatrixType>(type))
+ {
+ if (!shouldLowerMatrixType(matrixType, config))
+ {
+ info.loweredType = type;
+ return info;
+ }
+
+ auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
+ StringBuilder nameSB;
+ bool isColMajor =
+ getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ nameSB << "_MatrixStorage_";
+ getTypeNameHint(nameSB, matrixType->getElementType());
+ nameSB << getIntVal(matrixType->getRowCount()) << "x"
+ << getIntVal(matrixType->getColumnCount());
+ if (isColMajor)
+ nameSB << "_ColMajor";
+ nameSB << getLayoutName(config.layoutRuleName);
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
+ {
+ // For constant buffer layout, we need to use 16-byte aligned vector if
+ // we are required to ensure array element types has 16-byte stride.
+ vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
+ matrixType->getElementType(),
+ getIntVal(vectorSize)));
+ }
+
+ auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ vectorType,
+ &elementSizeAlignment);
+ elementSizeAlignment =
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
+
+ auto arrayType = builder.getArrayType(
+ vectorType,
+ isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
+ builder.createStructField(loweredType, structKey, arrayType);
+
+ info.loweredType = loweredType;
+ info.loweredInnerArrayType = arrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal =
+ createMatrixUnpackFunc(matrixType, loweredType, structKey);
+ info.convertOriginalToLowered =
+ createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
+ return info;
+ }
+
+ info.loweredType = type;
+ return info;
+ }
+};
+
+struct KhronosTargetBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ KhronosTargetBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ // For spirv, we always want to lower all matrix types, because SPIRV does not support
+ // specifying matrix layout/stride if the matrix type is used in places other than
+ // defining a struct field. This means that if a matrix is used to define a varying
+ // parameter, we always want to wrap it in a struct.
+ //
+ if (target->shouldEmitSPIRVDirectly())
+ return true;
+ return DefaultBufferElementTypeLoweringPolicy::shouldLowerMatrixType(matrixType, config);
+ }
+
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ TypeLoweringConfig config) override
+ {
+ // For spirv backend, we always want to lower all array types, even if the element type
+ // comes out the same. This is because different layout rules may have different array
+ // stride requirements.
+ //
+ // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
+ // GLSL.
+ return target->shouldEmitSPIRVDirectly() && config.addressSpace != AddressSpace::Input;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (target->shouldEmitSPIRVDirectly())
+ {
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+
+ switch (target->getTargetReq()->getTarget())
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ {
+ auto scalarType = type;
+ auto vectorType = as<IRVectorType>(scalarType);
+ if (vectorType)
+ scalarType = vectorType->getElementType();
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+
+ if (as<IRBoolType>(scalarType))
+ {
+ // Bool is an abstract type in SPIRV, so we need to lower them into an int.
+
+ // Find an integer type of the correct size for the current layout rule.
+ IRSizeAndAlignment boolSizeAndAlignment;
+ if (getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ scalarType,
+ &boolSizeAndAlignment) == SLANG_OK)
+ {
+ IntInfo ii;
+ ii.width = boolSizeAndAlignment.size * 8;
+ ii.isSigned = true;
+ info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
+ }
+ else
+ {
+ // Just in case that fails for some reason, just use an int.
+ info.loweredType = builder.getIntType();
+ }
+
+ if (vectorType)
+ info.loweredType = builder.getVectorType(
+ info.loweredType,
+ vectorType->getElementCount());
+ info.convertLoweredToOriginal = kIROp_BuiltinCast;
+ info.convertOriginalToLowered = kIROp_BuiltinCast;
+ return info;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct MetalParameterBlockElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ MetalParameterBlockElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ SLANG_UNUSED(matrixType);
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (config.layoutRuleName == IRTypeLayoutRuleName::MetalParameterBlock &&
+ isResourceType(type))
+ {
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+ info.loweredType = builder.getType(kIROp_DescriptorHandleType, type);
+ info.convertLoweredToOriginal = kIROp_CastDescriptorHandleToResource;
+ info.convertOriginalToLowered = kIROp_CastResourceToDescriptorHandle;
+ return info;
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct WGSLBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ WGSLBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() override
+ {
+ return true;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options)
+{
+ switch (kind)
+ {
+ case BufferElementTypeLoweringPolicyKind::Default:
+ return new DefaultBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::KhronosTarget:
+ return new KhronosTargetBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::MetalParameterBlock:
+ return new MetalParameterBlockElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::WGSL:
+ return new WGSLBufferElementTypeLoweringPolicy(target, options);
+ }
+ SLANG_UNREACHABLE("unknown buffer element type lowering policy");
+}
+
} // namespace Slang