diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-18 11:33:55 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-18 11:33:55 -0800 |
| commit | ae04e604d43d169bcba7f24c8c23a0fdf4cbb483 (patch) | |
| tree | 899c872ec5cc5c6ccc27930ef6971a0baf018569 /source/slang/slang-ir-lower-buffer-element-type.cpp | |
| parent | 41c627fd420a644f0ae86e36f4752e820e2d683c (diff) | |
Allow `Optional`, `Tuple` and `bool` to be used in varying input/output. (#5889)
* Allow `Optional` and `Tuple` to be used in varying input/output.
* Fix.
* format code
* Fix.
* Fix test.
* Fix.
* enhance test.
* Fix.
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 150 |
1 files changed, 106 insertions, 44 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index dd62ca02c..7f67c9254 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -8,6 +8,22 @@ namespace Slang { + +struct TypeLoweringConfig +{ + AddressSpace addressSpace; + IRTypeLayoutRules* layoutRule; + bool operator==(const TypeLoweringConfig& other) const + { + return addressSpace == other.addressSpace && layoutRule == other.layoutRule; + } + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule)); + } +}; +TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType); + struct LoweredElementTypeContext { static const IRIntegerValue kMaxArraySizeToUnroll = 32; @@ -67,9 +83,13 @@ struct LoweredElementTypeContext ConversionMethod convertLoweredToOriginal; }; - Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count]; - Dictionary<IRType*, LoweredElementTypeInfo> - mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count]; + struct LoweredTypeMap : RefObject + { + Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo; + Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo; + }; + + Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps; struct ConversionMethodKey { @@ -392,7 +412,7 @@ struct LoweredElementTypeContext return 4; } - LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules) + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); builder.setInsertAfter(type); @@ -409,7 +429,7 @@ struct LoweredElementTypeContext // For other targets, we only lower the matrix types if they differ from the default // matrix layout. if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - rules->ruleName == IRTypeLayoutRuleName::Natural) + config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) { info.loweredType = type; return info; @@ -426,12 +446,12 @@ struct LoweredElementTypeContext << getIntVal(matrixType->getColumnCount()); if (isColMajor) nameSB << "_ColMajor"; - nameSB << getLayoutName(rules->ruleName); + nameSB << getLayoutName(config.layoutRule->ruleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); auto structKey = builder.createStructKey(); builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount(); - if (rules->ruleName == IRTypeLayoutRuleName::Std140 && + if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && options.use16ByteArrayElementForConstantBuffer) { // For constant buffer layout, we need to use 16-byte aligned vector if @@ -443,8 +463,12 @@ struct LoweredElementTypeContext auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize); IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment); - elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + vectorType, + &elementSizeAlignment); + elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); auto arrayType = builder.getArrayType( vectorType, @@ -463,9 +487,9 @@ struct LoweredElementTypeContext } else if (auto arrayType = as<IRArrayType>(type)) { - auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config); - if (rules->ruleName == IRTypeLayoutRuleName::Std140 && + if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && options.use16ByteArrayElementForConstantBuffer) { // For constant buffer layout, we need to use 16-byte-aligned vector if @@ -510,10 +534,10 @@ struct LoweredElementTypeContext } } - // For spirv backend, we always want to lower all array types, even if the element type - // comes out the same. This is because different layout rules may have different array - // stride requirements. - if (!target->shouldEmitSPIRVDirectly()) + // For spirv backend, we always want to lower all array types for non-varying + // parameters, even if the element type comes out the same. This is because different + // layout rules may have different array stride requirements. + if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input) { if (!loweredInnerTypeInfo.convertLoweredToOriginal) { @@ -525,7 +549,7 @@ struct LoweredElementTypeContext auto loweredType = builder.createStructType(); info.loweredType = loweredType; StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_"; + nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; getTypeNameHint(nameSB, arrayType->getElementType()); nameSB << getIntVal(arrayType->getElementCount()); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); @@ -534,10 +558,10 @@ struct LoweredElementTypeContext IRSizeAndAlignment elementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - rules, + config.layoutRule, loweredInnerTypeInfo.loweredType, &elementSizeAlignment); - elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); auto innerArrayType = builder.getArrayType( loweredInnerTypeInfo.loweredType, arrayType->getElementCount(), @@ -566,10 +590,10 @@ struct LoweredElementTypeContext bool isTrivial = true; for (auto field : structType->getFields()) { - auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules); + auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config); fieldLoweredTypeInfo.add(loweredFieldTypeInfo); if (loweredFieldTypeInfo.convertLoweredToOriginal || - rules->ruleName != IRTypeLayoutRuleName::Natural) + config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural) isTrivial = false; } @@ -589,7 +613,7 @@ struct LoweredElementTypeContext auto loweredType = builder.createStructType(); StringBuilder nameSB; getTypeNameHint(nameSB, type); - nameSB << "_" << getLayoutName(rules->ruleName); + nameSB << "_" << getLayoutName(config.layoutRule->ruleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); info.loweredType = loweredType; // Create fields. @@ -748,24 +772,40 @@ struct LoweredElementTypeContext return info; } - LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules) + LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config) + { + RefPtr<LoweredTypeMap> map; + if (loweredTypeInfoMaps.tryGetValue(config, map)) + return *map; + map = new LoweredTypeMap(); + loweredTypeInfoMaps.add(config, map); + return *map; + } + + LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, TypeLoweringConfig config) { // If `type` is already a lowered type, no more lowering is required. LoweredElementTypeInfo info; - if (mapLoweredTypeToInfo->tryGetValue(type)) + auto& map = getTypeLoweringMap(config); + auto& mapLoweredTypeToInfo = map.mapLoweredTypeToInfo; + auto& loweredTypeInfo = map.loweredTypeInfo; + if (mapLoweredTypeToInfo.tryGetValue(type)) { info.originalType = type; info.loweredType = type; return info; } - - if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info)) + if (loweredTypeInfo.tryGetValue(type, info)) return info; - info = getLoweredTypeInfoImpl(type, rules); + info = getLoweredTypeInfoImpl(type, config); IRSizeAndAlignment sizeAlignment; - getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment); - loweredTypeInfo[(int)rules->ruleName].set(type, info); - mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info); + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + info.loweredType, + &sizeAlignment); + loweredTypeInfo.set(type, info); + mapLoweredTypeToInfo.set(info.loweredType, info); conversionMethodMap[{info.originalType, info.loweredType}] = info.convertLoweredToOriginal; conversionMethodMap[{info.loweredType, info.originalType}] = info.convertOriginalToLowered; return info; @@ -802,7 +842,7 @@ struct LoweredElementTypeContext struct MatrixAddrWorkItem { IRInst* matrixAddrInst; - IRTypeLayoutRules* layoutRules; + TypeLoweringConfig config; }; void processModule(IRModule* module) @@ -812,17 +852,25 @@ struct LoweredElementTypeContext { IRType* bufferType; IRType* elementType; + bool shouldWrapArrayInStruct = false; }; List<BufferTypeInfo> bufferTypeInsts; for (auto globalInst : module->getGlobalInsts()) { IRType* elementType = nullptr; + if (options.lowerBufferPointer) { - if (auto ptrType = as<IRPtrType>(globalInst)) + if (auto ptrType = as<IRPtrTypeBase>(globalInst)) { - if (ptrType->getAddressSpace() == AddressSpace::UserPointer) + switch (ptrType->getAddressSpace()) + { + case AddressSpace::UserPointer: + case AddressSpace::Input: + case AddressSpace::Output: elementType = ptrType->getValueType(); + break; + } } } else @@ -849,8 +897,8 @@ struct LoweredElementTypeContext { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; - auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType); - auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules); + auto config = getTypeLoweringConfigForBuffer(target, bufferType); + auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config); // If the lowered type is the same as original type, no change is required. if (loweredBufferElementTypeInfo.loweredType == @@ -902,18 +950,18 @@ struct LoweredElementTypeContext builder.setInsertBefore(ptrVal); auto newArrayPtrVal = fieldAddr->getBase(); auto loweredInnerType = - getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules); + getLoweredTypeInfo(unsizedArrayType->getElementType(), config); IRSizeAndAlignment arrayElementSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - layoutRules, + config.layoutRule, loweredInnerType.loweredType, &arrayElementSizeAlignment); IRSizeAndAlignment baseSizeAlignment; getSizeAndAlignment( target->getOptionSet(), - layoutRules, + config.layoutRule, tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), &baseSizeAlignment); @@ -998,7 +1046,7 @@ struct LoweredElementTypeContext if (!loweredElementTypeInfo.loweredType) { loweredElementTypeInfo = - getLoweredTypeInfo((IRType*)originalElementType, layoutRules); + getLoweredTypeInfo((IRType*)originalElementType, config); } if (!loweredElementTypeInfo.convertLoweredToOriginal) @@ -1087,7 +1135,7 @@ struct LoweredElementTypeContext // We are tring to get a pointer to a lowered matrix element. // We process this insts at a later phase. SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); - matrixAddrInsts.add(MatrixAddrWorkItem{user, layoutRules}); + matrixAddrInsts.add(MatrixAddrWorkItem{user, config}); } else { @@ -1160,15 +1208,12 @@ struct LoweredElementTypeContext for (auto workItem : matrixAddrInsts) { auto majorAddr = workItem.matrixAddrInst; - auto layoutRules = workItem.layoutRules; - - int layoutRuleName = (int)layoutRules->ruleName; auto majorGEP = as<IRGetElementPtr>(majorAddr); SLANG_ASSERT(majorGEP); auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); - auto matrixTypeInfo = - mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType); + auto matrixTypeInfo = getTypeLoweringMap(workItem.config) + .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); SLANG_ASSERT(matrixTypeInfo); auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); auto rowCount = getIntVal(matrixType->getRowCount()); @@ -1379,4 +1424,21 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf return IRTypeLayoutRules::getNatural(); } +TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType) +{ + AddressSpace addrSpace = AddressSpace::Generic; + if (auto ptrType = as<IRPtrTypeBase>(bufferType)) + { + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::Output: + addrSpace = AddressSpace::Input; + break; + } + } + auto rules = getTypeLayoutRuleForBuffer(target, bufferType); + return TypeLoweringConfig{addrSpace, rules}; +} + } // namespace Slang |
