summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-18 11:33:55 -0800
committerGitHub <noreply@github.com>2024-12-18 11:33:55 -0800
commitae04e604d43d169bcba7f24c8c23a0fdf4cbb483 (patch)
tree899c872ec5cc5c6ccc27930ef6971a0baf018569 /source/slang/slang-ir-lower-buffer-element-type.cpp
parent41c627fd420a644f0ae86e36f4752e820e2d683c (diff)
Allow `Optional`, `Tuple` and `bool` to be used in varying input/output. (#5889)
* Allow `Optional` and `Tuple` to be used in varying input/output. * Fix. * format code * Fix. * Fix test. * Fix. * enhance test. * Fix. * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp150
1 files changed, 106 insertions, 44 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index dd62ca02c..7f67c9254 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -8,6 +8,22 @@
namespace Slang
{
+
+struct TypeLoweringConfig
+{
+ AddressSpace addressSpace;
+ IRTypeLayoutRules* layoutRule;
+ bool operator==(const TypeLoweringConfig& other) const
+ {
+ return addressSpace == other.addressSpace && layoutRule == other.layoutRule;
+ }
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(addressSpace), Slang::getHashCode(layoutRule));
+ }
+};
+TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType);
+
struct LoweredElementTypeContext
{
static const IRIntegerValue kMaxArraySizeToUnroll = 32;
@@ -67,9 +83,13 @@ struct LoweredElementTypeContext
ConversionMethod convertLoweredToOriginal;
};
- Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count];
- Dictionary<IRType*, LoweredElementTypeInfo>
- mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count];
+ struct LoweredTypeMap : RefObject
+ {
+ Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo;
+ Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo;
+ };
+
+ Dictionary<TypeLoweringConfig, RefPtr<LoweredTypeMap>> loweredTypeInfoMaps;
struct ConversionMethodKey
{
@@ -392,7 +412,7 @@ struct LoweredElementTypeContext
return 4;
}
- LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules)
+ LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
builder.setInsertAfter(type);
@@ -409,7 +429,7 @@ struct LoweredElementTypeContext
// For other targets, we only lower the matrix types if they differ from the default
// matrix layout.
if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- rules->ruleName == IRTypeLayoutRuleName::Natural)
+ config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
{
info.loweredType = type;
return info;
@@ -426,12 +446,12 @@ struct LoweredElementTypeContext
<< getIntVal(matrixType->getColumnCount());
if (isColMajor)
nameSB << "_ColMajor";
- nameSB << getLayoutName(rules->ruleName);
+ nameSB << getLayoutName(config.layoutRule->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
auto structKey = builder.createStructKey();
builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
auto vectorSize = isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount();
- if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
options.use16ByteArrayElementForConstantBuffer)
{
// For constant buffer layout, we need to use 16-byte aligned vector if
@@ -443,8 +463,12 @@ struct LoweredElementTypeContext
auto vectorType = builder.getVectorType(matrixType->getElementType(), vectorSize);
IRSizeAndAlignment elementSizeAlignment;
- getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment);
- elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ vectorType,
+ &elementSizeAlignment);
+ elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
auto arrayType = builder.getArrayType(
vectorType,
@@ -463,9 +487,9 @@ struct LoweredElementTypeContext
}
else if (auto arrayType = as<IRArrayType>(type))
{
- auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules);
+ auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config);
- if (rules->ruleName == IRTypeLayoutRuleName::Std140 &&
+ if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 &&
options.use16ByteArrayElementForConstantBuffer)
{
// For constant buffer layout, we need to use 16-byte-aligned vector if
@@ -510,10 +534,10 @@ struct LoweredElementTypeContext
}
}
- // For spirv backend, we always want to lower all array types, even if the element type
- // comes out the same. This is because different layout rules may have different array
- // stride requirements.
- if (!target->shouldEmitSPIRVDirectly())
+ // For spirv backend, we always want to lower all array types for non-varying
+ // parameters, even if the element type comes out the same. This is because different
+ // layout rules may have different array stride requirements.
+ if (!target->shouldEmitSPIRVDirectly() || config.addressSpace == AddressSpace::Input)
{
if (!loweredInnerTypeInfo.convertLoweredToOriginal)
{
@@ -525,7 +549,7 @@ struct LoweredElementTypeContext
auto loweredType = builder.createStructType();
info.loweredType = loweredType;
StringBuilder nameSB;
- nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_";
+ nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_";
getTypeNameHint(nameSB, arrayType->getElementType());
nameSB << getIntVal(arrayType->getElementCount());
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
@@ -534,10 +558,10 @@ struct LoweredElementTypeContext
IRSizeAndAlignment elementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- rules,
+ config.layoutRule,
loweredInnerTypeInfo.loweredType,
&elementSizeAlignment);
- elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment);
+ elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment);
auto innerArrayType = builder.getArrayType(
loweredInnerTypeInfo.loweredType,
arrayType->getElementCount(),
@@ -566,10 +590,10 @@ struct LoweredElementTypeContext
bool isTrivial = true;
for (auto field : structType->getFields())
{
- auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules);
+ auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), config);
fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
if (loweredFieldTypeInfo.convertLoweredToOriginal ||
- rules->ruleName != IRTypeLayoutRuleName::Natural)
+ config.layoutRule->ruleName != IRTypeLayoutRuleName::Natural)
isTrivial = false;
}
@@ -589,7 +613,7 @@ struct LoweredElementTypeContext
auto loweredType = builder.createStructType();
StringBuilder nameSB;
getTypeNameHint(nameSB, type);
- nameSB << "_" << getLayoutName(rules->ruleName);
+ nameSB << "_" << getLayoutName(config.layoutRule->ruleName);
builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
info.loweredType = loweredType;
// Create fields.
@@ -748,24 +772,40 @@ struct LoweredElementTypeContext
return info;
}
- LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules)
+ LoweredTypeMap& getTypeLoweringMap(TypeLoweringConfig config)
+ {
+ RefPtr<LoweredTypeMap> map;
+ if (loweredTypeInfoMaps.tryGetValue(config, map))
+ return *map;
+ map = new LoweredTypeMap();
+ loweredTypeInfoMaps.add(config, map);
+ return *map;
+ }
+
+ LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, TypeLoweringConfig config)
{
// If `type` is already a lowered type, no more lowering is required.
LoweredElementTypeInfo info;
- if (mapLoweredTypeToInfo->tryGetValue(type))
+ auto& map = getTypeLoweringMap(config);
+ auto& mapLoweredTypeToInfo = map.mapLoweredTypeToInfo;
+ auto& loweredTypeInfo = map.loweredTypeInfo;
+ if (mapLoweredTypeToInfo.tryGetValue(type))
{
info.originalType = type;
info.loweredType = type;
return info;
}
-
- if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info))
+ if (loweredTypeInfo.tryGetValue(type, info))
return info;
- info = getLoweredTypeInfoImpl(type, rules);
+ info = getLoweredTypeInfoImpl(type, config);
IRSizeAndAlignment sizeAlignment;
- getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment);
- loweredTypeInfo[(int)rules->ruleName].set(type, info);
- mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info);
+ getSizeAndAlignment(
+ target->getOptionSet(),
+ config.layoutRule,
+ info.loweredType,
+ &sizeAlignment);
+ loweredTypeInfo.set(type, info);
+ mapLoweredTypeToInfo.set(info.loweredType, info);
conversionMethodMap[{info.originalType, info.loweredType}] = info.convertLoweredToOriginal;
conversionMethodMap[{info.loweredType, info.originalType}] = info.convertOriginalToLowered;
return info;
@@ -802,7 +842,7 @@ struct LoweredElementTypeContext
struct MatrixAddrWorkItem
{
IRInst* matrixAddrInst;
- IRTypeLayoutRules* layoutRules;
+ TypeLoweringConfig config;
};
void processModule(IRModule* module)
@@ -812,17 +852,25 @@ struct LoweredElementTypeContext
{
IRType* bufferType;
IRType* elementType;
+ bool shouldWrapArrayInStruct = false;
};
List<BufferTypeInfo> bufferTypeInsts;
for (auto globalInst : module->getGlobalInsts())
{
IRType* elementType = nullptr;
+
if (options.lowerBufferPointer)
{
- if (auto ptrType = as<IRPtrType>(globalInst))
+ if (auto ptrType = as<IRPtrTypeBase>(globalInst))
{
- if (ptrType->getAddressSpace() == AddressSpace::UserPointer)
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::UserPointer:
+ case AddressSpace::Input:
+ case AddressSpace::Output:
elementType = ptrType->getValueType();
+ break;
+ }
}
}
else
@@ -849,8 +897,8 @@ struct LoweredElementTypeContext
{
auto bufferType = bufferTypeInfo.bufferType;
auto elementType = bufferTypeInfo.elementType;
- auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType);
- auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules);
+ auto config = getTypeLoweringConfigForBuffer(target, bufferType);
+ auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config);
// If the lowered type is the same as original type, no change is required.
if (loweredBufferElementTypeInfo.loweredType ==
@@ -902,18 +950,18 @@ struct LoweredElementTypeContext
builder.setInsertBefore(ptrVal);
auto newArrayPtrVal = fieldAddr->getBase();
auto loweredInnerType =
- getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules);
+ getLoweredTypeInfo(unsizedArrayType->getElementType(), config);
IRSizeAndAlignment arrayElementSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- layoutRules,
+ config.layoutRule,
loweredInnerType.loweredType,
&arrayElementSizeAlignment);
IRSizeAndAlignment baseSizeAlignment;
getSizeAndAlignment(
target->getOptionSet(),
- layoutRules,
+ config.layoutRule,
tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()),
&baseSizeAlignment);
@@ -998,7 +1046,7 @@ struct LoweredElementTypeContext
if (!loweredElementTypeInfo.loweredType)
{
loweredElementTypeInfo =
- getLoweredTypeInfo((IRType*)originalElementType, layoutRules);
+ getLoweredTypeInfo((IRType*)originalElementType, config);
}
if (!loweredElementTypeInfo.convertLoweredToOriginal)
@@ -1087,7 +1135,7 @@ struct LoweredElementTypeContext
// We are tring to get a pointer to a lowered matrix element.
// We process this insts at a later phase.
SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
- matrixAddrInsts.add(MatrixAddrWorkItem{user, layoutRules});
+ matrixAddrInsts.add(MatrixAddrWorkItem{user, config});
}
else
{
@@ -1160,15 +1208,12 @@ struct LoweredElementTypeContext
for (auto workItem : matrixAddrInsts)
{
auto majorAddr = workItem.matrixAddrInst;
- auto layoutRules = workItem.layoutRules;
-
- int layoutRuleName = (int)layoutRules->ruleName;
auto majorGEP = as<IRGetElementPtr>(majorAddr);
SLANG_ASSERT(majorGEP);
auto loweredMatrixType =
cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType();
- auto matrixTypeInfo =
- mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType);
+ auto matrixTypeInfo = getTypeLoweringMap(workItem.config)
+ .mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
SLANG_ASSERT(matrixTypeInfo);
auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
auto rowCount = getIntVal(matrixType->getRowCount());
@@ -1379,4 +1424,21 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf
return IRTypeLayoutRules::getNatural();
}
+TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* bufferType)
+{
+ AddressSpace addrSpace = AddressSpace::Generic;
+ if (auto ptrType = as<IRPtrTypeBase>(bufferType))
+ {
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::Input:
+ case AddressSpace::Output:
+ addrSpace = AddressSpace::Input;
+ break;
+ }
+ }
+ auto rules = getTypeLayoutRuleForBuffer(target, bufferType);
+ return TypeLoweringConfig{addrSpace, rules};
+}
+
} // namespace Slang