summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-c-like.cpp1
-rw-r--r--source/slang/slang-emit.cpp55
-rw-r--r--source/slang/slang-ir-insts-stable-names.lua1
-rw-r--r--source/slang/slang-ir-insts.lua1
-rw-r--r--source/slang/slang-ir-layout.cpp1
-rw-r--r--source/slang/slang-ir-legalize-types.cpp128
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp1058
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.h19
-rw-r--r--source/slang/slang-ir-wrap-cbuffer-element.cpp133
-rw-r--r--source/slang/slang-ir-wrap-cbuffer-element.h23
-rw-r--r--source/slang/slang-ir.cpp1
-rw-r--r--source/slang/slang-ir.h1
12 files changed, 834 insertions, 588 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 77c45a6d9..d69485cce 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -2543,6 +2543,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
case kIROp_CastDescriptorHandleToUInt2:
case kIROp_CastUInt2ToDescriptorHandle:
case kIROp_CastDescriptorHandleToResource:
+ case kIROp_CastResourceToDescriptorHandle:
case kIROp_CastUInt64ToDescriptorHandle:
case kIROp_CastDescriptorHandleToUInt64:
emitOperand(inst->getOperand(0), outerPrec);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 09c2efea9..804a44b81 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -120,6 +120,7 @@
#include "slang-ir-variable-scope-correction.h"
#include "slang-ir-vk-invert-y.h"
#include "slang-ir-wgsl-legalize.h"
+#include "slang-ir-wrap-cbuffer-element.h"
#include "slang-ir-wrap-structured-buffers.h"
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
@@ -1276,6 +1277,23 @@ Result linkAndOptimizeIR(
// We don't need the legalize pass for C/C++ based types
if (options.shouldLegalizeExistentialAndResourceTypes)
{
+ if (isMetalTarget(targetRequest))
+ {
+ // Metal is a special target in that we want to legalize constant buffer
+ // types as if it is a bindful target, and skip legalizing parameter block
+ // types as if it is a bindless target.
+ // To achieve this, we want to ensure that all resource typed fields in parameter blocks
+ // are translated into descriptor handles first before running the resource type
+ // legalization pass for metal, so that type legalization pass won't mess around with
+ // them.
+ BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions = {};
+ bufferElementTypeLoweringOptions.loweringPolicyKind =
+ BufferElementTypeLoweringPolicyKind::MetalParameterBlock;
+ lowerBufferElementTypeToStorageType(
+ targetProgram,
+ irModule,
+ bufferElementTypeLoweringOptions);
+ }
// The Slang language allows interfaces to be used like
// ordinary types (including placing them in constant
@@ -1403,17 +1421,17 @@ Result linkAndOptimizeIR(
// Some information for `static_assert` is available only after the specialization.
checkStaticAssert(irModule->getModuleInst(), sink);
- // For HLSL (and fxc/dxc) only, we need to "wrap" any
- // structured buffers defined over matrix types so
- // that they instead use an intermediate `struct`.
- // This is required to get those targets to respect
- // the options for matrix layout set via `#pragma`
- // or command-line options.
- //
switch (target)
{
case CodeGenTarget::HLSL:
{
+ // For HLSL(fxc) only, we need to "wrap" any
+ // structured buffers defined over matrix types so
+ // that they instead use an intermediate `struct`.
+ // This is required to get those targets to respect
+ // the options for matrix layout set via `#pragma`
+ // or command-line options.
+ //
wrapStructuredBuffersOfMatrices(irModule);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "STRUCTURED BUFFERS WRAPPED");
@@ -1421,7 +1439,12 @@ Result linkAndOptimizeIR(
validateIRModuleIfEnabled(codeGenContext, irModule);
}
break;
-
+ case CodeGenTarget::Metal:
+ case CodeGenTarget::MetalLib:
+ // Metal does not allow `ConstantBuffer<StructuredBuffer<T>>`, so we need to create
+ // a wrapper struct for the `StructuredBuffer<T>`.
+ wrapCBufferElementsForMetal(irModule);
+ break;
default:
break;
}
@@ -1828,11 +1851,16 @@ Result linkAndOptimizeIR(
rcpWOfPositionInput(irModule);
}
- bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly();
-
- BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions;
- bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer =
- isWGPUTarget(targetRequest);
+ BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions = {};
+ if (isWGPUTarget(targetRequest))
+ bufferElementTypeLoweringOptions.loweringPolicyKind =
+ BufferElementTypeLoweringPolicyKind::WGSL;
+ else if (isKhronosTarget(targetRequest))
+ bufferElementTypeLoweringOptions.loweringPolicyKind =
+ BufferElementTypeLoweringPolicyKind::KhronosTarget;
+ else
+ bufferElementTypeLoweringOptions.loweringPolicyKind =
+ BufferElementTypeLoweringPolicyKind::Default;
lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions);
// If we are generating code for glsl or metal, perform address space propagation now.
@@ -1850,6 +1878,7 @@ Result linkAndOptimizeIR(
performForceInlining(irModule);
+ bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly();
if (emitSpirvDirectly)
{
performIntrinsicFunctionInlining(irModule);
diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua
index cfde8c5fa..4b56cc52f 100644
--- a/source/slang/slang-ir-insts-stable-names.lua
+++ b/source/slang/slang-ir-insts-stable-names.lua
@@ -676,4 +676,5 @@ return {
["CastStorageToLogicalBase.CastStorageToLogicalDeref"] = 672,
["Decoration.DisableCopyEliminationDecoration"] = 673,
["Decoration.TempCallArgImmutableVar"] = 674,
+ ["CastResourceToDescriptorHandle"] = 675,
}
diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua
index 045144e06..8b8515424 100644
--- a/source/slang/slang-ir-insts.lua
+++ b/source/slang/slang-ir-insts.lua
@@ -1908,6 +1908,7 @@ local insts = {
-- Represents a no-op cast to convert a resource pointer to a resource on targets where the resource handles are
-- already concrete types.
{ CastDescriptorHandleToResource = { min_operands = 1 } },
+ { CastResourceToDescriptorHandle = { min_operands = 1 } },
{ TreatAsDynamicUniform = { min_operands = 1 } },
{ sizeOf = { min_operands = 1 } },
{ alignOf = { min_operands = 1 } },
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 123dfdea4..9db4b3097 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -769,6 +769,7 @@ IRTypeLayoutRules* IRTypeLayoutRules::get(IRTypeLayoutRuleName name)
case IRTypeLayoutRuleName::Std140:
return getStd140();
case IRTypeLayoutRuleName::Natural:
+ case IRTypeLayoutRuleName::MetalParameterBlock:
return getNatural();
case IRTypeLayoutRuleName::C:
return getC();
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 5168f0466..021566e12 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1950,133 +1950,6 @@ static LegalVal legalizeDefaultConstruct(IRTypeLegalizationContext* context, Leg
}
}
-// If a legalized `val` has a different flavor than `type`, try to coerce it to `type`.
-//
-static LegalVal coerceToLegalType(IRTypeLegalizationContext* context, LegalType type, LegalVal val)
-{
- switch (type.flavor)
- {
- case LegalType::Flavor::none:
- return LegalVal();
- case LegalType::Flavor::simple:
- {
- if (val.flavor != LegalVal::Flavor::simple)
- return val;
- auto simpleVal = val.getSimple();
- if (simpleVal->getDataType() == type.getSimple())
- return val;
-
- auto resultType = type.getSimple();
- auto structType = as<IRStructType>(resultType);
- if (!structType)
- {
- auto resultValueType = tryGetPointedToType(context->builder, resultType);
- if (!resultValueType)
- return val;
- auto valValueType = tryGetPointedToType(context->builder, simpleVal->getDataType());
- if (!valValueType)
- return val;
- if (resultValueType == valValueType)
- return val;
- auto loadedVal = context->builder->emitLoad(val.getSimple());
- auto innerLegalVal = coerceToLegalType(
- context,
- LegalType::simple(resultValueType),
- LegalVal::simple(loadedVal));
- return LegalVal::implicitDeref(innerLegalVal);
- }
- ShortList<IRInst*> fields;
- for (auto field : structType->getFields())
- {
- if (as<IRVoidType>(field->getFieldType()))
- continue;
- auto fieldVal = coerceToLegalType(
- context,
- LegalType::simple(field->getFieldType()),
- LegalVal::simple(
- context->builder->emitFieldExtract(simpleVal, field->getKey())));
- fields.add(fieldVal.getSimple());
- }
- return LegalVal::simple(context->builder->emitMakeStruct(
- structType,
- (UInt)fields.getCount(),
- fields.getArrayView().getBuffer()));
- }
- case LegalType::Flavor::implicitDeref:
- {
- auto innerVal = val;
- if (innerVal.flavor == LegalVal::Flavor::implicitDeref)
- innerVal = innerVal.getImplicitDeref();
- else if (innerVal.flavor == LegalVal::Flavor::simple)
- innerVal = LegalVal::simple(context->builder->emitLoad(innerVal.getSimple()));
- innerVal = coerceToLegalType(context, type.getImplicitDeref()->valueType, innerVal);
- return LegalVal::implicitDeref(innerVal);
- }
- case LegalType::Flavor::pair:
- {
- if (val.flavor == LegalVal::Flavor::pair)
- return val;
- else if (val.flavor == LegalVal::Flavor::simple)
- {
- auto pairType = type.getPair();
- auto pairInfo = pairType->pairInfo;
- LegalVal ordinaryVal = coerceToLegalType(context, pairType->ordinaryType, val);
- LegalVal specialVal = coerceToLegalType(context, pairType->specialType, val);
- return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
- }
- else if (val.flavor == LegalVal::Flavor::implicitDeref)
- {
- LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref());
- return LegalVal::implicitDeref(innerVal);
- }
- else
- {
- SLANG_UNEXPECTED("unhandled legal type coercion");
- UNREACHABLE_RETURN(LegalVal());
- }
- }
- case LegalType::Flavor::tuple:
- {
- if (val.flavor == LegalVal::Flavor::tuple)
- return val;
- else if (val.flavor == LegalVal::Flavor::simple)
- {
- auto tupleType = type.getTuple();
- RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal();
- auto simpleVal = val.getSimple();
- for (auto elem : tupleType->elements)
- {
- IRInst* elementVal = nullptr;
- if (as<IRPtrTypeBase>(simpleVal->getDataType()) ||
- as<IRPointerLikeType>(simpleVal->getDataType()))
- elementVal = context->builder->emitFieldAddress(simpleVal, elem.key);
- else
- elementVal = context->builder->emitFieldExtract(simpleVal, elem.key);
- LegalVal legalElementVal =
- coerceToLegalType(context, elem.type, LegalVal::simple(elementVal));
- TuplePseudoVal::Element tupleElem;
- tupleElem.key = elem.key;
- tupleElem.val = legalElementVal;
- tupleVal->elements.add(tupleElem);
- }
- return LegalVal::tuple(tupleVal);
- }
- else if (val.flavor == LegalVal::Flavor::implicitDeref)
- {
- LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref());
- return LegalVal::implicitDeref(innerVal);
- }
- else
- {
- SLANG_UNEXPECTED("unhandled legal type coercion");
- UNREACHABLE_RETURN(LegalVal());
- }
- }
- default:
- return val;
- }
-}
-
static LegalVal legalizeUndefined(IRTypeLegalizationContext* context, IRInst* inst)
{
IRType* opaqueType = nullptr;
@@ -2188,7 +2061,6 @@ static LegalVal legalizeInst(
SLANG_UNEXPECTED("non-simple operand(s)!");
break;
}
- result = coerceToLegalType(context, type, result);
return result;
}
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 83218bade..c69592939 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -214,104 +214,173 @@
namespace Slang
{
+enum ConversionMethodKind
+{
+ Func,
+ Opcode
+};
+
+struct ConversionMethod
+{
+ ConversionMethodKind kind = ConversionMethodKind::Func;
+ union
+ {
+ IRFunc* func;
+ IROp op;
+ };
+ ConversionMethod() { func = nullptr; }
+ operator bool()
+ {
+ return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
+ }
+ ConversionMethod& operator=(IRFunc* f)
+ {
+ kind = ConversionMethodKind::Func;
+ this->func = f;
+ return *this;
+ }
+ ConversionMethod& operator=(IROp irop)
+ {
+ kind = ConversionMethodKind::Opcode;
+ this->op = irop;
+ return *this;
+ }
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr);
+ void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand);
+};
+
struct TypeLoweringConfig
{
AddressSpace addressSpace;
- IRTypeLayoutRules* layoutRule;
+ IRTypeLayoutRuleName layoutRuleName;
+ IRTypeLayoutRules* getLayoutRule() const { return IRTypeLayoutRules::get(layoutRuleName); }
+
bool operator==(const TypeLoweringConfig& other) const
{
- return addressSpace == other.addressSpace && layoutRule == other.layoutRule;
+ return addressSpace == other.addressSpace && layoutRuleName == other.layoutRuleName;
}
HashCode getHashCode() const
{
- return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule));
+ return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRuleName));
}
};
+
+struct LoweredElementTypeInfo
+{
+ IRType* originalType;
+ IRType* loweredType;
+ IRType* loweredInnerArrayType =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // inner array type of the data field.
+ IRStructKey* loweredInnerStructKey =
+ nullptr; // For matrix/array types that are lowered into a struct type, this is the
+ // struct key of the data field.
+ ConversionMethod convertOriginalToLowered;
+ ConversionMethod convertLoweredToOriginal;
+};
+
+/// Defines target-specific behavior of how to lower buffer element types.
+struct BufferElementTypeLoweringPolicy : public RefObject
+{
+ /// Defines target-specific behavior of how to translate a non-composite logical type to a
+ /// storage type.
+ virtual LoweredElementTypeInfo lowerLeafLogicalType(
+ IRType* type,
+ TypeLoweringConfig config) = 0;
+
+ /// Returns true if we should always create a fresh lowered storage type for a composite type,
+ /// even if every member/element of the composite type is not changed by the lowering.
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(TypeLoweringConfig config)
+ {
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ /// Returns true if the target requires all array of scalars or vectors inside a constant buffer
+ /// to be translated into a 16-byte aligned vector type.
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()
+ {
+ return false;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options);
+
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType);
-struct LoweredElementTypeContext
+IRInst* ConversionMethod::apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
{
- static const IRIntegerValue kMaxArraySizeToUnroll = 32;
+ if (!*this)
+ return builder.emitLoad(operandAddr);
+ if (kind == ConversionMethodKind::Func)
+ return builder.emitCallInst(resultType, func, 1, &operandAddr);
+ else
+ {
+ auto val = builder.emitLoad(operandAddr);
+ return builder.emitIntrinsicInst(resultType, op, 1, &val);
+ }
+}
- enum ConversionMethodKind
+void ConversionMethod::applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
+{
+ if (!*this)
{
- Func,
- Opcode
- };
- struct ConversionMethod
+ builder.emitStore(dest, operand);
+ return;
+ }
+ if (kind == ConversionMethodKind::Func)
{
- ConversionMethodKind kind = ConversionMethodKind::Func;
- union
- {
- IRFunc* func;
- IROp op;
- };
- ConversionMethod() { func = nullptr; }
- operator bool()
- {
- return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop;
- }
- ConversionMethod& operator=(IRFunc* f)
- {
- kind = ConversionMethodKind::Func;
- this->func = f;
- return *this;
- }
- ConversionMethod& operator=(IROp irop)
- {
- kind = ConversionMethodKind::Opcode;
- this->op = irop;
- return *this;
- }
- IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
- {
- if (!*this)
- return builder.emitLoad(operandAddr);
- if (kind == ConversionMethodKind::Func)
- return builder.emitCallInst(resultType, func, 1, &operandAddr);
- else
- {
- auto val = builder.emitLoad(operandAddr);
- return builder.emitIntrinsicInst(resultType, op, 1, &val);
- }
- }
- void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand)
- {
- if (!*this)
- {
- builder.emitStore(dest, operand);
- return;
- }
- if (kind == ConversionMethodKind::Func)
- {
- IRInst* operands[] = {dest, operand};
- builder.emitCallInst(builder.getVoidType(), func, 2, operands);
- }
- else
- {
- auto val = builder.emitIntrinsicInst(
- tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
- op,
- 1,
- &operand);
- builder.emitStore(dest, val);
- }
- }
- };
+ IRInst* operands[] = {dest, operand};
+ builder.emitCallInst(builder.getVoidType(), func, 2, operands);
+ }
+ else
+ {
+ auto val = builder.emitIntrinsicInst(
+ tryGetPointedToOrBufferElementType(&builder, dest->getDataType()),
+ op,
+ 1,
+ &operand);
+ builder.emitStore(dest, val);
+ }
+}
+
+// Returns the number of elements N that ensures the IRVectorType(elementType,N)
+// has 16-byte aligned size and N is no less than `minCount`.
+IRIntegerValue get16ByteAlignedVectorElementCount(
+ TargetProgram* target,
+ IRType* elementType,
+ IRIntegerValue minCount)
+{
+ IRSizeAndAlignment sizeAlignment;
+ getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
+ if (sizeAlignment.size)
+ return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
+ return 4;
+}
- struct LoweredElementTypeInfo
+const char* getLayoutName(IRTypeLayoutRuleName name)
+{
+ switch (name)
{
- IRType* originalType;
- IRType* loweredType;
- IRType* loweredInnerArrayType =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // inner array type of the data field.
- IRStructKey* loweredInnerStructKey =
- nullptr; // For matrix/array types that are lowered into a struct type, this is the
- // struct key of the data field.
- ConversionMethod convertOriginalToLowered;
- ConversionMethod convertLoweredToOriginal;
- };
+ case IRTypeLayoutRuleName::Std140:
+ return "std140";
+ case IRTypeLayoutRuleName::Std430:
+ return "std430";
+ case IRTypeLayoutRuleName::Natural:
+ return "natural";
+ case IRTypeLayoutRuleName::C:
+ return "c";
+ default:
+ return "default";
+ }
+}
+
+struct LoweredElementTypeContext
+{
+ static const IRIntegerValue kMaxArraySizeToUnroll = 32;
struct LoweredTypeMap : RefObject
{
@@ -320,6 +389,7 @@ struct LoweredElementTypeContext
};
Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps;
+ RefPtr<BufferElementTypeLoweringPolicy> leafTypeLoweringPolicy;
struct ConversionMethodKey
{
@@ -366,150 +436,11 @@ struct LoweredElementTypeContext
// Specialized functions that takes storage-typed pointers instead of logical-typed pointers.
Dictionary<SpecializationKey, IRFunc*> specializedFuncs;
- LoweredElementTypeContext(
- TargetProgram* target,
- BufferElementTypeLoweringOptions inOptions,
- SlangMatrixLayoutMode inDefaultMatrixLayout)
- : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions)
- {
- }
-
- IRFunc* createMatrixUnpackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRStructKey* dataKey)
- {
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = (Index)getIntVal(matrixType->getRowCount());
- auto colCount = (Index)getIntVal(matrixType->getColumnCount());
- auto packedParamRef = builder.emitParam(refStructType);
- auto packedParam = builder.emitLoad(packedParamRef);
- auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
- List<IRInst*> args;
- args.setCount(rowCount * colCount);
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto vector = builder.emitElementExtract(vectorArray, c);
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = builder.emitElementExtract(vector, r);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(vectorArray, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- args[(Index)(r * colCount + c)] = element;
- }
- }
- }
- IRInst* result =
- builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
- builder.emitReturn(result);
- return func;
- }
-
- IRFunc* createMatrixPackFunc(
- IRMatrixType* matrixType,
- IRStructType* structType,
- IRVectorType* vectorType,
- IRArrayType* arrayType)
+ LoweredElementTypeContext(TargetProgram* target, BufferElementTypeLoweringOptions inOptions)
+ : target(target), options(inOptions)
{
- IRBuilder builder(structType);
- builder.setInsertAfter(structType);
- auto func = builder.createFunc();
- auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
- IRType* paramTypes[] = {outStructType, matrixType};
- auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
- func->setFullType(funcType);
- builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
- builder.addForceInlineDecoration(func);
- builder.setInsertInto(func);
- builder.emitBlock();
- auto rowCount = getIntVal(matrixType->getRowCount());
- auto colCount = getIntVal(matrixType->getColumnCount());
- auto outParam = builder.emitParam(outStructType);
- auto originalParam = builder.emitParam(matrixType);
- List<IRInst*> elements;
- elements.setCount((Index)(rowCount * colCount));
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto vector = builder.emitElementExtract(originalParam, r);
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = builder.emitElementExtract(vector, c);
- elements[(Index)(r * colCount + c)] = element;
- }
- }
- List<IRInst*> vectors;
- if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
- {
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto colVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(colVector);
- }
- }
- else
- {
- for (IRIntegerValue r = 0; r < rowCount; r++)
- {
- List<IRInst*> vecArgs;
- for (IRIntegerValue c = 0; c < colCount; c++)
- {
- auto element = elements[(Index)(r * colCount + c)];
- vecArgs.add(element);
- }
- // Fill in default values for remaining elements in the vector.
- for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
- {
- vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
- }
- auto rowVector = builder.emitMakeVector(
- vectorType,
- (UInt)vecArgs.getCount(),
- vecArgs.getBuffer());
- vectors.add(rowVector);
- }
- }
-
- auto vectorArray =
- builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
- auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
- builder.emitStore(outParam, result);
- builder.emitReturn();
- return func;
+ leafTypeLoweringPolicy =
+ getBufferElementTypeLoweringPolicy(options.loweringPolicyKind, target, options);
}
IRFunc* createArrayUnpackFunc(
@@ -637,56 +568,6 @@ struct LoweredElementTypeContext
return func;
}
- const char* getLayoutName(IRTypeLayoutRuleName name)
- {
- switch (name)
- {
- case IRTypeLayoutRuleName::Std140:
- return "std140";
- case IRTypeLayoutRuleName::Std430:
- return "std430";
- case IRTypeLayoutRuleName::Natural:
- return "natural";
- case IRTypeLayoutRuleName::C:
- return "c";
- default:
- return "default";
- }
- }
-
- // Returns the number of elements N that ensures the IRVectorType(elementType,N)
- // has 16-byte aligned size and N is no less than `minCount`.
- IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount)
- {
- IRSizeAndAlignment sizeAlignment;
- getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment);
- if (sizeAlignment.size)
- return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size;
- return 4;
- }
-
- bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
- {
- // For spirv, we always want to lower all matrix types, because SPIRV does not support
- // specifying matrix layout/stride if the matrix type is used in places other than
- // defining a struct field. This means that if a matrix is used to define a varying
- // parameter, we always want to wrap it in a struct.
- //
- if (target->shouldEmitSPIRVDirectly())
- {
- return true;
- }
-
- if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
- {
- // For other targets, we only lower the matrix types if they differ from the default
- // matrix layout.
- return false;
- }
- return true;
- }
-
LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
@@ -694,72 +575,13 @@ struct LoweredElementTypeContext
LoweredElementTypeInfo info;
info.originalType = type;
-
- if (auto matrixType = as<IRMatrixType>(type))
- {
- if (!shouldLowerMatrixType(matrixType, config))
- {
- info.loweredType = type;
- return info;
- }
-
- auto loweredType = builder.createStructType();
- builder.addPhysicalTypeDecoration(loweredType);
-
- StringBuilder nameSB;
- bool isColMajor =
- getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
- nameSB << "_MatrixStorage_";
- getTypeNameHint(nameSB, matrixType->getElementType());
- nameSB << getIntVal(matrixType->getRowCount()) << "x"
- << getIntVal(matrixType->getColumnCount());
- if (isColMajor)
- nameSB << "_ColMajor";
- nameSB << getLayoutName(config.layoutRule->ruleName);
- builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
- auto structKey = builder.createStructKey();
- builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
- auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
- {
- // For constant buffer layout, we need to use 16-byte aligned vector if
- // we are required to ensure array element types has 16-byte stride.
- vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
- matrixType->getElementType(),
- getIntVal(vectorSize)));
- }
-
- auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
- IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- vectorType,
- &elementSizeAlignment);
- elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
-
- auto arrayType = builder.getArrayType(
- vectorType,
- isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
- builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
- builder.createStructField(loweredType, structKey, arrayType);
-
- info.loweredType = loweredType;
- info.loweredInnerArrayType = arrayType;
- info.loweredInnerStructKey = structKey;
- info.convertLoweredToOriginal =
- createMatrixUnpackFunc(matrixType, loweredType, structKey);
- info.convertOriginalToLowered =
- createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
- return info;
- }
- else if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
+ if (auto arrayTypeBase = as<IRArrayTypeBase>(type))
{
auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config);
- if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
- options.use16ByteArrayElementForConstantBuffer)
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ leafTypeLoweringPolicy
+ ->shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
{
// For constant buffer layout, we need to use 16-byte-aligned vector if
// we are required to ensure array element types has 16-byte stride.
@@ -772,6 +594,7 @@ struct LoweredElementTypeContext
packedVectorType = builder.getVectorType(
vectorType->getElementType(),
builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
vectorType->getElementType(),
getIntVal(vectorType->getElementCount()))));
if (packedVectorType != loweredInnerTypeInfo.originalType)
@@ -784,7 +607,7 @@ struct LoweredElementTypeContext
{
packedVectorType = builder.getVectorType(
loweredInnerTypeInfo.loweredType,
- get16ByteAlignedVectorElementCount(scalarType, 1));
+ get16ByteAlignedVectorElementCount(target, scalarType, 1));
loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape;
loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar;
}
@@ -803,10 +626,10 @@ struct LoweredElementTypeContext
}
}
- // For spirv backend, we always want to lower all array types for non-varying
- // parameters, even if the element type comes out the same. This is because different
- // layout rules may have different array stride requirements.
- if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input)
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
@@ -823,7 +646,7 @@ struct LoweredElementTypeContext
info.loweredType = loweredType;
StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
+ nameSB << "_Array_" << getLayoutName(config.layoutRuleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
nameSB << getArraySizeVal(arrayType->getElementCount());
@@ -835,11 +658,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayType(
loweredInnerTypeInfo.loweredType,
arrayType->getElementCount(),
@@ -857,11 +680,11 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
elementSizeAlignment =
- config.layoutRule->alignCompositeElement(elementSizeAlignment);
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayTypeBase(
arrayTypeBase->getOp(),
loweredInnerTypeInfo.loweredType,
@@ -880,20 +703,15 @@ struct LoweredElementTypeContext
auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
if (loweredFieldTypeInfo.convertLoweredToOriginal ||
- config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural)
+ config.layoutRuleName != IRTypeLayoutRuleName::Natural)
isTrivial = false;
}
- // For spirv backend, we always want to lower all array types, even if the element type
- // comes out the same. This is because different layout rules may have different array
- // stride requirements.
- //
- // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
- // GLSL.
- if (!isKhronosTarget(target->getTargetReq()))
+ // We can skip lowering this type if all field types are unchanged, unless the target
+ // specific policy tells us to always create a lowered type.
+ if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ config))
{
- // For non-spirv target, we skip lowering this type if all field types are
- // unchanged.
if (isTrivial)
{
info.loweredType = type;
@@ -905,7 +723,7 @@ struct LoweredElementTypeContext
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
+ nameSB << "_" << getLayoutName(config.layoutRuleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
// Create fields.
@@ -1007,58 +825,7 @@ struct LoweredElementTypeContext
return info;
}
-
- if (target->shouldEmitSPIRVDirectly())
- {
- switch (target->getTargetReq()->getTarget())
- {
- case CodeGenTarget::SPIRV:
- case CodeGenTarget::SPIRVAssembly:
- {
- auto scalarType = type;
- auto vectorType = as<IRVectorType>(scalarType);
- if (vectorType)
- scalarType = vectorType->getElementType();
-
- if (as<IRBoolType>(scalarType))
- {
- // Bool is an abstract type in SPIRV, so we need to lower them into an int.
-
- // Find an integer type of the correct size for the current layout rule.
- IRSizeAndAlignment boolSizeAndAlignment;
- if (getSizeAndAlignment(
- target->getOptionSet(),
- config.layoutRule,
- scalarType,
- &boolSizeAndAlignment) == SLANG_OK)
- {
- IntInfo ii;
- ii.width = boolSizeAndAlignment.size * 8;
- ii.isSigned = true;
- info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
- }
- else
- {
- // Just in case that fails for some reason, just use an int.
- info.loweredType = builder.getIntType();
- }
-
- if (vectorType)
- info.loweredType = builder.getVectorType(
- info.loweredType,
- vectorType->getElementCount());
- info.convertLoweredToOriginal = kIROp_BuiltinCast;
- info.convertOriginalToLowered = kIROp_BuiltinCast;
- return info;
- }
- }
- default:
- break;
- }
- }
-
- info.loweredType = type;
- return info;
+ return leafTypeLoweringPolicy->lowerLeafLogicalType(type, config);
}
LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config)
@@ -1090,7 +857,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
info.loweredType,
&sizeAlignment);
loweredTypeInfo.set(type, info);
@@ -1190,13 +957,13 @@ struct LoweredElementTypeContext
IRSizeAndAlignment arrayElementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
loweredInnerType.loweredType,
&arrayElementSizeAlignment);
IRSizeAndAlignment baseSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()),
&baseSizeAlignment);
@@ -1701,9 +1468,6 @@ struct LoweredElementTypeContext
switch (ptrType->getAddressSpace())
{
case AddressSpace::UserPointer:
- if (!options.lowerBufferPointer)
- continue;
- [[fallthrough]];
case AddressSpace::Input:
case AddressSpace::Output:
elementType = ptrType->getValueType();
@@ -1720,7 +1484,7 @@ struct LoweredElementTypeContext
IRSizeAndAlignment sizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- config.layoutRule,
+ config.getLayoutRule(),
elementType,
&sizeAlignment);
SLANG_UNUSED(sizeAlignment);
@@ -2181,63 +1945,60 @@ void lowerBufferElementTypeToStorageType(
IRModule* module,
BufferElementTypeLoweringOptions options)
{
- SlangMatrixLayoutMode defaultMatrixMode =
- (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
- if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
- isMetalTarget(target->getTargetReq())))
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
- defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
- LoweredElementTypeContext context(target, options, defaultMatrixMode);
+ LoweredElementTypeContext context(target, options);
context.processModule(module);
}
-IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout)
+IRTypeLayoutRuleName getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRuleName defaultLayout)
{
switch (layoutTypeOp)
{
case kIROp_DefaultBufferLayoutType:
return defaultLayout;
case kIROp_Std140BufferLayoutType:
- return IRTypeLayoutRules::getStd140();
+ return IRTypeLayoutRuleName::Std140;
case kIROp_Std430BufferLayoutType:
- return IRTypeLayoutRules::getStd430();
+ return IRTypeLayoutRuleName::Std430;
case kIROp_ScalarBufferLayoutType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
case kIROp_CBufferLayoutType:
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
}
return defaultLayout;
}
-IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType)
{
+ if (bufferType->getOp() == kIROp_ParameterBlockType && isMetalTarget(target->getTargetReq()))
+ {
+ return IRTypeLayoutRuleName::MetalParameterBlock;
+ }
if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL)
{
if (!isKhronosTarget(target->getTargetReq()))
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If we are just emitting GLSL, we can just use the general layout rule.
if (!target->shouldEmitSPIRVDirectly())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
// If the user specified a C-compatible buffer layout, then do that.
if (target->getOptionSet().shouldUseCLayout())
- return IRTypeLayoutRules::getC();
+ return IRTypeLayoutRuleName::C;
// If the user specified a scalar buffer layout, then just use that.
if (target->getOptionSet().shouldUseScalarLayout())
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
if (target->getOptionSet().shouldUseDXLayout())
{
if (as<IRUniformParameterGroupType>(bufferType))
{
- return IRTypeLayoutRules::getConstantBuffer();
+ return IRTypeLayoutRuleName::D3DConstantBuffer;
}
else
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
// The default behavior is to use std140 for constant buffers and std430 for other buffers.
@@ -2253,17 +2014,17 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = structBufferType->getDataLayout()
? structBufferType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
- case kIROp_ConstantBufferType:
case kIROp_ParameterBlockType:
+ case kIROp_ConstantBufferType:
{
auto parameterGroupType = as<IRUniformParameterGroupType>(bufferType);
auto layoutTypeOp = parameterGroupType->getDataLayout()
? parameterGroupType->getDataLayout()->getOp()
: kIROp_DefaultBufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std140);
}
case kIROp_GLSLShaderStorageBufferType:
{
@@ -2271,12 +2032,18 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
auto layoutTypeOp = storageBufferType->getDataLayout()
? storageBufferType->getDataLayout()->getOp()
: kIROp_Std430BufferLayoutType;
- return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430());
+ return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430);
}
case kIROp_PtrType:
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
}
- return IRTypeLayoutRules::getNatural();
+ return IRTypeLayoutRuleName::Natural;
+}
+
+IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
+{
+ auto ruleName = getTypeLayoutRuleNameForBuffer(target, bufferType);
+ return IRTypeLayoutRules::get(ruleName);
}
TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType)
@@ -2295,8 +2062,413 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType*
break;
}
}
- auto rules = getTypeLayoutRuleForBuffer(target, bufferType);
+ auto rules = getTypeLayoutRuleNameForBuffer(target, bufferType);
return TypeLoweringConfig{addrSpace, rules};
}
+struct DefaultBufferElementTypeLoweringPolicy : BufferElementTypeLoweringPolicy
+{
+ TargetProgram* target;
+ BufferElementTypeLoweringOptions options;
+ SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+
+ DefaultBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : target(inTarget), options(inOptions)
+ {
+ defaultMatrixLayout = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode();
+ if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) ||
+ isMetalTarget(target->getTargetReq())))
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ else if (defaultMatrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
+ defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
+ {
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
+ config.getLayoutRule()->ruleName == IRTypeLayoutRuleName::Natural)
+ {
+ // We only lower the matrix types if they differ from the default
+ // matrix layout.
+ return false;
+ }
+ return true;
+ }
+
+ IRFunc* createMatrixUnpackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRStructKey* dataKey)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = (Index)getIntVal(matrixType->getRowCount());
+ auto colCount = (Index)getIntVal(matrixType->getColumnCount());
+ auto packedParamRef = builder.emitParam(refStructType);
+ auto packedParam = builder.emitLoad(packedParamRef);
+ auto vectorArray = builder.emitFieldExtract(packedParam, dataKey);
+ List<IRInst*> args;
+ args.setCount(rowCount * colCount);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, c);
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = builder.emitElementExtract(vector, r);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ IRInst* result =
+ builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createMatrixPackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRVectorType* vectorType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic);
+ IRType* paramTypes[] = {outStructType, matrixType};
+ auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.addForceInlineDecoration(func);
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto outParam = builder.emitParam(outStructType);
+ auto originalParam = builder.emitParam(matrixType);
+ List<IRInst*> elements;
+ elements.setCount((Index)(rowCount * colCount));
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(originalParam, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ elements[(Index)(r * colCount + c)] = element;
+ }
+ }
+ List<IRInst*> vectors;
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto colVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(colVector);
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ // Fill in default values for remaining elements in the vector.
+ for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++)
+ {
+ vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType()));
+ }
+ auto rowVector = builder.emitMakeVector(
+ vectorType,
+ (UInt)vecArgs.getCount(),
+ vecArgs.getBuffer());
+ vectors.add(rowVector);
+ }
+ }
+
+ auto vectorArray =
+ builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
+ builder.emitStore(outParam, result);
+ builder.emitReturn();
+ return func;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ IRBuilder builder(type);
+ builder.setInsertAfter(type);
+
+ LoweredElementTypeInfo info;
+ info.originalType = type;
+
+ if (auto matrixType = as<IRMatrixType>(type))
+ {
+ if (!shouldLowerMatrixType(matrixType, config))
+ {
+ info.loweredType = type;
+ return info;
+ }
+
+ auto loweredType = builder.createStructType();
+ builder.addPhysicalTypeDecoration(loweredType);
+
+ StringBuilder nameSB;
+ bool isColMajor =
+ getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ nameSB << "_MatrixStorage_";
+ getTypeNameHint(nameSB, matrixType->getElementType());
+ nameSB << getIntVal(matrixType->getRowCount()) << "x"
+ << getIntVal(matrixType->getColumnCount());
+ if (isColMajor)
+ nameSB << "_ColMajor";
+ nameSB << getLayoutName(config.layoutRuleName);
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
+ if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 &&
+ shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer())
+ {
+ // For constant buffer layout, we need to use 16-byte aligned vector if
+ // we are required to ensure array element types has 16-byte stride.
+ vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount(
+ target,
+ matrixType->getElementType(),
+ getIntVal(vectorSize)));
+ }
+
+ auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
+ IRSizeAndAlignment elementSizeAlignment;
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ vectorType,
+ &elementSizeAlignment);
+ elementSizeAlignment =
+ config.getLayoutRule()->alignCompositeElement(elementSizeAlignment);
+
+ auto arrayType = builder.getArrayType(
+ vectorType,
+ isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(),
+ builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride()));
+ builder.createStructField(loweredType, structKey, arrayType);
+
+ info.loweredType = loweredType;
+ info.loweredInnerArrayType = arrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal =
+ createMatrixUnpackFunc(matrixType, loweredType, structKey);
+ info.convertOriginalToLowered =
+ createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
+ return info;
+ }
+
+ info.loweredType = type;
+ return info;
+ }
+};
+
+struct KhronosTargetBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ KhronosTargetBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ // For spirv, we always want to lower all matrix types, because SPIRV does not support
+ // specifying matrix layout/stride if the matrix type is used in places other than
+ // defining a struct field. This means that if a matrix is used to define a varying
+ // parameter, we always want to wrap it in a struct.
+ //
+ if (target->shouldEmitSPIRVDirectly())
+ return true;
+ return DefaultBufferElementTypeLoweringPolicy::shouldLowerMatrixType(matrixType, config);
+ }
+
+ virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(
+ TypeLoweringConfig config) override
+ {
+ // For spirv backend, we always want to lower all array types, even if the element type
+ // comes out the same. This is because different layout rules may have different array
+ // stride requirements.
+ //
+ // Additionally, `buffer` blocks do not work correctly unless lowered when targeting
+ // GLSL.
+ return target->shouldEmitSPIRVDirectly() && config.addressSpace != AddressSpace::Input;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (target->shouldEmitSPIRVDirectly())
+ {
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+
+ switch (target->getTargetReq()->getTarget())
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ {
+ auto scalarType = type;
+ auto vectorType = as<IRVectorType>(scalarType);
+ if (vectorType)
+ scalarType = vectorType->getElementType();
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+
+ if (as<IRBoolType>(scalarType))
+ {
+ // Bool is an abstract type in SPIRV, so we need to lower them into an int.
+
+ // Find an integer type of the correct size for the current layout rule.
+ IRSizeAndAlignment boolSizeAndAlignment;
+ if (getSizeAndAlignment(
+ target->getOptionSet(),
+ config.getLayoutRule(),
+ scalarType,
+ &boolSizeAndAlignment) == SLANG_OK)
+ {
+ IntInfo ii;
+ ii.width = boolSizeAndAlignment.size * 8;
+ ii.isSigned = true;
+ info.loweredType = builder.getType(getIntTypeOpFromInfo(ii));
+ }
+ else
+ {
+ // Just in case that fails for some reason, just use an int.
+ info.loweredType = builder.getIntType();
+ }
+
+ if (vectorType)
+ info.loweredType = builder.getVectorType(
+ info.loweredType,
+ vectorType->getElementCount());
+ info.convertLoweredToOriginal = kIROp_BuiltinCast;
+ info.convertOriginalToLowered = kIROp_BuiltinCast;
+ return info;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct MetalParameterBlockElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ MetalParameterBlockElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override
+ {
+ SLANG_UNUSED(matrixType);
+ SLANG_UNUSED(config);
+ return false;
+ }
+
+ LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override
+ {
+ if (config.layoutRuleName == IRTypeLayoutRuleName::MetalParameterBlock &&
+ isResourceType(type))
+ {
+ IRBuilder builder(type);
+ builder.setInsertBefore(type);
+ LoweredElementTypeInfo info = {};
+ info.originalType = type;
+ info.loweredType = builder.getType(kIROp_DescriptorHandleType, type);
+ info.convertLoweredToOriginal = kIROp_CastDescriptorHandleToResource;
+ info.convertOriginalToLowered = kIROp_CastResourceToDescriptorHandle;
+ return info;
+ }
+ return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config);
+ }
+};
+
+struct WGSLBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy
+{
+ WGSLBufferElementTypeLoweringPolicy(
+ TargetProgram* inTarget,
+ BufferElementTypeLoweringOptions inOptions)
+ : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions)
+ {
+ }
+
+ virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() override
+ {
+ return true;
+ }
+};
+
+BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy(
+ BufferElementTypeLoweringPolicyKind kind,
+ TargetProgram* target,
+ BufferElementTypeLoweringOptions options)
+{
+ switch (kind)
+ {
+ case BufferElementTypeLoweringPolicyKind::Default:
+ return new DefaultBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::KhronosTarget:
+ return new KhronosTargetBufferElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::MetalParameterBlock:
+ return new MetalParameterBlockElementTypeLoweringPolicy(target, options);
+ case BufferElementTypeLoweringPolicyKind::WGSL:
+ return new WGSLBufferElementTypeLoweringPolicy(target, options);
+ }
+ SLANG_UNREACHABLE("unknown buffer element type lowering policy");
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h
index 9d6e53609..81ce73ddb 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.h
+++ b/source/slang/slang-ir-lower-buffer-element-type.h
@@ -1,19 +1,28 @@
#ifndef SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H
#define SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H
+#include "slang.h"
+
namespace Slang
{
struct IRModule;
class TargetProgram;
struct IRTypeLayoutRules;
struct IRType;
+enum class IRTypeLayoutRuleName;
-struct BufferElementTypeLoweringOptions
+enum class BufferElementTypeLoweringPolicyKind
{
- bool lowerBufferPointer = true;
+ Default,
+ KhronosTarget,
+ MetalParameterBlock,
+ WGSL
+};
- // For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers.
- bool use16ByteArrayElementForConstantBuffer = false;
+struct BufferElementTypeLoweringOptions
+{
+ BufferElementTypeLoweringPolicyKind loweringPolicyKind =
+ BufferElementTypeLoweringPolicyKind::Default;
};
// For each struct type S used as element type of a ConstantBuffer, ParameterBlock or
@@ -31,6 +40,8 @@ void lowerBufferElementTypeToStorageType(
// Returns the type layout rules should be used for a buffer resource type.
IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType);
+IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType);
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ir-wrap-cbuffer-element.cpp b/source/slang/slang-ir-wrap-cbuffer-element.cpp
new file mode 100644
index 000000000..9c070ab80
--- /dev/null
+++ b/source/slang/slang-ir-wrap-cbuffer-element.cpp
@@ -0,0 +1,133 @@
+#include "slang-ir-wrap-cbuffer-element.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+// This pass implements a simple translation that wraps the element type T in a ConstantBuffer<T>
+// (or ParameterBlock<T>) type in `struct S { T inner; }`, and replace the ConstantBuffer<T> type
+// with ConstantBuffer<S>. This is needed because some backends do not allow certain types to be
+// used directly as the element type of a constant buffer.
+// For example, Metal does not allow `ParameterBlock<StructuredBuffer<int>>` as that will create
+// a double pointer that Metal compiler does not like. We can easily work around this limitation
+// by wrapping the `StructuredBuffer<int>` in a struct.
+
+namespace Slang
+{
+
+void maybeProvideNameHint(
+ IRBuilder& builder,
+ IRStructType* wrappedStructType,
+ IRParameterGroupType* originalParamGroupType)
+{
+ StringBuilder sb;
+ sb << "wrapper_";
+ getTypeNameHint(sb, originalParamGroupType->getElementType());
+ builder.addNameHintDecoration(wrappedStructType, sb.produceString().getUnownedSlice());
+}
+
+void wrapCBufferElements(IRModule* module, WrapCBufferElementPolicy* policy)
+{
+ struct WorkItem
+ {
+ IRStructKey* wrappedFieldKey;
+ IRInst* inst;
+ IRInst* newParameterGroupType;
+ };
+
+ IRBuilder builder(module);
+
+ List<WorkItem> workList;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // Discover all insts whose type is a parameter group type.
+ if (auto paramGroupType = as<IRParameterGroupType>(globalInst))
+ {
+ if (!policy->shouldWrapBufferElementInStruct(paramGroupType))
+ continue;
+
+ // Create the wrapper struct.
+ builder.setInsertBefore(paramGroupType);
+ auto structType = builder.createStructType();
+ maybeProvideNameHint(builder, structType, paramGroupType);
+ auto fieldKey = builder.createStructKey();
+ builder.addNameHintDecoration(fieldKey, toSlice("inner"));
+ builder.createStructField(structType, fieldKey, paramGroupType->getElementType());
+
+ // Create the new parameter group type whose element is the wrapper struct.
+ List<IRInst*> bufferTypeOperands;
+ bufferTypeOperands.add(structType);
+ for (UInt i = 1; i < paramGroupType->getOperandCount(); ++i)
+ {
+ bufferTypeOperands.add(paramGroupType->getOperand(i));
+ }
+ auto newParameterGroupType = builder.getType(
+ paramGroupType->getOp(),
+ (UInt)bufferTypeOperands.getCount(),
+ bufferTypeOperands.getArrayView().getBuffer());
+
+ // Traverse all uses of the parameter group type, and add them to the work list
+ // for further processing.
+ traverseUses(
+ paramGroupType,
+ [&](IRUse* use)
+ {
+ if (use->getUser()->getFullType() != paramGroupType)
+ return;
+ WorkItem item;
+ item.wrappedFieldKey = fieldKey;
+ item.inst = use->getUser();
+ workList.add(item);
+ });
+ paramGroupType->replaceUsesWith(newParameterGroupType);
+ }
+ }
+
+ // Now we have a work list of all instructions that uses a parameter group.
+ // We need to update all uses of parameter group x with `x.inner` instead.
+ for (auto item : workList)
+ {
+ traverseUses(
+ item.inst,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+
+ // Note that we insert the field address instruction right before each use, instead
+ // of immediately after the original parameter group inst, because the parameter
+ // group inst may be defined in a scope that does not allow field address
+ // instructions.
+ auto unwrapped = builder.emitFieldAddress(item.inst, item.wrappedFieldKey);
+ builder.replaceOperand(use, unwrapped);
+ });
+ }
+}
+
+class MetalWrapCBufferElementPolicy : public WrapCBufferElementPolicy
+{
+public:
+ virtual bool shouldWrapBufferElementInStruct(IRParameterGroupType* cbufferType) override
+ {
+ // Metal allows structs, scalars, vectors and matrices directly as buffer elements.
+ if (as<IRStructType>(cbufferType->getElementType()))
+ return false;
+ if (as<IRBasicType>(cbufferType->getElementType()))
+ return false;
+ if (as<IRMatrixType>(cbufferType->getElementType()))
+ return false;
+ if (as<IRVectorType>(cbufferType->getElementType()))
+ return false;
+
+ // Wrap everything else in a struct.
+ return true;
+ }
+};
+
+void wrapCBufferElementsForMetal(IRModule* module)
+{
+ MetalWrapCBufferElementPolicy policy = {};
+ wrapCBufferElements(module, &policy);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-wrap-cbuffer-element.h b/source/slang/slang-ir-wrap-cbuffer-element.h
new file mode 100644
index 000000000..a75ec78a3
--- /dev/null
+++ b/source/slang/slang-ir-wrap-cbuffer-element.h
@@ -0,0 +1,23 @@
+#ifndef SLANG_IR_WRAP_CBUFFER_ELEMENT_H
+#define SLANG_IR_WRAP_CBUFFER_ELEMENT_H
+
+namespace Slang
+{
+struct IRModule;
+struct IRParameterGroupType;
+
+class WrapCBufferElementPolicy
+{
+public:
+ virtual bool shouldWrapBufferElementInStruct(IRParameterGroupType* cbufferType) = 0;
+};
+
+// Wrap the element type of a ConstantBuffer/ParameterBlock in a struct if the element type is not
+// something that allowed directly as a buffer element type by the target.
+void wrapCBufferElements(IRModule* module, WrapCBufferElementPolicy* policy);
+
+void wrapCBufferElementsForMetal(IRModule* module);
+
+} // namespace Slang
+
+#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9ec8a2c8b..d114a9a40 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -8731,6 +8731,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_CastUInt64ToDescriptorHandle:
case kIROp_CastDescriptorHandleToUInt64:
case kIROp_CastDescriptorHandleToResource:
+ case kIROp_CastResourceToDescriptorHandle:
case kIROp_GetDynamicResourceHeap:
case kIROp_CastDynamicResource:
case kIROp_AllocObj:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 161e70b25..aef3d6aeb 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -524,6 +524,7 @@ enum class IRTypeLayoutRuleName
Std430,
Std140,
D3DConstantBuffer,
+ MetalParameterBlock,
C,
_Count,
};