diff options
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 1058 |
1 files changed, 615 insertions, 443 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 83218bade..c69592939 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -214,104 +214,173 @@ namespace Slang { +enum ConversionMethodKind +{ + Func, + Opcode +}; + +struct ConversionMethod +{ + ConversionMethodKind kind = ConversionMethodKind::Func; + union + { + IRFunc* func; + IROp op; + }; + ConversionMethod() { func = nullptr; } + operator bool() + { + return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop; + } + ConversionMethod& operator=(IRFunc* f) + { + kind = ConversionMethodKind::Func; + this->func = f; + return *this; + } + ConversionMethod& operator=(IROp irop) + { + kind = ConversionMethodKind::Opcode; + this->op = irop; + return *this; + } + IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr); + void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand); +}; + struct TypeLoweringConfig { AddressSpace addressSpace; - IRTypeLayoutRules* layoutRule; + IRTypeLayoutRuleName layoutRuleName; + IRTypeLayoutRules* getLayoutRule() const { return IRTypeLayoutRules::get(layoutRuleName); } + bool operator==(const TypeLoweringConfig& other) const { - return addressSpace == other.addressSpace && layoutRule == other.layoutRule; + return addressSpace == other.addressSpace && layoutRuleName == other.layoutRuleName; } HashCode getHashCode() const { - return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule)); + return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRuleName)); } }; + +struct LoweredElementTypeInfo +{ + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // inner array type of the data field. + IRStructKey* loweredInnerStructKey = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // struct key of the data field. + ConversionMethod convertOriginalToLowered; + ConversionMethod convertLoweredToOriginal; +}; + +/// Defines target-specific behavior of how to lower buffer element types. +struct BufferElementTypeLoweringPolicy : public RefObject +{ + /// Defines target-specific behavior of how to translate a non-composite logical type to a + /// storage type. + virtual LoweredElementTypeInfo lowerLeafLogicalType( + IRType* type, + TypeLoweringConfig config) = 0; + + /// Returns true if we should always create a fresh lowered storage type for a composite type, + /// even if every member/element of the composite type is not changed by the lowering. + virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes(TypeLoweringConfig config) + { + SLANG_UNUSED(config); + return false; + } + + /// Returns true if the target requires all array of scalars or vectors inside a constant buffer + /// to be translated into a 16-byte aligned vector type. + virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() + { + return false; + } +}; + +BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy( + BufferElementTypeLoweringPolicyKind kind, + TargetProgram* target, + BufferElementTypeLoweringOptions options); + TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType); -struct LoweredElementTypeContext +IRInst* ConversionMethod::apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) { - static const IRIntegerValue kMaxArraySizeToUnroll = 32; + if (!*this) + return builder.emitLoad(operandAddr); + if (kind == ConversionMethodKind::Func) + return builder.emitCallInst(resultType, func, 1, &operandAddr); + else + { + auto val = builder.emitLoad(operandAddr); + return builder.emitIntrinsicInst(resultType, op, 1, &val); + } +} - enum ConversionMethodKind +void ConversionMethod::applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) +{ + if (!*this) { - Func, - Opcode - }; - struct ConversionMethod + builder.emitStore(dest, operand); + return; + } + if (kind == ConversionMethodKind::Func) { - ConversionMethodKind kind = ConversionMethodKind::Func; - union - { - IRFunc* func; - IROp op; - }; - ConversionMethod() { func = nullptr; } - operator bool() - { - return kind == ConversionMethodKind::Func ? func != nullptr : op != kIROp_Nop; - } - ConversionMethod& operator=(IRFunc* f) - { - kind = ConversionMethodKind::Func; - this->func = f; - return *this; - } - ConversionMethod& operator=(IROp irop) - { - kind = ConversionMethodKind::Opcode; - this->op = irop; - return *this; - } - IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) - { - if (!*this) - return builder.emitLoad(operandAddr); - if (kind == ConversionMethodKind::Func) - return builder.emitCallInst(resultType, func, 1, &operandAddr); - else - { - auto val = builder.emitLoad(operandAddr); - return builder.emitIntrinsicInst(resultType, op, 1, &val); - } - } - void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) - { - if (!*this) - { - builder.emitStore(dest, operand); - return; - } - if (kind == ConversionMethodKind::Func) - { - IRInst* operands[] = {dest, operand}; - builder.emitCallInst(builder.getVoidType(), func, 2, operands); - } - else - { - auto val = builder.emitIntrinsicInst( - tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), - op, - 1, - &operand); - builder.emitStore(dest, val); - } - } - }; + IRInst* operands[] = {dest, operand}; + builder.emitCallInst(builder.getVoidType(), func, 2, operands); + } + else + { + auto val = builder.emitIntrinsicInst( + tryGetPointedToOrBufferElementType(&builder, dest->getDataType()), + op, + 1, + &operand); + builder.emitStore(dest, val); + } +} + +// Returns the number of elements N that ensures the IRVectorType(elementType,N) +// has 16-byte aligned size and N is no less than `minCount`. +IRIntegerValue get16ByteAlignedVectorElementCount( + TargetProgram* target, + IRType* elementType, + IRIntegerValue minCount) +{ + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment); + if (sizeAlignment.size) + return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size; + return 4; +} - struct LoweredElementTypeInfo +const char* getLayoutName(IRTypeLayoutRuleName name) +{ + switch (name) { - IRType* originalType; - IRType* loweredType; - IRType* loweredInnerArrayType = - nullptr; // For matrix/array types that are lowered into a struct type, this is the - // inner array type of the data field. - IRStructKey* loweredInnerStructKey = - nullptr; // For matrix/array types that are lowered into a struct type, this is the - // struct key of the data field. - ConversionMethod convertOriginalToLowered; - ConversionMethod convertLoweredToOriginal; - }; + case IRTypeLayoutRuleName::Std140: + return "std140"; + case IRTypeLayoutRuleName::Std430: + return "std430"; + case IRTypeLayoutRuleName::Natural: + return "natural"; + case IRTypeLayoutRuleName::C: + return "c"; + default: + return "default"; + } +} + +struct LoweredElementTypeContext +{ + static const IRIntegerValue kMaxArraySizeToUnroll = 32; struct LoweredTypeMap : RefObject { @@ -320,6 +389,7 @@ struct LoweredElementTypeContext }; Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps; + RefPtr<BufferElementTypeLoweringPolicy> leafTypeLoweringPolicy; struct ConversionMethodKey { @@ -366,150 +436,11 @@ struct LoweredElementTypeContext // Specialized functions that takes storage-typed pointers instead of logical-typed pointers. Dictionary<SpecializationKey, IRFunc*> specializedFuncs; - LoweredElementTypeContext( - TargetProgram* target, - BufferElementTypeLoweringOptions inOptions, - SlangMatrixLayoutMode inDefaultMatrixLayout) - : target(target), defaultMatrixLayout(inDefaultMatrixLayout), options(inOptions) - { - } - - IRFunc* createMatrixUnpackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRStructKey* dataKey) - { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic); - auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); - builder.addForceInlineDecoration(func); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = (Index)getIntVal(matrixType->getRowCount()); - auto colCount = (Index)getIntVal(matrixType->getColumnCount()); - auto packedParamRef = builder.emitParam(refStructType); - auto packedParam = builder.emitLoad(packedParamRef); - auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); - List<IRInst*> args; - args.setCount(rowCount * colCount); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto vector = builder.emitElementExtract(vectorArray, c); - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = builder.emitElementExtract(vector, r); - args[(Index)(r * colCount + c)] = element; - } - } - } - else - { - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto vector = builder.emitElementExtract(vectorArray, r); - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = builder.emitElementExtract(vector, c); - args[(Index)(r * colCount + c)] = element; - } - } - } - IRInst* result = - builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); - builder.emitReturn(result); - return func; - } - - IRFunc* createMatrixPackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRVectorType* vectorType, - IRArrayType* arrayType) + LoweredElementTypeContext(TargetProgram* target, BufferElementTypeLoweringOptions inOptions) + : target(target), options(inOptions) { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic); - IRType* paramTypes[] = {outStructType, matrixType}; - auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); - builder.addForceInlineDecoration(func); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); - auto outParam = builder.emitParam(outStructType); - auto originalParam = builder.emitParam(matrixType); - List<IRInst*> elements; - elements.setCount((Index)(rowCount * colCount)); - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto vector = builder.emitElementExtract(originalParam, r); - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = builder.emitElementExtract(vector, c); - elements[(Index)(r * colCount + c)] = element; - } - } - List<IRInst*> vectors; - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue c = 0; c < colCount; c++) - { - List<IRInst*> vecArgs; - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - // Fill in default values for remaining elements in the vector. - for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++) - { - vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); - } - auto colVector = builder.emitMakeVector( - vectorType, - (UInt)vecArgs.getCount(), - vecArgs.getBuffer()); - vectors.add(colVector); - } - } - else - { - for (IRIntegerValue r = 0; r < rowCount; r++) - { - List<IRInst*> vecArgs; - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - // Fill in default values for remaining elements in the vector. - for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++) - { - vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); - } - auto rowVector = builder.emitMakeVector( - vectorType, - (UInt)vecArgs.getCount(), - vecArgs.getBuffer()); - vectors.add(rowVector); - } - } - - auto vectorArray = - builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); - auto result = builder.emitMakeStruct(structType, 1, &vectorArray); - builder.emitStore(outParam, result); - builder.emitReturn(); - return func; + leafTypeLoweringPolicy = + getBufferElementTypeLoweringPolicy(options.loweringPolicyKind, target, options); } IRFunc* createArrayUnpackFunc( @@ -637,56 +568,6 @@ struct LoweredElementTypeContext return func; } - const char* getLayoutName(IRTypeLayoutRuleName name) - { - switch (name) - { - case IRTypeLayoutRuleName::Std140: - return "std140"; - case IRTypeLayoutRuleName::Std430: - return "std430"; - case IRTypeLayoutRuleName::Natural: - return "natural"; - case IRTypeLayoutRuleName::C: - return "c"; - default: - return "default"; - } - } - - // Returns the number of elements N that ensures the IRVectorType(elementType,N) - // has 16-byte aligned size and N is no less than `minCount`. - IRIntegerValue get16ByteAlignedVectorElementCount(IRType* elementType, IRIntegerValue minCount) - { - IRSizeAndAlignment sizeAlignment; - getNaturalSizeAndAlignment(target->getOptionSet(), elementType, &sizeAlignment); - if (sizeAlignment.size) - return align(sizeAlignment.size * minCount, 16) / sizeAlignment.size; - return 4; - } - - bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) - { - // For spirv, we always want to lower all matrix types, because SPIRV does not support - // specifying matrix layout/stride if the matrix type is used in places other than - // defining a struct field. This means that if a matrix is used to define a varying - // parameter, we always want to wrap it in a struct. - // - if (target->shouldEmitSPIRVDirectly()) - { - return true; - } - - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) - { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - return false; - } - return true; - } - LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); @@ -694,72 +575,13 @@ struct LoweredElementTypeContext LoweredElementTypeInfo info; info.originalType = type; - - if (auto matrixType = as<IRMatrixType>(type)) - { - if (!shouldLowerMatrixType(matrixType, config)) - { - info.loweredType = type; - return info; - } - - auto loweredType = builder.createStructType(); - builder.addPhysicalTypeDecoration(loweredType); - - StringBuilder nameSB; - bool isColMajor = - getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; - nameSB << "_MatrixStorage_"; - getTypeNameHint(nameSB, matrixType->getElementType()); - nameSB << getIntVal(matrixType->getRowCount()) << "x" - << getIntVal(matrixType->getColumnCount()); - if (isColMajor) - nameSB << "_ColMajor"; - nameSB << getLayoutName(config.layoutRule->ruleName); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); - if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && - options.use16ByteArrayElementForConstantBuffer) - { - // For constant buffer layout, we need to use 16-byte aligned vector if - // we are required to ensure array element types has 16-byte stride. - vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount( - matrixType->getElementType(), - getIntVal(vectorSize))); - } - - auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - vectorType, - &elementSizeAlignment); - elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); - - auto arrayType = builder.getArrayType( - vectorType, - isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, arrayType); - - info.loweredType = loweredType; - info.loweredInnerArrayType = arrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = - createMatrixUnpackFunc(matrixType, loweredType, structKey); - info.convertOriginalToLowered = - createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); - return info; - } - else if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) + if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) { auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config); - if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && - options.use16ByteArrayElementForConstantBuffer) + if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 && + leafTypeLoweringPolicy + ->shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()) { // For constant buffer layout, we need to use 16-byte-aligned vector if // we are required to ensure array element types has 16-byte stride. @@ -772,6 +594,7 @@ struct LoweredElementTypeContext packedVectorType = builder.getVectorType( vectorType->getElementType(), builder.getIntValue(get16ByteAlignedVectorElementCount( + target, vectorType->getElementType(), getIntVal(vectorType->getElementCount())))); if (packedVectorType != loweredInnerTypeInfo.originalType) @@ -784,7 +607,7 @@ struct LoweredElementTypeContext { packedVectorType = builder.getVectorType( loweredInnerTypeInfo.loweredType, - get16ByteAlignedVectorElementCount(scalarType, 1)); + get16ByteAlignedVectorElementCount(target, scalarType, 1)); loweredInnerTypeInfo.convertLoweredToOriginal = kIROp_VectorReshape; loweredInnerTypeInfo.convertOriginalToLowered = kIROp_MakeVectorFromScalar; } @@ -803,10 +626,10 @@ struct LoweredElementTypeContext } } - // For spirv backend, we always want to lower all array types for non-varying - // parameters, even if the element type comes out the same. This is because different - // layout rules may have different array stride requirements. - if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input) + // We can skip lowering this type if all field types are unchanged, unless the target + // specific policy tells us to always create a lowered type. + if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + config)) { if (!loweredInnerTypeInfo.convertLoweredToOriginal) { @@ -823,7 +646,7 @@ struct LoweredElementTypeContext info.loweredType = loweredType; StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; + nameSB << "_Array_" << getLayoutName(config.layoutRuleName) << "_"; getTypeNameHint(nameSB, arrayType->getElementType()); nameSB << getArraySizeVal(arrayType->getElementCount()); @@ -835,11 +658,11 @@ struct LoweredElementTypeContext IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerTypeInfo.loweredType, &elementSizeAlignment); elementSizeAlignment = - config.layoutRule->alignCompositeElement(elementSizeAlignment); + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); auto innerArrayType = builder.getArrayType( loweredInnerTypeInfo.loweredType, arrayType->getElementCount(), @@ -857,11 +680,11 @@ struct LoweredElementTypeContext IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerTypeInfo.loweredType, &elementSizeAlignment); elementSizeAlignment = - config.layoutRule->alignCompositeElement(elementSizeAlignment); + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); auto innerArrayType = builder.getArrayTypeBase( arrayTypeBase->getOp(), loweredInnerTypeInfo.loweredType, @@ -880,20 +703,15 @@ struct LoweredElementTypeContext auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config); fieldLoweredTypeInfo.add(loweredFieldTypeInfo); if (loweredFieldTypeInfo.convertLoweredToOriginal || - config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural) + config.layoutRuleName != IRTypeLayoutRuleName::Natural) isTrivial = false; } - // For spirv backend, we always want to lower all array types, even if the element type - // comes out the same. This is because different layout rules may have different array - // stride requirements. - // - // Additionally, `buffer` blocks do not work correctly unless lowered when targeting - // GLSL. - if (!isKhronosTarget(target->getTargetReq())) + // We can skip lowering this type if all field types are unchanged, unless the target + // specific policy tells us to always create a lowered type. + if (!leafTypeLoweringPolicy->shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + config)) { - // For non-spirv target, we skip lowering this type if all field types are - // unchanged. if (isTrivial) { info.loweredType = type; @@ -905,7 +723,7 @@ struct LoweredElementTypeContext StringBuilder nameSB; getTypeNameHint(nameSB, type); - nameSB << "_" << getLayoutName(config.layoutRule->ruleName); + nameSB << "_" << getLayoutName(config.layoutRuleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); info.loweredType = loweredType; // Create fields. @@ -1007,58 +825,7 @@ struct LoweredElementTypeContext return info; } - - if (target->shouldEmitSPIRVDirectly()) - { - switch (target->getTargetReq()->getTarget()) - { - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: - { - auto scalarType = type; - auto vectorType = as<IRVectorType>(scalarType); - if (vectorType) - scalarType = vectorType->getElementType(); - - if (as<IRBoolType>(scalarType)) - { - // Bool is an abstract type in SPIRV, so we need to lower them into an int. - - // Find an integer type of the correct size for the current layout rule. - IRSizeAndAlignment boolSizeAndAlignment; - if (getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - scalarType, - &boolSizeAndAlignment) == SLANG_OK) - { - IntInfo ii; - ii.width = boolSizeAndAlignment.size * 8; - ii.isSigned = true; - info.loweredType = builder.getType(getIntTypeOpFromInfo(ii)); - } - else - { - // Just in case that fails for some reason, just use an int. - info.loweredType = builder.getIntType(); - } - - if (vectorType) - info.loweredType = builder.getVectorType( - info.loweredType, - vectorType->getElementCount()); - info.convertLoweredToOriginal = kIROp_BuiltinCast; - info.convertOriginalToLowered = kIROp_BuiltinCast; - return info; - } - } - default: - break; - } - } - - info.loweredType = type; - return info; + return leafTypeLoweringPolicy->lowerLeafLogicalType(type, config); } LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config) @@ -1090,7 +857,7 @@ struct LoweredElementTypeContext IRSizeAndAlignment sizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), info.loweredType, &sizeAlignment); loweredTypeInfo.set(type, info); @@ -1190,13 +957,13 @@ struct LoweredElementTypeContext IRSizeAndAlignment arrayElementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), loweredInnerType.loweredType, &arrayElementSizeAlignment); IRSizeAndAlignment baseSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), tryGetPointedToOrBufferElementType(&builder, fieldAddr->getBase()->getDataType()), &baseSizeAlignment); @@ -1701,9 +1468,6 @@ struct LoweredElementTypeContext switch (ptrType->getAddressSpace()) { case AddressSpace::UserPointer: - if (!options.lowerBufferPointer) - continue; - [[fallthrough]]; case AddressSpace::Input: case AddressSpace::Output: elementType = ptrType->getValueType(); @@ -1720,7 +1484,7 @@ struct LoweredElementTypeContext IRSizeAndAlignment sizeAlignment; getSizeAndAlignment( target->getOptionSet(), - config.layoutRule, + config.getLayoutRule(), elementType, &sizeAlignment); SLANG_UNUSED(sizeAlignment); @@ -2181,63 +1945,60 @@ void lowerBufferElementTypeToStorageType( IRModule* module, BufferElementTypeLoweringOptions options) { - SlangMatrixLayoutMode defaultMatrixMode = - (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); - if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || - isMetalTarget(target->getTargetReq()))) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(target, options, defaultMatrixMode); + LoweredElementTypeContext context(target, options); context.processModule(module); } -IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout) +IRTypeLayoutRuleName getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRuleName defaultLayout) { switch (layoutTypeOp) { case kIROp_DefaultBufferLayoutType: return defaultLayout; case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); + return IRTypeLayoutRuleName::Std140; case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); + return IRTypeLayoutRuleName::Std430; case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; case kIROp_CBufferLayoutType: - return IRTypeLayoutRules::getC(); + return IRTypeLayoutRuleName::C; } return defaultLayout; } -IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +IRTypeLayoutRuleName getTypeLayoutRuleNameForBuffer(TargetProgram* target, IRType* bufferType) { + if (bufferType->getOp() == kIROp_ParameterBlockType && isMetalTarget(target->getTargetReq())) + { + return IRTypeLayoutRuleName::MetalParameterBlock; + } if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) { if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; // If we are just emitting GLSL, we can just use the general layout rule. if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; // If the user specified a C-compatible buffer layout, then do that. if (target->getOptionSet().shouldUseCLayout()) - return IRTypeLayoutRules::getC(); + return IRTypeLayoutRuleName::C; // If the user specified a scalar buffer layout, then just use that. if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } if (target->getOptionSet().shouldUseDXLayout()) { if (as<IRUniformParameterGroupType>(bufferType)) { - return IRTypeLayoutRules::getConstantBuffer(); + return IRTypeLayoutRuleName::D3DConstantBuffer; } else - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } // The default behavior is to use std140 for constant buffers and std430 for other buffers. @@ -2253,17 +2014,17 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = structBufferType->getDataLayout() ? structBufferType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430); } - case kIROp_ConstantBufferType: case kIROp_ParameterBlockType: + case kIROp_ConstantBufferType: { auto parameterGroupType = as<IRUniformParameterGroupType>(bufferType); auto layoutTypeOp = parameterGroupType->getDataLayout() ? parameterGroupType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std140); } case kIROp_GLSLShaderStorageBufferType: { @@ -2271,12 +2032,18 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = storageBufferType->getDataLayout() ? storageBufferType->getDataLayout()->getOp() : kIROp_Std430BufferLayoutType; - return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRuleName::Std430); } case kIROp_PtrType: - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; } - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRuleName::Natural; +} + +IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +{ + auto ruleName = getTypeLayoutRuleNameForBuffer(target, bufferType); + return IRTypeLayoutRules::get(ruleName); } TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType) @@ -2295,8 +2062,413 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* break; } } - auto rules = getTypeLayoutRuleForBuffer(target, bufferType); + auto rules = getTypeLayoutRuleNameForBuffer(target, bufferType); return TypeLoweringConfig{addrSpace, rules}; } +struct DefaultBufferElementTypeLoweringPolicy : BufferElementTypeLoweringPolicy +{ + TargetProgram* target; + BufferElementTypeLoweringOptions options; + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + + DefaultBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : target(inTarget), options(inOptions) + { + defaultMatrixLayout = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); + if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || + isMetalTarget(target->getTargetReq()))) + defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + else if (defaultMatrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) + { + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + config.getLayoutRule()->ruleName == IRTypeLayoutRuleName::Natural) + { + // We only lower the matrix types if they differ from the default + // matrix layout. + return false; + } + return true; + } + + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto refStructType = builder.getRefParamType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParamRef = builder.emitParam(refStructType); + auto packedParam = builder.emitLoad(packedParamRef); + auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto vector = builder.emitElementExtract(vectorArray, c); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = builder.emitElementExtract(vector, r); + args[(Index)(r * colCount + c)] = element; + } + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(vectorArray, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + args[(Index)(r * colCount + c)] = element; + } + } + } + IRInst* result = + builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRVectorType* vectorType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto outStructType = builder.getRefParamType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outStructType, matrixType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.addForceInlineDecoration(func); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto outParam = builder.emitParam(outStructType); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; + } + } + List<IRInst*> vectors; + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue r = rowCount; r < getIntVal(vectorType->getElementCount()); r++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } + auto colVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(colVector); + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + // Fill in default values for remaining elements in the vector. + for (IRIntegerValue c = colCount; c < getIntVal(vectorType->getElementCount()); c++) + { + vecArgs.add(builder.emitDefaultConstruct(vectorType->getElementType())); + } + auto rowVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(rowVector); + } + } + + auto vectorArray = + builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &vectorArray); + builder.emitStore(outParam, result); + builder.emitReturn(); + return func; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + IRBuilder builder(type); + builder.setInsertAfter(type); + + LoweredElementTypeInfo info; + info.originalType = type; + + if (auto matrixType = as<IRMatrixType>(type)) + { + if (!shouldLowerMatrixType(matrixType, config)) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + + StringBuilder nameSB; + bool isColMajor = + getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" + << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + nameSB << getLayoutName(config.layoutRuleName); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); + if (config.layoutRuleName == IRTypeLayoutRuleName::Std140 && + shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer()) + { + // For constant buffer layout, we need to use 16-byte aligned vector if + // we are required to ensure array element types has 16-byte stride. + vectorSize = builder.getIntValue(get16ByteAlignedVectorElementCount( + target, + matrixType->getElementType(), + getIntVal(vectorSize))); + } + + auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.getLayoutRule(), + vectorType, + &elementSizeAlignment); + elementSizeAlignment = + config.getLayoutRule()->alignCompositeElement(elementSizeAlignment); + + auto arrayType = builder.getArrayType( + vectorType, + isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, arrayType); + + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = + createMatrixUnpackFunc(matrixType, loweredType, structKey); + info.convertOriginalToLowered = + createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); + return info; + } + + info.loweredType = type; + return info; + } +}; + +struct KhronosTargetBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + KhronosTargetBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override + { + // For spirv, we always want to lower all matrix types, because SPIRV does not support + // specifying matrix layout/stride if the matrix type is used in places other than + // defining a struct field. This means that if a matrix is used to define a varying + // parameter, we always want to wrap it in a struct. + // + if (target->shouldEmitSPIRVDirectly()) + return true; + return DefaultBufferElementTypeLoweringPolicy::shouldLowerMatrixType(matrixType, config); + } + + virtual bool shouldAlwaysCreateLoweredStorageTypeForCompositeTypes( + TypeLoweringConfig config) override + { + // For spirv backend, we always want to lower all array types, even if the element type + // comes out the same. This is because different layout rules may have different array + // stride requirements. + // + // Additionally, `buffer` blocks do not work correctly unless lowered when targeting + // GLSL. + return target->shouldEmitSPIRVDirectly() && config.addressSpace != AddressSpace::Input; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + if (target->shouldEmitSPIRVDirectly()) + { + LoweredElementTypeInfo info = {}; + info.originalType = type; + + switch (target->getTargetReq()->getTarget()) + { + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + { + auto scalarType = type; + auto vectorType = as<IRVectorType>(scalarType); + if (vectorType) + scalarType = vectorType->getElementType(); + IRBuilder builder(type); + builder.setInsertBefore(type); + + if (as<IRBoolType>(scalarType)) + { + // Bool is an abstract type in SPIRV, so we need to lower them into an int. + + // Find an integer type of the correct size for the current layout rule. + IRSizeAndAlignment boolSizeAndAlignment; + if (getSizeAndAlignment( + target->getOptionSet(), + config.getLayoutRule(), + scalarType, + &boolSizeAndAlignment) == SLANG_OK) + { + IntInfo ii; + ii.width = boolSizeAndAlignment.size * 8; + ii.isSigned = true; + info.loweredType = builder.getType(getIntTypeOpFromInfo(ii)); + } + else + { + // Just in case that fails for some reason, just use an int. + info.loweredType = builder.getIntType(); + } + + if (vectorType) + info.loweredType = builder.getVectorType( + info.loweredType, + vectorType->getElementCount()); + info.convertLoweredToOriginal = kIROp_BuiltinCast; + info.convertOriginalToLowered = kIROp_BuiltinCast; + return info; + } + } + break; + default: + break; + } + } + return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config); + } +}; + +struct MetalParameterBlockElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + MetalParameterBlockElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) override + { + SLANG_UNUSED(matrixType); + SLANG_UNUSED(config); + return false; + } + + LoweredElementTypeInfo lowerLeafLogicalType(IRType* type, TypeLoweringConfig config) override + { + if (config.layoutRuleName == IRTypeLayoutRuleName::MetalParameterBlock && + isResourceType(type)) + { + IRBuilder builder(type); + builder.setInsertBefore(type); + LoweredElementTypeInfo info = {}; + info.originalType = type; + info.loweredType = builder.getType(kIROp_DescriptorHandleType, type); + info.convertLoweredToOriginal = kIROp_CastDescriptorHandleToResource; + info.convertOriginalToLowered = kIROp_CastResourceToDescriptorHandle; + return info; + } + return DefaultBufferElementTypeLoweringPolicy::lowerLeafLogicalType(type, config); + } +}; + +struct WGSLBufferElementTypeLoweringPolicy : DefaultBufferElementTypeLoweringPolicy +{ + WGSLBufferElementTypeLoweringPolicy( + TargetProgram* inTarget, + BufferElementTypeLoweringOptions inOptions) + : DefaultBufferElementTypeLoweringPolicy(inTarget, inOptions) + { + } + + virtual bool shouldTranslateArrayElementTo16ByteAlignedVectorForConstantBuffer() override + { + return true; + } +}; + +BufferElementTypeLoweringPolicy* getBufferElementTypeLoweringPolicy( + BufferElementTypeLoweringPolicyKind kind, + TargetProgram* target, + BufferElementTypeLoweringOptions options) +{ + switch (kind) + { + case BufferElementTypeLoweringPolicyKind::Default: + return new DefaultBufferElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::KhronosTarget: + return new KhronosTargetBufferElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::MetalParameterBlock: + return new MetalParameterBlockElementTypeLoweringPolicy(target, options); + case BufferElementTypeLoweringPolicyKind::WGSL: + return new WGSLBufferElementTypeLoweringPolicy(target, options); + } + SLANG_UNREACHABLE("unknown buffer element type lowering policy"); +} + } // namespace Slang |
