summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp1058
1 files changed, 615 insertions, 443 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 83218bade..c69592939 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -214,104 +214,173 @@
namespace Slang
{
+enum ConversionMethodKind
+{
+ Func,
+ Opcode
+};
+
+struct ConversionMethod
+{
+ ConversionMethodKind kind = ConversionMethodKind::Func;
+ union
+ {
+ IRFunc* func;
+ IROp op;
+ };
+ ConversionMethod() { func = nullptr; }
+ operator bool()
+ {
+ return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
+ }
+ ConversionMethod& operator=(IRFunc* f)
+ {
+ kind = ConversionMethodKind::Func;
+ this->func = f;
+ return *this;
+ }
+ ConversionMethod& operator=(IROp irop)
+ {
+ kind = ConversionMethodKind::Opcode;
+ this->op = irop;
+ return *this;
+ }
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr);
+ void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand);
+};
+
struct TypeLoweringConfig
{
AddressSpace addressSpace;
- IRTypeLayoutRules* layoutRule;
+ IRTypeLayoutRuleName layoutRuleName;
+ IRTypeLayoutRules* getLayoutRule() const { return IRTypeLayoutRules::get(layoutRuleName); }
+
bool operator==(const TypeLoweringConfig& other) const
{
- return addressSpace == other.addressSpace && layoutRule == other.layoutRule;
+ return addressSpace == other.addressSpace && layoutRuleName == other.layoutRuleName;
}
HashCode getHashCode() const
{
- return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule));
+ return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRuleName));
}
};
+
+struct LoweredElementTypeInfo
+{
+ IRType* originalType;
+ IRType* loweredType;
+ IRType* loweredInnerArrayType =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // inner array type of the data field.
+ IRStructKey* loweredInnerStructKey =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // struct key of the data field.
+ ConversionMethod convertOriginalToLowered;
+ ConversionMethod convertLoweredToOriginal;
+};
+
+/// Defines target-specific behavior of how to lower buffer element types.
+struct BufferElementTypeLoweringPolicy : public RefObject
+{
+ /// Defines target-specific behavior of how to translate a non-composite logical type to a
+ /// storage type.
+ virtual LoweredElementTypeInfo lowerLeafLogicalType(
+ IRType* type,
+ TypeLoweringConfig config) = 0;
+
+ /// Returns true if we should always create a fresh lowered storage type for a composite type,
+ /// even if every member/element of the composite type is not changed by the lowering.
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(TypeLoweringConfig config)
+ {
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ /// Returns true if the target requires all array of scalars or vectors inside a constant buffer
+ /// to be translated into a 16-byte aligned vector type.
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()
+ {
+ return false;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options);
+
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType);
-struct LoweredElementTypeContext
+IRInst* ConversionMethod::apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
{
- static const IRIntegerValue kMaxArraySizeToUnroll = 32;
+ if (!*this)
+ return builder.emitLoad(operandAddr);
+ if (kind == ConversionMethodKind::Func)
+ return builder.emitCallInst(resultType, func, 1, &operandAddr);
+ else
+ {
+ auto val = builder.emitLoad(operandAddr);
+ return builder.emitIntrinsicInst(resultType, op, 1, &val);
+ }
+}
- enum ConversionMethodKind
+void ConversionMethod::applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
+{
+ if (!*this)
{
- Func,
- Opcode
- };
- struct ConversionMethod
+ builder.emitStore(dest, operand);
+ return;
+ }
+ if (kind == ConversionMethodKind::Func)
{
- ConversionMethodKind kind = ConversionMethodKind::Func;
- union
- {
- IRFunc* func;
- IROp op;
- };
- ConversionMethod() { func = nullptr; }
- operator bool()
- {
- return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
- }
- ConversionMethod& operator=(IRFunc* f)
- {
- kind = ConversionMethodKind::Func;
- this->func = f;
- return *this;
- }
- ConversionMethod& operator=(IROp irop)
- {
- kind = ConversionMethodKind::Opcode;
- this->op = irop;
- return *this;
- }
- IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
- {
- if (!*this)
- return builder.emitLoad(operandAddr);
- if (kind == ConversionMethodKind::Func)
- return builder.emitCallInst(resultType, func, 1, &operandAddr);
- else
- {
- auto val = builder.emitLoad(operandAddr);
- return builder.emitIntrinsicInst(resultType, op, 1, &val);
- }
- }
- void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
- {
- if (!*this)
- {
- builder.emitStore(dest, operand);
- return;
- }
- if (kind == ConversionMethodKind::Func)
- {
- IRInst* operands[] = {dest, operand};
- builder.emitCallInst(builder.getVoidType(), func, 2, operands);
- }
- else
- {
- auto val = builder.emitIntrinsicInst(
- tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
- op,
- 1,
- &operand);
- builder.emitStore(dest, val);
- }
- }
- };
+ IRInst* operands[] = {dest, operand};
+ builder.emitCallInst(builder.getVoidType(), func, 2, operands);
+ }
+ else
+ {
+ auto val = builder.emitIntrinsicInst(
+ tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
+ op,
+ 1,
+ &operand);
+ builder.emitStore(dest, val);
+ }
+}
+
+// Returns the number of elements N that ensures the IRVectorType(elementType,N)
+// has 16-byte aligned size and N is no less than `minCount`.
+IRIntegerValue get16ByteAlignedVectorElementCount(
+ TargetProgram* target,
+ IRType* elementType,
+ IRIntegerValue minCount)
+{
+ IRSizeAndAlignment sizeAlignment;
+ getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
+ if (sizeAlignment.size)
+ return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
+ return 4;
+}
- struct LoweredElementTypeInfo
+const char* getLayoutName(IRTypeLayoutRuleName name)
+{
+ switch (name)
{
- IRType* originalType;
- IRType* loweredType;
- IRType* loweredInnerArrayType =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // inner array type of the data field.
- IRStructKey* loweredInnerStructKey =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // struct key of the data field.
- ConversionMethod convertOriginalToLowered;
- ConversionMethod convertLoweredToOriginal;
- };
+ case IRTypeLayoutRuleName::Std140:
+ return "std140";
+ case IRTypeLayoutRuleName::Std430:
+ return "std430";
+ case IRTypeLayoutRuleName::Natural:
+ return "natural";
+ case IRTypeLayoutRuleName::C:
+ return "c";
+ default:
+ return "default";
+ }
+}
+
+struct LoweredElementTypeContext
+{
+ static const IRIntegerValue kMaxArraySizeToUnroll = 32;
struct LoweredTypeMap : RefObject
{
@@ -320,6 +389,7 @@ struct LoweredElementTypeContext
};
Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps;
+ RefPtr<BufferElementTypeLoweringPolicy> leafTypeLoweringPolicy;
struct ConversionMethodKey
{
@@ -366,150 +436,11 @@ struct LoweredElementTypeContext
// Specialized functions that takes storage-typed pointers instead of logical-typed pointers.
Dictionary<SpecializationKey, IRFunc*> specializedFuncs;
- LoweredElementTypeContext(
- TargetProgram* target,
- BufferElementTypeLoweringOptions inOptions,
- SlangMatrixLayoutMode inDefaultMatrixLayout)
- : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions)
- {
- }
-
- IRFunc* createMatrixUnpackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRStructKey* dataKey)
- {
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = (Index)getIntVal(matrixType->getRowCount());
- auto colCount = (Index)getIntVal(matrixType->getColumnCount());
- auto packedParamRef = builder.emitParam(refStructType);
- auto packedParam = builder.emitLoad(packedParamRef);
- auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
- List<IRInst*> args;
- args.setCount(rowCount * colCount);
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto vector = builder.emitElementExtract(vectorArray, c);
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = builder.emitElementExtract(vector, r);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(vectorArray, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- IRInst* result =
- builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
- builder.emitReturn(result);
- return func;
- }
-
- IRFunc* createMatrixPackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRVectorType* vectorType,
- IRArrayType* arrayType)
+ LoweredElementTypeContext(TargetProgram* target, BufferElementTypeLoweringOptions inOptions)
+ : target(target), options(inOptions)
{
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- IRType* paramTypes[] = {outStructType, matrixType};
- auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = getIntVal(matrixType->getRowCount());
- auto colCount = getIntVal(matrixType->getColumnCount());
- auto outParam = builder.emitParam(outStructType);
- auto originalParam = builder.emitParam(matrixType);
- List<IRInst*> elements;
- elements.setCount((Index)(rowCount * colCount));
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(originalParam, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- elements[(Index)(r * colCount + c)] = element;
- }
- }
- List<IRInst*> vectors;
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto colVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(colVector);
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto rowVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(rowVector);
- }
- }
-
- auto vectorArray =
- builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
- auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
- builder.emitStore(outParam, result);
- builder.emitReturn();
- return func;
+ leafTypeLoweringPolicy =
+ getBufferElementTypeLoweringPolicy(options.loweringPolicyKind, target, options);
}
IRFunc* createArrayUnpackFunc(
@@ -637,56 +568,6 @@ struct LoweredElementTypeContext
return func;
}
- const char* getLayoutName(IRTypeLayoutRuleName name)
- {
- switch (name)
- {
- case IRTypeLayoutRuleName::Std140:
- return "std140";
- case IRTypeLayoutRuleName::Std430:
- return "std430";
- case IRTypeLayoutRuleName::Natural:
- return "natural";
- case IRTypeLayoutRuleName::C:
- return "c";
- default:
- return "default";
- }
- }
-
- // Returns the number of elements N that ensures the IRVectorType(elementType,N)
- // has 16-byte aligned size and N is no less than `minCount`.
- IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount)
- {
- IRSizeAndAlignment sizeAlignment;
- getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
- if (sizeAlignment.size)
- return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
- return 4;
- }
-
- bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
- {
- // For spirv, we always want to lower all matrix types, because SPIRV does not support
- // specifying matrix layout/stride if the matrix type is used in places other than
- // defining a struct field. This means that if a matrix is used to define a varying
- // parameter, we always want to wrap it in a struct.
- //
- if (target->shouldEmitSPIRVDirectly())
- {
- return true;
- }
-
- if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
- {
- // For other targets, we only lower the matrix types if they differ from the default
- // matrix layout.
- return false;
- }
- return true;
- }
-
LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
@@ -694,72 +575,13 @@ struct LoweredElementTypeContext
LoweredElementTypeInfo info;
info.originalType = type;
-
- if (auto matrixType = as<IRMatrixType>(type))
- {
- if (!shouldLowerMatrixType(matrixType, config))
- {
- info.loweredType = type;
- return info;
- }
-
- auto loweredType = builder.createStructType();
- builder.addPhysicalTypeDecoration(loweredType);
-
- StringBuilder nameSB;
- bool isColMajor =
- getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
- nameSB << "_MatrixStorage_";
- getTypeNameHint(nameSB, matrixType->getElementType());
- nameSB << getIntVal(matrixType->getRowCount()) << "x"
- << getIntVal(matrixType->getColumnCount());
- if (isColMajor)
- nameSB << "_ColMajor";
- nameSB << getLayoutName(config.layoutRule->ruleName);
- builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
- auto structKey = builder.createStructKey();
- builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
- {
- // For constant buffer layout, we need to use 16-byte aligned vector if
- // we are required to ensure array element types has 16-byte stride.
- vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
- matrixType->getElementType(),
- getIntVal(vectorSize)));
- }
-
- auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
- IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- vectorType,
- &elementSizeAlignment);
- elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
-
- auto arrayType = builder.getArrayType(
- vectorType,
- isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
- builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
- builder.createStructField(loweredType, structKey, arrayType);
-
- info.loweredType = loweredType;
- info.loweredInnerArrayType = arrayType;
- info.loweredInnerStructKey = structKey;
- info.convertLoweredToOriginal =
- createMatrixUnpackFunc(matrixType, loweredType, structKey);
- info.convertOriginalToLowered =
- createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
- return info;
- }
- else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
+ if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
{
auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config);
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ leafTypeLoweringPolicy
+ ->shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
{
// For constant buffer layout, we need to use 16-byte-aligned vector if
// we are required to ensure array element types has 16-byte stride.
@@ -772,6 +594,7 @@ struct LoweredElementTypeContext
packedVectorType = builder.getVectorType(
vectorType->getElementType(),
builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
vectorType->getElementType(),
getIntVal(vectorType->getElementCount()))));
if (packedVectorType != loweredInnerTypeInfo.originalType)
@@ -784,7 +607,7 @@ struct LoweredElementTypeContext
{
packedVectorType = builder.getVectorType(
loweredInnerTypeInfo.loweredType,
- get16ByteAlignedVectorElementCount(scalarType, 1));
+ get16ByteAlignedVectorElementCount(target, scalarType, 1));
loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape;
loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar;
}
@@ -803,10 +626,10 @@ struct LoweredElementTypeContext
}
}
- // For spirv backend, we always want to lower all array types for non-varying
- // parameters, even if the element type comes out the same. This is because different
- // layout rules may have different array stride requirements.
- if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input)
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
@@ -823,7 +646,7 @@ struct LoweredElementTypeContext
info.loweredType = loweredType;
StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
+ nameSB << "_Array_" << getLayoutName(config.layoutRuleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
nameSB << getArraySizeVal(arrayType->getElementCount());
@@ -835,11 +658,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayType(
loweredInnerTypeInfo.loweredType,
arrayType->getElementCount(),
@@ -857,11 +680,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayTypeBase(
arrayTypeBase->getOp(),
loweredInnerTypeInfo.loweredType,
@@ -880,20 +703,15 @@ struct LoweredElementTypeContext
auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
if (loweredFieldTypeInfo.convertLoweredToOriginal ||
- config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural)
+ config.layoutRuleName != IRTypeLayoutRuleName::Natural)
isTrivial = false;
}
- // For spirv backend, we always want to lower all array types, even if the element type
- // comes out the same. This is because different layout rules may have different array
- // stride requirements.
- //
- // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
- // GLSL.
- if (!isKhronosTarget(target->getTargetReq()))
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
- // For non-spirv target, we skip lowering this type if all field types are
- // unchanged.
if (isTrivial)
{
info.loweredType = type;
@@ -905,7 +723,7 @@ struct LoweredElementTypeContext
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
+ nameSB << "_" << getLayoutName(config.layoutRuleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
// Create fields.
@@ -1007,58 +825,7 @@ struct LoweredElementTypeContext
return info;
}
-
- if (target->shouldEmitSPIRVDirectly())
- {
- switch (target->getTargetReq()->getTarget())
- {
- case CodeGenTarget::SPIRV:
- case CodeGenTarget::SPIRVAssembly:
- {
- auto scalarType = type;
- auto vectorType = as<IRVectorType>(scalarType);
- if (vectorType)
- scalarType = vectorType->getElementType();
-
- if (as<IRBoolType>(scalarType))
- {
- // Bool is an abstract type in SPIRV, so we need to lower them into an int.
-
- // Find an integer type of the correct size for the current layout rule.
- IRSizeAndAlignment boolSizeAndAlignment;
- if (getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- scalarType,
- &boolSizeAndAlignment) == SLANG_OK)
- {
- IntInfo ii;
- ii.width = boolSizeAndAlignment.size * 8;
- ii.isSigned = true;
- info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
- }
- else
- {
- // Just in case that fails for some reason, just use an int.
- info.loweredType = builder.getIntType();
- }
-
- if (vectorType)
- info.loweredType = builder.getVectorType(
- info.loweredType,
- vectorType->getElementCount());
- info.convertLoweredToOriginal = kIROp_BuiltinCast;
- info.convertOriginalToLowered = kIROp_BuiltinCast;
- return info;
- }
- }
- default:
- break;
- }
- }
-
- info.loweredType = type;
- return info;
+ return leafTypeLoweringPolicy->lowerLeafLogicalType(type, config);
}
LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config)
@@ -1090,7 +857,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
info.loweredType,
&sizeAlignment);
loweredTypeInfo.set(type, info);
@@ -1190,13 +957,13 @@ struct LoweredElementTypeContext
IRSizeAndAlignment arrayElementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerType.loweredType,
&arrayElementSizeAlignment);
IRSizeAndAlignment baseSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()),
&baseSizeAlignment);
@@ -1701,9 +1468,6 @@ struct LoweredElementTypeContext
switch (ptrType->getAddressSpace())
{
case AddressSpace::UserPointer:
- if (!options.lowerBufferPointer)
- continue;
- [[fallthrough]];
case AddressSpace::Input:
case AddressSpace::Output:
elementType = ptrType->getValueType();
@@ -1720,7 +1484,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
elementType,
&sizeAlignment);
SLANG_UNUSED(sizeAlignment);
@@ -2181,63 +1945,60 @@ void lowerBufferElementTypeToStorageType(
IRModule* module,
BufferElementTypeLoweringOptions options)
{
- SlangMatrixLayoutMode defaultMatrixMode =
- (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
- if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
- isMetalTarget(target->getTargetReq())))
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(target, options, defaultMatrixMode);
+ LoweredElementTypeContext context(target, options);
context.processModule(module);
}
-IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout)
+IRTypeLayoutRuleName getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRuleName defaultLayout)
{
switch (layoutTypeOp)
{
case kIROp_DefaultBufferLayoutType:
return defaultLayout;
case kIROp_Std140BufferLayoutType:
- return IRTypeLayoutRules::getStd140();
+ return IRTypeLayoutRuleName::Std140;
case kIROp_Std430BufferLayoutType:
- return IRTypeLayoutRules::getStd430();
+ return IRTypeLayoutRuleName::Std430;
case kIROp_ScalarBufferLayoutType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
case kIROp_CBufferLayoutType:
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
}
return defaultLayout;
}
-IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType)
{
+ if (bufferType->getOp() == kIROp_ParameterBlockType && isMetalTarget(target->getTargetReq()))
+ {
+ return IRTypeLayoutRuleName::MetalParameterBlock;
+ }
if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL)
{
if (!isKhronosTarget(target->getTargetReq()))
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If we are just emitting GLSL, we can just use the general layout rule.
if (!target->shouldEmitSPIRVDirectly())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If the user specified a C-compatible buffer layout, then do that.
if (target->getOptionSet().shouldUseCLayout())
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
// If the user specified a scalar buffer layout, then just use that.
if (target->getOptionSet().shouldUseScalarLayout())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
if (target->getOptionSet().shouldUseDXLayout())
{
if (as<IRUniformParameterGroupType>(bufferType))
{
- return IRTypeLayoutRules::getConstantBuffer();
+ return IRTypeLayoutRuleName::D3DConstantBuffer;
}
else
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
// The default behavior is to use std140 for constant buffers and std430 for other buffers.
@@ -2253,17 +2014,17 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = structBufferType->getDataLayout()
? structBufferType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
- case kIROp_ConstantBufferType:
case kIROp_ParameterBlockType:
+ case kIROp_ConstantBufferType:
{
auto parameterGroupType = as<IRUniformParameterGroupType>(bufferType);
auto layoutTypeOp = parameterGroupType->getDataLayout()
? parameterGroupType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std140);
}
case kIROp_GLSLShaderStorageBufferType:
{
@@ -2271,12 +2032,18 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = storageBufferType->getDataLayout()
? storageBufferType->getDataLayout()->getOp()
: kIROp_Std430BufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
case kIROp_PtrType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
+}
+
+IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+{
+ auto ruleName = getTypeLayoutRuleNameForBuffer(target, bufferType);
+ return IRTypeLayoutRules::get(ruleName);
}
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType)
@@ -2295,8 +2062,413 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType*
break;
}
}
- auto rules = getTypeLayoutRuleForBuffer(target, bufferType);
+ auto rules = getTypeLayoutRuleNameForBuffer(target, bufferType);
return TypeLoweringConfig{addrSpace, rules};
}
+struct DefaultBufferElementTypeLoweringPolicy : BufferElementTypeLoweringPolicy
+{
+ TargetProgram* target;
+ BufferElementTypeLoweringOptions options;
+ SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+
+ DefaultBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : target(inTarget), options(inOptions)
+ {
+ defaultMatrixLayout = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
+ if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
+ isMetalTarget(target->getTargetReq())))
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ else if (defaultMatrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
+ {
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
+ config.getLayoutRule()->ruleName == IRTypeLayoutRuleName::Natural)
+ {
+ // We only lower the matrix types if they differ from the default
+ // matrix layout.
+ return false;
+ }
+ return true;
+ }
+
+ IRFunc* createMatrixUnpackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRStructKey* dataKey)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = (Index)getIntVal(matrixType->getRowCount());
+ auto colCount = (Index)getIntVal(matrixType->getColumnCount());
+ auto packedParamRef = builder.emitParam(refStructType);
+ auto packedParam = builder.emitLoad(packedParamRef);
+ auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
+ List<IRInst*> args;
+ args.setCount(rowCount * colCount);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, c);
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = builder.emitElementExtract(vector, r);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ IRInst* result =
+ builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createMatrixPackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRVectorType* vectorType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outStructType, matrixType};
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto outParam = builder.emitParam(outStructType);
+ auto originalParam = builder.emitParam(matrixType);
+ List<IRInst*> elements;
+ elements.setCount((Index)(rowCount * colCount));
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(originalParam, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ elements[(Index)(r * colCount + c)] = element;
+ }
+ }
+ List<IRInst*> vectors;
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto colVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(colVector);
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto rowVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(rowVector);
+ }
+ }
+
+ auto vectorArray =
+ builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
+ builder.emitStore(outParam, result);
+ builder.emitReturn();
+ return func;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ IRBuilder builder(type);
+ builder.setInsertAfter(type);
+
+ LoweredElementTypeInfo info;
+ info.originalType = type;
+
+ if (auto matrixType = as<IRMatrixType>(type))
+ {
+ if (!shouldLowerMatrixType(matrixType, config))
+ {
+ info.loweredType = type;
+ return info;
+ }
+
+ auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
+ StringBuilder nameSB;
+ bool isColMajor =
+ getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ nameSB << "_MatrixStorage_";
+ getTypeNameHint(nameSB, matrixType->getElementType());
+ nameSB << getIntVal(matrixType->getRowCount()) << "x"
+ << getIntVal(matrixType->getColumnCount());
+ if (isColMajor)
+ nameSB << "_ColMajor";
+ nameSB << getLayoutName(config.layoutRuleName);
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
+ {
+ // For constant buffer layout, we need to use 16-byte aligned vector if
+ // we are required to ensure array element types has 16-byte stride.
+ vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
+ matrixType->getElementType(),
+ getIntVal(vectorSize)));
+ }
+
+ auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ vectorType,
+ &elementSizeAlignment);
+ elementSizeAlignment =
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
+
+ auto arrayType = builder.getArrayType(
+ vectorType,
+ isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
+ builder.createStructField(loweredType, structKey, arrayType);
+
+ info.loweredType = loweredType;
+ info.loweredInnerArrayType = arrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal =
+ createMatrixUnpackFunc(matrixType, loweredType, structKey);
+ info.convertOriginalToLowered =
+ createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
+ return info;
+ }
+
+ info.loweredType = type;
+ return info;
+ }
+};
+
+struct KhronosTargetBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ KhronosTargetBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ // For spirv, we always want to lower all matrix types, because SPIRV does not support
+ // specifying matrix layout/stride if the matrix type is used in places other than
+ // defining a struct field. This means that if a matrix is used to define a varying
+ // parameter, we always want to wrap it in a struct.
+ //
+ if (target->shouldEmitSPIRVDirectly())
+ return true;
+ return DefaultBufferElementTypeLoweringPolicy::shouldLowerMatrixType(matrixType, config);
+ }
+
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ TypeLoweringConfig config) override
+ {
+ // For spirv backend, we always want to lower all array types, even if the element type
+ // comes out the same. This is because different layout rules may have different array
+ // stride requirements.
+ //
+ // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
+ // GLSL.
+ return target->shouldEmitSPIRVDirectly() && config.addressSpace != AddressSpace::Input;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (target->shouldEmitSPIRVDirectly())
+ {
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+
+ switch (target->getTargetReq()->getTarget())
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ {
+ auto scalarType = type;
+ auto vectorType = as<IRVectorType>(scalarType);
+ if (vectorType)
+ scalarType = vectorType->getElementType();
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+
+ if (as<IRBoolType>(scalarType))
+ {
+ // Bool is an abstract type in SPIRV, so we need to lower them into an int.
+
+ // Find an integer type of the correct size for the current layout rule.
+ IRSizeAndAlignment boolSizeAndAlignment;
+ if (getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ scalarType,
+ &boolSizeAndAlignment) == SLANG_OK)
+ {
+ IntInfo ii;
+ ii.width = boolSizeAndAlignment.size * 8;
+ ii.isSigned = true;
+ info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
+ }
+ else
+ {
+ // Just in case that fails for some reason, just use an int.
+ info.loweredType = builder.getIntType();
+ }
+
+ if (vectorType)
+ info.loweredType = builder.getVectorType(
+ info.loweredType,
+ vectorType->getElementCount());
+ info.convertLoweredToOriginal = kIROp_BuiltinCast;
+ info.convertOriginalToLowered = kIROp_BuiltinCast;
+ return info;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct MetalParameterBlockElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ MetalParameterBlockElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ SLANG_UNUSED(matrixType);
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (config.layoutRuleName == IRTypeLayoutRuleName::MetalParameterBlock &&
+ isResourceType(type))
+ {
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+ info.loweredType = builder.getType(kIROp_DescriptorHandleType, type);
+ info.convertLoweredToOriginal = kIROp_CastDescriptorHandleToResource;
+ info.convertOriginalToLowered = kIROp_CastResourceToDescriptorHandle;
+ return info;
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct WGSLBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ WGSLBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() override
+ {
+ return true;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options)
+{
+ switch (kind)
+ {
+ case BufferElementTypeLoweringPolicyKind::Default:
+ return new DefaultBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::KhronosTarget:
+ return new KhronosTargetBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::MetalParameterBlock:
+ return new MetalParameterBlockElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::WGSL:
+ return new WGSLBufferElementTypeLoweringPolicy(target, options);
+ }
+ SLANG_UNREACHABLE("unknown buffer element type lowering policy");
+}
+
} // namespace Slang