diff options
| author | Yong He <yonghe@outlook.com> | 2025-10-03 12:52:26 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-03 19:52:26 +0000 |
| commit | 6a2cf239a89340ed2985d04609499e8c4a2d8f89 (patch) | |
| tree | c80f3c8dc7e89762aeab0ee7830d1ad728665460 /source/slang | |
| parent | cc8f6a241edb47c43c5698ee33abed4fe57d4566 (diff) | |
Fix legalization crash when processing metal parameter blocks. (#8591)
Closes #7606.
When Slang compile for a bindful target, we will run the resource type
legalization pass to hoist resource typed struct fields outside of the
struct type and define them as global parameters and passing them around
via dedicated function parameters.
When we compile for a bindless target, we don't run this pass.
However, Metal is a hybrid bindful and bindless target. We need to run
type legalization for the constant buffer, but skip type legalization
for parameter block.
The previous attempt to support this behavior is to hack the type
legalization pass to return `LegalVal::simple` when it sees a
`ParameterBlock<T>`. However, whenever the code is accessing
`parameterBlock.someNestedField`, the type of the nested field may get a
`LegalType::tuple`, and now we will run into inconsistent scenarios
where we have a `LegalVal::simple` on the operand val, and but the
legalization logic is expecting that val to be a `LegalType::tuple`.
This breaks a lot of assumptions and invariants in the type legalization
pass, resulting unstable/fragile behavior.
To systematically solve this problem, this change generalizes the
existing legalize buffer element type pass to translate
`ParameterBlock<Texture2D>` (and similar cases) to
`ParameterBlock<Texture2D.Handle>`. So that such parameter block will
always be legalized to `LegalType:::simple` during type legalization,
and we will never run into any inconsistent cases. This allowed us to
get rid of the hacky logic in the type legalization pass to try to
workaround the inconsistencies.
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 55 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts-stable-names.lua | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.lua | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-layout.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 128 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 1058 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-cbuffer-element.cpp | 133 | ||||
| -rw-r--r-- | source/slang/slang-ir-wrap-cbuffer-element.h | 23 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 |
12 files changed, 834 insertions, 588 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 77c45a6d9..d69485cce 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2543,6 +2543,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_CastDescriptorHandleToUInt2: case kIROp_CastUInt2ToDescriptorHandle: case kIROp_CastDescriptorHandleToResource: + case kIROp_CastResourceToDescriptorHandle: case kIROp_CastUInt64ToDescriptorHandle: case kIROp_CastDescriptorHandleToUInt64: emitOperand(inst->getOperand(0), outerPrec); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 09c2efea9..804a44b81 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -120,6 +120,7 @@ #include "slang-ir-variable-scope-correction.h" #include "slang-ir-vk-invert-y.h" #include "slang-ir-wgsl-legalize.h" +#include "slang-ir-wrap-cbuffer-element.h" #include "slang-ir-wrap-structured-buffers.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" @@ -1276,6 +1277,23 @@ Result linkAndOptimizeIR( // We don't need the legalize pass for C/C++ based types if (options.shouldLegalizeExistentialAndResourceTypes) { + if (isMetalTarget(targetRequest)) + { + // Metal is a special target in that we want to legalize constant buffer + // types as if it is a bindful target, and skip legalizing parameter block + // types as if it is a bindless target. + // To achieve this, we want to ensure that all resource typed fields in parameter blocks + // are translated into descriptor handles first before running the resource type + // legalization pass for metal, so that type legalization pass won't mess around with + // them. + BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions = {}; + bufferElementTypeLoweringOptions.loweringPolicyKind = + BufferElementTypeLoweringPolicyKind::MetalParameterBlock; + lowerBufferElementTypeToStorageType( + targetProgram, + irModule, + bufferElementTypeLoweringOptions); + } // The Slang language allows interfaces to be used like // ordinary types (including placing them in constant @@ -1403,17 +1421,17 @@ Result linkAndOptimizeIR( // Some information for `static_assert` is available only after the specialization. checkStaticAssert(irModule->getModuleInst(), sink); - // For HLSL (and fxc/dxc) only, we need to "wrap" any - // structured buffers defined over matrix types so - // that they instead use an intermediate `struct`. - // This is required to get those targets to respect - // the options for matrix layout set via `#pragma` - // or command-line options. - // switch (target) { case CodeGenTarget::HLSL: { + // For HLSL(fxc) only, we need to "wrap" any + // structured buffers defined over matrix types so + // that they instead use an intermediate `struct`. + // This is required to get those targets to respect + // the options for matrix layout set via `#pragma` + // or command-line options. + // wrapStructuredBuffersOfMatrices(irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "STRUCTURED BUFFERS WRAPPED"); @@ -1421,7 +1439,12 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); } break; - + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + // Metal does not allow `ConstantBuffer<StructuredBuffer<T>>`, so we need to create + // a wrapper struct for the `StructuredBuffer<T>`. + wrapCBufferElementsForMetal(irModule); + break; default: break; } @@ -1828,11 +1851,16 @@ Result linkAndOptimizeIR( rcpWOfPositionInput(irModule); } - bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly(); - - BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions; - bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer = - isWGPUTarget(targetRequest); + BufferElementTypeLoweringOptions bufferElementTypeLoweringOptions = {}; + if (isWGPUTarget(targetRequest)) + bufferElementTypeLoweringOptions.loweringPolicyKind = + BufferElementTypeLoweringPolicyKind::WGSL; + else if (isKhronosTarget(targetRequest)) + bufferElementTypeLoweringOptions.loweringPolicyKind = + BufferElementTypeLoweringPolicyKind::KhronosTarget; + else + bufferElementTypeLoweringOptions.loweringPolicyKind = + BufferElementTypeLoweringPolicyKind::Default; lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions); // If we are generating code for glsl or metal, perform address space propagation now. @@ -1850,6 +1878,7 @@ Result linkAndOptimizeIR( performForceInlining(irModule); + bool emitSpirvDirectly = targetProgram->shouldEmitSPIRVDirectly(); if (emitSpirvDirectly) { performIntrinsicFunctionInlining(irModule); diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index cfde8c5fa..4b56cc52f 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -676,4 +676,5 @@ return { ["CastStorageToLogicalBase.CastStorageToLogicalDeref"] = 672, ["Decoration.DisableCopyEliminationDecoration"] = 673, ["Decoration.TempCallArgImmutableVar"] = 674, + ["CastResourceToDescriptorHandle"] = 675, } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 045144e06..8b8515424 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -1908,6 +1908,7 @@ local insts = { -- Represents a no-op cast to convert a resource pointer to a resource on targets where the resource handles are -- already concrete types. { CastDescriptorHandleToResource = { min_operands = 1 } }, + { CastResourceToDescriptorHandle = { min_operands = 1 } }, { TreatAsDynamicUniform = { min_operands = 1 } }, { sizeOf = { min_operands = 1 } }, { alignOf = { min_operands = 1 } }, diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 123dfdea4..9db4b3097 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -769,6 +769,7 @@ IRTypeLayoutRules* IRTypeLayoutRules::get(IRTypeLayoutRuleName name) case IRTypeLayoutRuleName::Std140: return getStd140(); case IRTypeLayoutRuleName::Natural: + case IRTypeLayoutRuleName::MetalParameterBlock: return getNatural(); case IRTypeLayoutRuleName::C: return getC(); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 5168f0466..021566e12 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1950,133 +1950,6 @@ static LegalVal legalizeDefaultConstruct(IRTypeLegalizationContext* context, Leg } } -// If a legalized `val` has a different flavor than `type`, try to coerce it to `type`. -// -static LegalVal coerceToLegalType(IRTypeLegalizationContext* context, LegalType type, LegalVal val) -{ - switch (type.flavor) - { - case LegalType::Flavor::none: - return LegalVal(); - case LegalType::Flavor::simple: - { - if (val.flavor != LegalVal::Flavor::simple) - return val; - auto simpleVal = val.getSimple(); - if (simpleVal->getDataType() == type.getSimple()) - return val; - - auto resultType = type.getSimple(); - auto structType = as<IRStructType>(resultType); - if (!structType) - { - auto resultValueType = tryGetPointedToType(context->builder, resultType); - if (!resultValueType) - return val; - auto valValueType = tryGetPointedToType(context->builder, simpleVal->getDataType()); - if (!valValueType) - return val; - if (resultValueType == valValueType) - return val; - auto loadedVal = context->builder->emitLoad(val.getSimple()); - auto innerLegalVal = coerceToLegalType( - context, - LegalType::simple(resultValueType), - LegalVal::simple(loadedVal)); - return LegalVal::implicitDeref(innerLegalVal); - } - ShortList<IRInst*> fields; - for (auto field : structType->getFields()) - { - if (as<IRVoidType>(field->getFieldType())) - continue; - auto fieldVal = coerceToLegalType( - context, - LegalType::simple(field->getFieldType()), - LegalVal::simple( - context->builder->emitFieldExtract(simpleVal, field->getKey()))); - fields.add(fieldVal.getSimple()); - } - return LegalVal::simple(context->builder->emitMakeStruct( - structType, - (UInt)fields.getCount(), - fields.getArrayView().getBuffer())); - } - case LegalType::Flavor::implicitDeref: - { - auto innerVal = val; - if (innerVal.flavor == LegalVal::Flavor::implicitDeref) - innerVal = innerVal.getImplicitDeref(); - else if (innerVal.flavor == LegalVal::Flavor::simple) - innerVal = LegalVal::simple(context->builder->emitLoad(innerVal.getSimple())); - innerVal = coerceToLegalType(context, type.getImplicitDeref()->valueType, innerVal); - return LegalVal::implicitDeref(innerVal); - } - case LegalType::Flavor::pair: - { - if (val.flavor == LegalVal::Flavor::pair) - return val; - else if (val.flavor == LegalVal::Flavor::simple) - { - auto pairType = type.getPair(); - auto pairInfo = pairType->pairInfo; - LegalVal ordinaryVal = coerceToLegalType(context, pairType->ordinaryType, val); - LegalVal specialVal = coerceToLegalType(context, pairType->specialType, val); - return LegalVal::pair(ordinaryVal, specialVal, pairInfo); - } - else if (val.flavor == LegalVal::Flavor::implicitDeref) - { - LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref()); - return LegalVal::implicitDeref(innerVal); - } - else - { - SLANG_UNEXPECTED("unhandled legal type coercion"); - UNREACHABLE_RETURN(LegalVal()); - } - } - case LegalType::Flavor::tuple: - { - if (val.flavor == LegalVal::Flavor::tuple) - return val; - else if (val.flavor == LegalVal::Flavor::simple) - { - auto tupleType = type.getTuple(); - RefPtr<TuplePseudoVal> tupleVal = new TuplePseudoVal(); - auto simpleVal = val.getSimple(); - for (auto elem : tupleType->elements) - { - IRInst* elementVal = nullptr; - if (as<IRPtrTypeBase>(simpleVal->getDataType()) || - as<IRPointerLikeType>(simpleVal->getDataType())) - elementVal = context->builder->emitFieldAddress(simpleVal, elem.key); - else - elementVal = context->builder->emitFieldExtract(simpleVal, elem.key); - LegalVal legalElementVal = - coerceToLegalType(context, elem.type, LegalVal::simple(elementVal)); - TuplePseudoVal::Element tupleElem; - tupleElem.key = elem.key; - tupleElem.val = legalElementVal; - tupleVal->elements.add(tupleElem); - } - return LegalVal::tuple(tupleVal); - } - else if (val.flavor == LegalVal::Flavor::implicitDeref) - { - LegalVal innerVal = coerceToLegalType(context, type, val.getImplicitDeref()); - return LegalVal::implicitDeref(innerVal); - } - else - { - SLANG_UNEXPECTED("unhandled legal type coercion"); - UNREACHABLE_RETURN(LegalVal()); - } - } - default: - return val; - } -} - static LegalVal legalizeUndefined(IRTypeLegalizationContext* context, IRInst* inst) { IRType* opaqueType = nullptr; @@ -2188,7 +2061,6 @@ static LegalVal legalizeInst( SLANG_UNEXPECTED("non-simple operand(s)!"); break; } - result = coerceToLegalType(context, type, result); return result; } diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 83218bade..c69592939 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -214,104 +214,173 @@ namespace Slang { +enum ConversionMethodKind +{ + Func, + Opcode +}; + +struct ConversionMethod +{ + ConversionMethodKind kind = ConversionMethodKind::Func; + union + { + IRFunc* func; + IROp op; + }; + ConversionMethod() { func = nullptr; } + operator bool() + { + return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop; + } + ConversionMethod& operator=(IRFunc* f) + { + kind = ConversionMethodKind::Func; + this->func = f; + return *this; + } + ConversionMethod& operator=(IROp irop) + { + kind = ConversionMethodKind::Opcode; + this->op = irop; + return *this; + } + IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr); + void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand); +}; + struct TypeLoweringConfig { AddressSpace addressSpace; - IRTypeLayoutRules* layoutRule; + IRTypeLayoutRuleName layoutRuleName; + IRTypeLayoutRules* getLayoutRule() const { return IRTypeLayoutRules::get(layoutRuleName); } + bool operator==(const TypeLoweringConfig& other) const { - return addressSpace == other.addressSpace && layoutRule == other.layoutRule; + return addressSpace == other.addressSpace && layoutRuleName == other.layoutRuleName; } HashCode getHashCode() const { - return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule)); + return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRuleName)); } }; + +struct LoweredElementTypeInfo +{ + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // inner array type of the data field. + IRStructKey* loweredInnerStructKey = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // struct key of the data field. + ConversionMethod convertOriginalToLowered; + ConversionMethod convertLoweredToOriginal; +}; + +/// Defines target-specific behavior of how to lower buffer element types. +struct BufferElementTypeLoweringPolicy : public RefObject +{ + /// Defines target-specific behavior of how to translate a non-composite logical type to a + /// storage type. + virtual LoweredElementTypeInfo lowerLeafLogicalType( + IRType* type, + TypeLoweringConfig config) = 0; + + /// Returns true if we should always create a fresh lowered storage type for a composite type, + /// even if every member/element of the composite type is not changed by the lowering. + virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(TypeLoweringConfig config) + { + SLANG_UNUSED(config); + return false; + } + + /// Returns true if the target requires all array of scalars or vectors inside a constant buffer + /// to be translated into a 16-byte aligned vector type. + virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() + { + return false; + } +}; + +BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy( + BufferElementTypeLoweringPolicyKind kind, + TargetProgram* target, + BufferElementTypeLoweringOptions options); + TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType); -struct LoweredElementTypeContext +IRInst* ConversionMethod::apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) { - static const IRIntegerValue kMaxArraySizeToUnroll = 32; + if (!*this) + return builder.emitLoad(operandAddr); + if (kind == ConversionMethodKind::Func) + return builder.emitCallInst(resultType, func, 1, &operandAddr); + else + { + auto val = builder.emitLoad(operandAddr); + return builder.emitIntrinsicInst(resultType, op, 1, &val); + } +} - enum ConversionMethodKind +void ConversionMethod::applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) +{ + if (!*this) { - Func, - Opcode - }; - struct ConversionMethod + builder.emitStore(dest, operand); + return; + } + if (kind == ConversionMethodKind::Func) { - ConversionMethodKind kind = ConversionMethodKind::Func; - union - { - IRFunc* func; - IROp op; - }; - ConversionMethod() { func = nullptr; } - operator bool() - { - return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop; - } - ConversionMethod& operator=(IRFunc* f) - { - kind = ConversionMethodKind::Func; - this->func = f; - return *this; - } - ConversionMethod& operator=(IROp irop) - { - kind = ConversionMethodKind::Opcode; - this->op = irop; - return *this; - } - IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) - { - if (!*this) - return builder.emitLoad(operandAddr); - if (kind == ConversionMethodKind::Func) - return builder.emitCallInst(resultType, func, 1, &operandAddr); - else - { - auto val = builder.emitLoad(operandAddr); - return builder.emitIntrinsicInst(resultType, op, 1, &val); - } - } - void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) - { - if (!*this) - { - builder.emitStore(dest, operand); - return; - } - if (kind == ConversionMethodKind::Func) - { - IRInst* operands[] = {dest, operand}; - builder.emitCallInst(builder.getVoidType(), func, 2, operands); - } - else - { - auto val = builder.emitIntrinsicInst( - tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), - op, - 1, - &operand); - builder.emitStore(dest, val); - } - } - }; + IRInst* operands[] = {dest, operand}; + builder.emitCallInst(builder.getVoidType(), func, 2, operands); + } + else + { + auto val = builder.emitIntrinsicInst( + tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), + op, + 1, + &operand); + builder.emitStore(dest, val); + } +} + +// Returns the number of elements N that ensures the IRVectorType(elementType,N) +// has 16-byte aligned size and N is no less than `minCount`. +IRIntegerValue get16ByteAlignedVectorElementCount( + TargetProgram* target, + IRType* elementType, + IRIntegerValue minCount) +{ + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment); + if (sizeAlignment.size) + return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size; + return 4; +} - struct LoweredElementTypeInfo +const char* getLayoutName(IRTypeLayoutRuleName name) +{ + switch (name) { - IRType* originalType; - IRType* loweredType; - IRType* loweredInnerArrayType = - nullptr; // For matrix/array types that are lowered into a struct type, this is the - // inner array type of the data field. - IRStructKey* loweredInnerStructKey = - nullptr; // For matrix/array types that are lowered into a struct type, this is the - // struct key of the data field. - ConversionMethod convertOriginalToLowered; - ConversionMethod convertLoweredToOriginal; - }; + case IRTypeLayoutRuleName::Std140: + return "std140"; + case IRTypeLayoutRuleName::Std430: + return "std430"; + case IRTypeLayoutRuleName::Natural: + return "natural"; + case IRTypeLayoutRuleName::C: + return "c"; + default: + return "default"; + } +} + +struct LoweredElementTypeContext +{ + static const IRIntegerValue kMaxArraySizeToUnroll = 32; struct LoweredTypeMap : RefObject { @@ -320,6 +389,7 @@ struct LoweredElementTypeContext }; Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps; + RefPtr<BufferElementTypeLoweringPolicy> leafTypeLoweringPolicy; struct ConversionMethodKey { @@ -366,150 +436,11 @@ struct LoweredElementTypeContext // Specialized functions that takes storage-typed pointers instead of logical-typed pointers. Dictionary<SpecializationKey, IRFunc*> specializedFuncs; - LoweredElementTypeContext( - TargetProgram* target, - BufferElementTypeLoweringOptions inOptions, - SlangMatrixLayoutMode inDefaultMatrixLayout) - : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions) - { - } - - IRFunc* createMatrixUnpackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRStructKey* dataKey) - { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic); - auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); - builder.addForceInlineDecoration(func); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = (Index)getIntVal(matrixType->getRowCount()); - auto colCount = (Index)getIntVal(matrixType->getColumnCount()); - auto packedParamRef = builder.emitParam(refStructType); - auto packedParam = builder.emitLoad(packedParamRef); - auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); - List<IRInst*> args; - args.setCount(rowCount * colCount); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto vector = builder.emitElementExtract(vectorArray, c); - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = builder.emitElementExtract(vector, r); - args[(Index)(r * colCount + c)] = element; - } - } - } - else - { - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto vector = builder.emitElementExtract(vectorArray, r); - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = builder.emitElementExtract(vector, c); - args[(Index)(r * colCount + c)] = element; - } - } - } - IRInst* result = - builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); - builder.emitReturn(result); - return func; - } - - IRFunc* createMatrixPackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRVectorType* vectorType, - IRArrayType* arrayType) + LoweredElementTypeContext(TargetProgram* target, BufferElementTypeLoweringOptions inOptions) + : target(target), options(inOptions) { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic); - IRType* paramTypes[] = {outStructType, matrixType}; - auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); - builder.addForceInlineDecoration(func); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); - auto outParam = builder.emitParam(outStructType); - auto originalParam = builder.emitParam(matrixType); - List<IRInst*> elements; - elements.setCount((Index)(rowCount * colCount)); - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto vector = builder.emitElementExtract(originalParam, r); - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = builder.emitElementExtract(vector, c); - elements[(Index)(r * colCount + c)] = element; - } - } - List<IRInst*> vectors; - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue c = 0; c < colCount; c++) - { - List<IRInst*> vecArgs; - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - // Fill in default values for remaining elements in the vector. - for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++) - { - vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); - } - auto colVector = builder.emitMakeVector( - vectorType, - (UInt)vecArgs.getCount(), - vecArgs.getBuffer()); - vectors.add(colVector); - } - } - else - { - for (IRIntegerValue r = 0; r < rowCount; r++) - { - List<IRInst*> vecArgs; - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - // Fill in default values for remaining elements in the vector. - for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++) - { - vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); - } - auto rowVector = builder.emitMakeVector( - vectorType, - (UInt)vecArgs.getCount(), - vecArgs.getBuffer()); - vectors.add(rowVector); - } - } - - auto vectorArray = - builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); - auto result = builder.emitMakeStruct(structType, 1, &vectorArray); - builder.emitStore(outParam, result); - builder.emitReturn(); - return func; + leafTypeLoweringPolicy = + getBufferElementTypeLoweringPolicy(options.loweringPolicyKind, target, options); } IRFunc* createArrayUnpackFunc( @@ -637,56 +568,6 @@ struct LoweredElementTypeContext return func; } - const char* getLayoutName(IRTypeLayoutRuleName name) - { - switch (name) - { - case IRTypeLayoutRuleName::Std140: - return "std140"; - case IRTypeLayoutRuleName::Std430: - return "std430"; - case IRTypeLayoutRuleName::Natural: - return "natural"; - case IRTypeLayoutRuleName::C: - return "c"; - default: - return "default"; - } - } - - // Returns the number of elements N that ensures the IRVectorType(elementType,N) - // has 16-byte aligned size and N is no less than `minCount`. - IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount) - { - IRSizeAndAlignment sizeAlignment; - getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment); - if (sizeAlignment.size) - return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size; - return 4; - } - - bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) - { - // For spirv, we always want to lower all matrix types, because SPIRV does not support - // specifying matrix layout/stride if the matrix type is used in places other than - // defining a struct field. This means that if a matrix is used to define a varying - // parameter, we always want to wrap it in a struct. - // - if (target->shouldEmitSPIRVDirectly()) - { - return true; - } - - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) - { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - return false; - } - return true; - } - LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); @@ -694,72 +575,13 @@ struct LoweredElementTypeContext LoweredElementTypeInfo info; info.originalType = type; - - if (auto matrixType = as<IRMatrixType>(type)) - { - if (!shouldLowerMatrixType(matrixType, config)) - { - info.loweredType = type; - return info; - } - - auto loweredType = builder.createStructType(); - builder.addPhysicalTypeDecoration(loweredType); - - StringBuilder nameSB; - bool isColMajor = - getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; - nameSB << "_MatrixStorage_"; - getTypeNameHint(nameSB, matrixType->getElementType()); - nameSB << getIntVal(matrixType->getRowCount()) << "x" - << getIntVal(matrixType->getColumnCount()); - if (isColMajor) - nameSB << "_ColMajor"; - nameSB << getLayoutName(config.layoutRule->ruleName); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); - if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && - options.use16ByteArrayElementForConstantBuffer) - { - // For constant buffer layout, we need to use 16-byte aligned vector if - // we are required to ensure array element types has 16-byte stride. - vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount( - matrixType->getElementType(), - getIntVal(vectorSize))); - } - - auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - vectorType, - &elementSizeAlignment); - elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); - - auto arrayType = builder.getArrayType( - vectorType, - isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, arrayType); - - info.loweredType = loweredType; - info.loweredInnerArrayType = arrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = - createMatrixUnpackFunc(matrixType, loweredType, structKey); - info.convertOriginalToLowered = - createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); - return info; - } - else if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) + if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) { auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config); - if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && - options.use16ByteArrayElementForConstantBuffer) + if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 && + leafTypeLoweringPolicy + ->shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()) { // For constant buffer layout, we need to use 16-byte-aligned vector if // we are required to ensure array element types has 16-byte stride. @@ -772,6 +594,7 @@ struct LoweredElementTypeContext packedVectorType = builder.getVectorType( vectorType->getElementType(), builder.getIntValue(get16ByteAlignedVectorElementCount( + target, vectorType->getElementType(), getIntVal(vectorType->getElementCount())))); if (packedVectorType != loweredInnerTypeInfo.originalType) @@ -784,7 +607,7 @@ struct LoweredElementTypeContext { packedVectorType = builder.getVectorType( loweredInnerTypeInfo.loweredType, - get16ByteAlignedVectorElementCount(scalarType, 1)); + get16ByteAlignedVectorElementCount(target, scalarType, 1)); loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape; loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar; } @@ -803,10 +626,10 @@ struct LoweredElementTypeContext } } - // For spirv backend, we always want to lower all array types for non-varying - // parameters, even if the element type comes out the same. This is because different - // layout rules may have different array stride requirements. - if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input) + // We can skip lowering this type if all field types are unchanged, unless the target + // specific policy tells us to always create a lowered type. + if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + config)) { if (!loweredInnerTypeInfo.convertLoweredToOriginal) { @@ -823,7 +646,7 @@ struct LoweredElementTypeContext info.loweredType = loweredType; StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; + nameSB << "_Array_" << getLayoutName(config.layoutRuleName) << "_"; getTypeNameHint(nameSB, arrayType->getElementType()); nameSB << getArraySizeVal(arrayType->getElementCount()); @@ -835,11 +658,11 @@ struct LoweredElementTypeContext IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerTypeInfo.loweredType, &elementSizeAlignment); elementSizeAlignment = - config.layoutRule->alignCompositeElement(elementSizeAlignment); + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); auto innerArrayType = builder.getArrayType( loweredInnerTypeInfo.loweredType, arrayType->getElementCount(), @@ -857,11 +680,11 @@ struct LoweredElementTypeContext IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerTypeInfo.loweredType, &elementSizeAlignment); elementSizeAlignment = - config.layoutRule->alignCompositeElement(elementSizeAlignment); + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); auto innerArrayType = builder.getArrayTypeBase( arrayTypeBase->getOp(), loweredInnerTypeInfo.loweredType, @@ -880,20 +703,15 @@ struct LoweredElementTypeContext auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config); fieldLoweredTypeInfo.add(loweredFieldTypeInfo); if (loweredFieldTypeInfo.convertLoweredToOriginal || - config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural) + config.layoutRuleName != IRTypeLayoutRuleName::Natural) isTrivial = false; } - // For spirv backend, we always want to lower all array types, even if the element type - // comes out the same. This is because different layout rules may have different array - // stride requirements. - // - // Additionally, `buffer` blocks do not work correctly unless lowered when targeting - // GLSL. - if (!isKhronosTarget(target->getTargetReq())) + // We can skip lowering this type if all field types are unchanged, unless the target + // specific policy tells us to always create a lowered type. + if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + config)) { - // For non-spirv target, we skip lowering this type if all field types are - // unchanged. if (isTrivial) { info.loweredType = type; @@ -905,7 +723,7 @@ struct LoweredElementTypeContext StringBuilder nameSB; getTypeNameHint(nameSB, type); - nameSB << "_" << getLayoutName(config.layoutRule->ruleName); + nameSB << "_" << getLayoutName(config.layoutRuleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); info.loweredType = loweredType; // Create fields. @@ -1007,58 +825,7 @@ struct LoweredElementTypeContext return info; } - - if (target->shouldEmitSPIRVDirectly()) - { - switch (target->getTargetReq()->getTarget()) - { - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: - { - auto scalarType = type; - auto vectorType = as<IRVectorType>(scalarType); - if (vectorType) - scalarType = vectorType->getElementType(); - - if (as<IRBoolType>(scalarType)) - { - // Bool is an abstract type in SPIRV, so we need to lower them into an int. - - // Find an integer type of the correct size for the current layout rule. - IRSizeAndAlignment boolSizeAndAlignment; - if (getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - scalarType, - &boolSizeAndAlignment) == SLANG_OK) - { - IntInfo ii; - ii.width = boolSizeAndAlignment.size * 8; - ii.isSigned = true; - info.loweredType = builder.getType(getIntTypeOpFromInfo(ii)); - } - else - { - // Just in case that fails for some reason, just use an int. - info.loweredType = builder.getIntType(); - } - - if (vectorType) - info.loweredType = builder.getVectorType( - info.loweredType, - vectorType->getElementCount()); - info.convertLoweredToOriginal = kIROp_BuiltinCast; - info.convertOriginalToLowered = kIROp_BuiltinCast; - return info; - } - } - default: - break; - } - } - - info.loweredType = type; - return info; + return leafTypeLoweringPolicy->lowerLeafLogicalType(type, config); } LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config) @@ -1090,7 +857,7 @@ struct LoweredElementTypeContext IRSizeAndAlignment sizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), info.loweredType, &sizeAlignment); loweredTypeInfo.set(type, info); @@ -1190,13 +957,13 @@ struct LoweredElementTypeContext IRSizeAndAlignment arrayElementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerType.loweredType, &arrayElementSizeAlignment); IRSizeAndAlignment baseSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()), &baseSizeAlignment); @@ -1701,9 +1468,6 @@ struct LoweredElementTypeContext switch (ptrType->getAddressSpace()) { case AddressSpace::UserPointer: - if (!options.lowerBufferPointer) - continue; - [[fallthrough]]; case AddressSpace::Input: case AddressSpace::Output: elementType = ptrType->getValueType(); @@ -1720,7 +1484,7 @@ struct LoweredElementTypeContext IRSizeAndAlignment sizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), elementType, &sizeAlignment); SLANG_UNUSED(sizeAlignment); @@ -2181,63 +1945,60 @@ void lowerBufferElementTypeToStorageType( IRModule* module, BufferElementTypeLoweringOptions options) { - SlangMatrixLayoutMode defaultMatrixMode = - (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); - if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || - isMetalTarget(target->getTargetReq()))) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(target, options, defaultMatrixMode); + LoweredElementTypeContext context(target, options); context.processModule(module); } -IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout) +IRTypeLayoutRuleName getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRuleName defaultLayout) { switch (layoutTypeOp) { case kIROp_DefaultBufferLayoutType: return defaultLayout; case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); + return IRTypeLayoutRuleName::Std140; case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); + return IRTypeLayoutRuleName::Std430; case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; case kIROp_CBufferLayoutType: - return IRTypeLayoutRules::getC(); + return IRTypeLayoutRuleName::C; } return defaultLayout; } -IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType) { + if (bufferType->getOp() == kIROp_ParameterBlockType && isMetalTarget(target->getTargetReq())) + { + return IRTypeLayoutRuleName::MetalParameterBlock; + } if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) { if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; // If we are just emitting GLSL, we can just use the general layout rule. if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; // If the user specified a C-compatible buffer layout, then do that. if (target->getOptionSet().shouldUseCLayout()) - return IRTypeLayoutRules::getC(); + return IRTypeLayoutRuleName::C; // If the user specified a scalar buffer layout, then just use that. if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } if (target->getOptionSet().shouldUseDXLayout()) { if (as<IRUniformParameterGroupType>(bufferType)) { - return IRTypeLayoutRules::getConstantBuffer(); + return IRTypeLayoutRuleName::D3DConstantBuffer; } else - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } // The default behavior is to use std140 for constant buffers and std430 for other buffers. @@ -2253,17 +2014,17 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = structBufferType->getDataLayout() ? structBufferType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430); } - case kIROp_ConstantBufferType: case kIROp_ParameterBlockType: + case kIROp_ConstantBufferType: { auto parameterGroupType = as<IRUniformParameterGroupType>(bufferType); auto layoutTypeOp = parameterGroupType->getDataLayout() ? parameterGroupType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std140); } case kIROp_GLSLShaderStorageBufferType: { @@ -2271,12 +2032,18 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = storageBufferType->getDataLayout() ? storageBufferType->getDataLayout()->getOp() : kIROp_Std430BufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430); } case kIROp_PtrType: - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; +} + +IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +{ + auto ruleName = getTypeLayoutRuleNameForBuffer(target, bufferType); + return IRTypeLayoutRules::get(ruleName); } TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType) @@ -2295,8 +2062,413 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* break; } } - auto rules = getTypeLayoutRuleForBuffer(target, bufferType); + auto rules = getTypeLayoutRuleNameForBuffer(target, bufferType); return TypeLoweringConfig{addrSpace, rules}; } +struct DefaultBufferElementTypeLoweringPolicy : BufferElementTypeLoweringPolicy +{ + TargetProgram* target; + BufferElementTypeLoweringOptions options; + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + + DefaultBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : target(inTarget), options(inOptions) + { + defaultMatrixLayout = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); + if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || + isMetalTarget(target->getTargetReq()))) + defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + else if (defaultMatrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) + { + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + config.getLayoutRule()->ruleName == IRTypeLayoutRuleName::Natural) + { + // We only lower the matrix types if they differ from the default + // matrix layout. + return false; + } + return true; + } + + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParamRef = builder.emitParam(refStructType); + auto packedParam = builder.emitLoad(packedParamRef); + auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto vector = builder.emitElementExtract(vectorArray, c); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = builder.emitElementExtract(vector, r); + args[(Index)(r * colCount + c)] = element; + } + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(vectorArray, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + args[(Index)(r * colCount + c)] = element; + } + } + } + IRInst* result = + builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRVectorType* vectorType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outStructType, matrixType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.addForceInlineDecoration(func); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto outParam = builder.emitParam(outStructType); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; + } + } + List<IRInst*> vectors; + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } + auto colVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(colVector); + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } + auto rowVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(rowVector); + } + } + + auto vectorArray = + builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &vectorArray); + builder.emitStore(outParam, result); + builder.emitReturn(); + return func; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + IRBuilder builder(type); + builder.setInsertAfter(type); + + LoweredElementTypeInfo info; + info.originalType = type; + + if (auto matrixType = as<IRMatrixType>(type)) + { + if (!shouldLowerMatrixType(matrixType, config)) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + + StringBuilder nameSB; + bool isColMajor = + getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" + << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + nameSB << getLayoutName(config.layoutRuleName); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); + if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 && + shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()) + { + // For constant buffer layout, we need to use 16-byte aligned vector if + // we are required to ensure array element types has 16-byte stride. + vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount( + target, + matrixType->getElementType(), + getIntVal(vectorSize))); + } + + auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.getLayoutRule(), + vectorType, + &elementSizeAlignment); + elementSizeAlignment = + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); + + auto arrayType = builder.getArrayType( + vectorType, + isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, arrayType); + + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = + createMatrixUnpackFunc(matrixType, loweredType, structKey); + info.convertOriginalToLowered = + createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); + return info; + } + + info.loweredType = type; + return info; + } +}; + +struct KhronosTargetBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + KhronosTargetBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override + { + // For spirv, we always want to lower all matrix types, because SPIRV does not support + // specifying matrix layout/stride if the matrix type is used in places other than + // defining a struct field. This means that if a matrix is used to define a varying + // parameter, we always want to wrap it in a struct. + // + if (target->shouldEmitSPIRVDirectly()) + return true; + return DefaultBufferElementTypeLoweringPolicy::shouldLowerMatrixType(matrixType, config); + } + + virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + TypeLoweringConfig config) override + { + // For spirv backend, we always want to lower all array types, even if the element type + // comes out the same. This is because different layout rules may have different array + // stride requirements. + // + // Additionally, `buffer` blocks do not work correctly unless lowered when targeting + // GLSL. + return target->shouldEmitSPIRVDirectly() && config.addressSpace != AddressSpace::Input; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + if (target->shouldEmitSPIRVDirectly()) + { + LoweredElementTypeInfo info = {}; + info.originalType = type; + + switch (target->getTargetReq()->getTarget()) + { + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + { + auto scalarType = type; + auto vectorType = as<IRVectorType>(scalarType); + if (vectorType) + scalarType = vectorType->getElementType(); + IRBuilder builder(type); + builder.setInsertBefore(type); + + if (as<IRBoolType>(scalarType)) + { + // Bool is an abstract type in SPIRV, so we need to lower them into an int. + + // Find an integer type of the correct size for the current layout rule. + IRSizeAndAlignment boolSizeAndAlignment; + if (getSizeAndAlignment( + target->getOptionSet(), + config.getLayoutRule(), + scalarType, + &boolSizeAndAlignment) == SLANG_OK) + { + IntInfo ii; + ii.width = boolSizeAndAlignment.size * 8; + ii.isSigned = true; + info.loweredType = builder.getType(getIntTypeOpFromInfo(ii)); + } + else + { + // Just in case that fails for some reason, just use an int. + info.loweredType = builder.getIntType(); + } + + if (vectorType) + info.loweredType = builder.getVectorType( + info.loweredType, + vectorType->getElementCount()); + info.convertLoweredToOriginal = kIROp_BuiltinCast; + info.convertOriginalToLowered = kIROp_BuiltinCast; + return info; + } + } + break; + default: + break; + } + } + return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config); + } +}; + +struct MetalParameterBlockElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + MetalParameterBlockElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override + { + SLANG_UNUSED(matrixType); + SLANG_UNUSED(config); + return false; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + if (config.layoutRuleName == IRTypeLayoutRuleName::MetalParameterBlock && + isResourceType(type)) + { + IRBuilder builder(type); + builder.setInsertBefore(type); + LoweredElementTypeInfo info = {}; + info.originalType = type; + info.loweredType = builder.getType(kIROp_DescriptorHandleType, type); + info.convertLoweredToOriginal = kIROp_CastDescriptorHandleToResource; + info.convertOriginalToLowered = kIROp_CastResourceToDescriptorHandle; + return info; + } + return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config); + } +}; + +struct WGSLBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + WGSLBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() override + { + return true; + } +}; + +BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy( + BufferElementTypeLoweringPolicyKind kind, + TargetProgram* target, + BufferElementTypeLoweringOptions options) +{ + switch (kind) + { + case BufferElementTypeLoweringPolicyKind::Default: + return new DefaultBufferElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::KhronosTarget: + return new KhronosTargetBufferElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::MetalParameterBlock: + return new MetalParameterBlockElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::WGSL: + return new WGSLBufferElementTypeLoweringPolicy(target, options); + } + SLANG_UNREACHABLE("unknown buffer element type lowering policy"); +} + } // namespace Slang diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h index 9d6e53609..81ce73ddb 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.h +++ b/source/slang/slang-ir-lower-buffer-element-type.h @@ -1,19 +1,28 @@ #ifndef SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H #define SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H +#include "slang.h" + namespace Slang { struct IRModule; class TargetProgram; struct IRTypeLayoutRules; struct IRType; +enum class IRTypeLayoutRuleName; -struct BufferElementTypeLoweringOptions +enum class BufferElementTypeLoweringPolicyKind { - bool lowerBufferPointer = true; + Default, + KhronosTarget, + MetalParameterBlock, + WGSL +}; - // For WGSL, we can only create arrays that has a stride of 16 bytes for constant buffers. - bool use16ByteArrayElementForConstantBuffer = false; +struct BufferElementTypeLoweringOptions +{ + BufferElementTypeLoweringPolicyKind loweringPolicyKind = + BufferElementTypeLoweringPolicyKind::Default; }; // For each struct type S used as element type of a ConstantBuffer, ParameterBlock or @@ -31,6 +40,8 @@ void lowerBufferElementTypeToStorageType( // Returns the type layout rules should be used for a buffer resource type. IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType); +IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType); + } // namespace Slang #endif diff --git a/source/slang/slang-ir-wrap-cbuffer-element.cpp b/source/slang/slang-ir-wrap-cbuffer-element.cpp new file mode 100644 index 000000000..9c070ab80 --- /dev/null +++ b/source/slang/slang-ir-wrap-cbuffer-element.cpp @@ -0,0 +1,133 @@ +#include "slang-ir-wrap-cbuffer-element.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +// This pass implements a simple translation that wraps the element type T in a ConstantBuffer<T> +// (or ParameterBlock<T>) type in `struct S { T inner; }`, and replace the ConstantBuffer<T> type +// with ConstantBuffer<S>. This is needed because some backends do not allow certain types to be +// used directly as the element type of a constant buffer. +// For example, Metal does not allow `ParameterBlock<StructuredBuffer<int>>` as that will create +// a double pointer that Metal compiler does not like. We can easily work around this limitation +// by wrapping the `StructuredBuffer<int>` in a struct. + +namespace Slang +{ + +void maybeProvideNameHint( + IRBuilder& builder, + IRStructType* wrappedStructType, + IRParameterGroupType* originalParamGroupType) +{ + StringBuilder sb; + sb << "wrapper_"; + getTypeNameHint(sb, originalParamGroupType->getElementType()); + builder.addNameHintDecoration(wrappedStructType, sb.produceString().getUnownedSlice()); +} + +void wrapCBufferElements(IRModule* module, WrapCBufferElementPolicy* policy) +{ + struct WorkItem + { + IRStructKey* wrappedFieldKey; + IRInst* inst; + IRInst* newParameterGroupType; + }; + + IRBuilder builder(module); + + List<WorkItem> workList; + for (auto globalInst : module->getGlobalInsts()) + { + // Discover all insts whose type is a parameter group type. + if (auto paramGroupType = as<IRParameterGroupType>(globalInst)) + { + if (!policy->shouldWrapBufferElementInStruct(paramGroupType)) + continue; + + // Create the wrapper struct. + builder.setInsertBefore(paramGroupType); + auto structType = builder.createStructType(); + maybeProvideNameHint(builder, structType, paramGroupType); + auto fieldKey = builder.createStructKey(); + builder.addNameHintDecoration(fieldKey, toSlice("inner")); + builder.createStructField(structType, fieldKey, paramGroupType->getElementType()); + + // Create the new parameter group type whose element is the wrapper struct. + List<IRInst*> bufferTypeOperands; + bufferTypeOperands.add(structType); + for (UInt i = 1; i < paramGroupType->getOperandCount(); ++i) + { + bufferTypeOperands.add(paramGroupType->getOperand(i)); + } + auto newParameterGroupType = builder.getType( + paramGroupType->getOp(), + (UInt)bufferTypeOperands.getCount(), + bufferTypeOperands.getArrayView().getBuffer()); + + // Traverse all uses of the parameter group type, and add them to the work list + // for further processing. + traverseUses( + paramGroupType, + [&](IRUse* use) + { + if (use->getUser()->getFullType() != paramGroupType) + return; + WorkItem item; + item.wrappedFieldKey = fieldKey; + item.inst = use->getUser(); + workList.add(item); + }); + paramGroupType->replaceUsesWith(newParameterGroupType); + } + } + + // Now we have a work list of all instructions that uses a parameter group. + // We need to update all uses of parameter group x with `x.inner` instead. + for (auto item : workList) + { + traverseUses( + item.inst, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + + // Note that we insert the field address instruction right before each use, instead + // of immediately after the original parameter group inst, because the parameter + // group inst may be defined in a scope that does not allow field address + // instructions. + auto unwrapped = builder.emitFieldAddress(item.inst, item.wrappedFieldKey); + builder.replaceOperand(use, unwrapped); + }); + } +} + +class MetalWrapCBufferElementPolicy : public WrapCBufferElementPolicy +{ +public: + virtual bool shouldWrapBufferElementInStruct(IRParameterGroupType* cbufferType) override + { + // Metal allows structs, scalars, vectors and matrices directly as buffer elements. + if (as<IRStructType>(cbufferType->getElementType())) + return false; + if (as<IRBasicType>(cbufferType->getElementType())) + return false; + if (as<IRMatrixType>(cbufferType->getElementType())) + return false; + if (as<IRVectorType>(cbufferType->getElementType())) + return false; + + // Wrap everything else in a struct. + return true; + } +}; + +void wrapCBufferElementsForMetal(IRModule* module) +{ + MetalWrapCBufferElementPolicy policy = {}; + wrapCBufferElements(module, &policy); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-wrap-cbuffer-element.h b/source/slang/slang-ir-wrap-cbuffer-element.h new file mode 100644 index 000000000..a75ec78a3 --- /dev/null +++ b/source/slang/slang-ir-wrap-cbuffer-element.h @@ -0,0 +1,23 @@ +#ifndef SLANG_IR_WRAP_CBUFFER_ELEMENT_H +#define SLANG_IR_WRAP_CBUFFER_ELEMENT_H + +namespace Slang +{ +struct IRModule; +struct IRParameterGroupType; + +class WrapCBufferElementPolicy +{ +public: + virtual bool shouldWrapBufferElementInStruct(IRParameterGroupType* cbufferType) = 0; +}; + +// Wrap the element type of a ConstantBuffer/ParameterBlock in a struct if the element type is not +// something that allowed directly as a buffer element type by the target. +void wrapCBufferElements(IRModule* module, WrapCBufferElementPolicy* policy); + +void wrapCBufferElementsForMetal(IRModule* module); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9ec8a2c8b..d114a9a40 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8731,6 +8731,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_CastUInt64ToDescriptorHandle: case kIROp_CastDescriptorHandleToUInt64: case kIROp_CastDescriptorHandleToResource: + case kIROp_CastResourceToDescriptorHandle: case kIROp_GetDynamicResourceHeap: case kIROp_CastDynamicResource: case kIROp_AllocObj: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 161e70b25..aef3d6aeb 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -524,6 +524,7 @@ enum class IRTypeLayoutRuleName Std430, Std140, D3DConstantBuffer, + MetalParameterBlock, C, _Count, }; |
