diff options
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 168 |
1 files changed, 119 insertions, 49 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 3ef94d415..0472e44df 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir-clone.h" +#include "slang-ir-layout.h" namespace Slang { @@ -17,13 +18,15 @@ namespace Slang IRFunc* convertOriginalToLowered = nullptr; IRFunc* convertLoweredToOriginal = nullptr; }; - Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo; - Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo; + + Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count]; + Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count]; SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + TargetRequest* target; - LoweredElementTypeContext(SlangMatrixLayoutMode inDefaultMatrixLayout) - : defaultMatrixLayout(inDefaultMatrixLayout) + LoweredElementTypeContext(TargetRequest* target, SlangMatrixLayoutMode inDefaultMatrixLayout) + : target(target), defaultMatrixLayout(inDefaultMatrixLayout) {} IRFunc* createMatrixUnpackFunc( @@ -200,7 +203,18 @@ namespace Slang return func; } - LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type) + const char* getLayoutName(IRTypeLayoutRuleName name) + { + switch (name) + { + case IRTypeLayoutRuleName::Std140: return "std140"; + case IRTypeLayoutRuleName::Std430: return "std430"; + case IRTypeLayoutRuleName::Natural: return "natural"; + default: return "default"; + } + } + + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules) { IRBuilder builder(type); builder.setInsertAfter(type); @@ -224,12 +238,20 @@ namespace Slang nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount()); if (isColMajor) nameSB << "_ColMajor"; + nameSB << getLayoutName(rules->ruleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); auto structKey = builder.createStructKey(); builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); auto vectorType = builder.getVectorType(matrixType->getElementType(), isColMajor?matrixType->getRowCount():matrixType->getColumnCount()); - auto arrayType = builder.getArrayType(vectorType, isColMajor?matrixType->getColumnCount():matrixType->getRowCount()); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(rules, vectorType, &elementSizeAlignment); + elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + + auto arrayType = builder.getArrayType( + vectorType, + isColMajor?matrixType->getColumnCount():matrixType->getRowCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.size)); builder.createStructField(loweredType, structKey, arrayType); info.loweredType = loweredType; @@ -241,45 +263,48 @@ namespace Slang } else if (auto arrayType = as<IRArrayType>(type)) { - auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType()); - - if (loweredInnerTypeInfo.loweredType != loweredInnerTypeInfo.originalType) - { - auto loweredType = builder.createStructType(); - info.loweredType = loweredType; - StringBuilder nameSB; - nameSB << "_ArrayStorage_"; - getTypeNameHint(nameSB, arrayType->getElementType()); - nameSB << getIntVal(arrayType->getElementCount()); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - auto innerArrayType = builder.getArrayType(loweredInnerTypeInfo.loweredType, arrayType->getElementCount()); - builder.createStructField(loweredType, structKey, innerArrayType); - info.loweredInnerArrayType = innerArrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo); - info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); - } - else + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); + if (!loweredInnerTypeInfo.convertLoweredToOriginal && rules->ruleName == IRTypeLayoutRuleName::Natural) { info.loweredType = type; + return info; } + auto loweredType = builder.createStructType(); + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(rules, loweredType, &elementSizeAlignment); + elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayType( + loweredInnerTypeInfo.loweredType, + arrayType->getElementCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.size)); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo); + info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); + return info; } else if (auto structType = as<IRStructType>(type)) { - bool hasNonTrivialField = false; List<LoweredElementTypeInfo> fieldLoweredTypeInfo; + bool isTrivial = true; for (auto field : structType->getFields()) { - auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType()); + auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules); fieldLoweredTypeInfo.add(loweredFieldTypeInfo); - if (loweredFieldTypeInfo.loweredType != loweredFieldTypeInfo.originalType) - hasNonTrivialField = true; + if (loweredFieldTypeInfo.convertLoweredToOriginal || rules->ruleName != IRTypeLayoutRuleName::Natural) + isTrivial = false; } - - if (!hasNonTrivialField) + if (isTrivial) { info.loweredType = type; return info; @@ -288,10 +313,9 @@ namespace Slang auto loweredType = builder.createStructType(); StringBuilder nameSB; getTypeNameHint(nameSB, type); - nameSB << "_Storage"; + nameSB << "_" << getLayoutName(rules->ruleName); builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); info.loweredType = loweredType; - // Create fields. { Index fieldId = 0; @@ -340,7 +364,7 @@ namespace Slang Index fieldId = 0; for (auto field : structType->getFields()) { - auto fieldVal = builder.emitFieldExtract(type, param, field->getKey()); + auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered ? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal) : fieldVal; @@ -358,14 +382,16 @@ namespace Slang return info; } - LoweredElementTypeInfo getLoweredTypeInfo(IRType* type) + LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules) { LoweredElementTypeInfo info; - if (loweredTypeInfo.tryGetValue(type, info)) + if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info)) return info; - info = getLoweredTypeInfoImpl(type); - loweredTypeInfo[type] = info; - mapLoweredTypeToInfo[info.loweredType] = info; + info = getLoweredTypeInfoImpl(type, rules); + IRSizeAndAlignment sizeAlignment; + getSizeAndAlignment(rules, info.loweredType, &sizeAlignment); + loweredTypeInfo[(int)rules->ruleName].set(type, info); + mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info); return info; } @@ -389,6 +415,12 @@ namespace Slang return nullptr; } + struct MatrixAddrWorkItem + { + IRInst* matrixAddrInst; + IRTypeLayoutRules* layoutRules; + }; + void processModule(IRModule* module) { IRBuilder builder(module); @@ -414,13 +446,15 @@ namespace Slang // Maintain a pending work list of all matrix addresses, and try to lower them out of existance // after everything else has been lowered. - List<IRInst*> matrixAddrInsts; + + List<MatrixAddrWorkItem> matrixAddrInsts; for (auto bufferTypeInfo : bufferTypeInsts) { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; - auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType); + auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType); + auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules); // If the lowered type is the same as original type, no change is required. if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) @@ -450,7 +484,7 @@ namespace Slang auto ptrVal = ptrValsWorkList[i]; auto oldPtrType = ptrVal->getFullType(); auto originalElementType = oldPtrType->getOperand(0); - auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType); + auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, layoutRules); if (!loweredElementTypeInfo.convertLoweredToOriginal) continue; @@ -517,7 +551,7 @@ namespace Slang // We are tring to get a pointer to a lowered matrix element. // We process this insts at a later phase. SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); - matrixAddrInsts.add(user); + matrixAddrInsts.add(MatrixAddrWorkItem{ user, layoutRules }); } else { @@ -544,19 +578,24 @@ namespace Slang bufferType->removeAndDeallocate(); } + // Process all matrix address uses. lowerMatrixAddresses(module, matrixAddrInsts); } // Lower all getElementPtr insts of a lowered matrix out of existance. - void lowerMatrixAddresses(IRModule* module, List<IRInst*>& matrixAddrInsts) + void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts) { IRBuilder builder(module); - for (auto majorAddr : matrixAddrInsts) + for (auto workItem : matrixAddrInsts) { + auto majorAddr = workItem.matrixAddrInst; + auto layoutRules = workItem.layoutRules; + + int layoutRuleName = (int)layoutRules->ruleName; auto majorGEP = as<IRGetElementPtr>(majorAddr); SLANG_ASSERT(majorGEP); auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); - auto matrixTypeInfo = mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); + auto matrixTypeInfo = mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType); SLANG_ASSERT(matrixTypeInfo); auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); auto rowCount = getIntVal(matrixType->getRowCount()); @@ -652,7 +691,38 @@ namespace Slang SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getDefaultMatrixLayoutMode(); if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(defaultMatrixMode); + LoweredElementTypeContext context(target, defaultMatrixMode); context.processModule(module); } + + + IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetRequest* target, IRType* bufferType) + { + if (!isKhronosTarget(target)) + return IRTypeLayoutRules::getNatural(); + + // If we are just emitting GLSL, we can just use the general layout rule. + if (!target->shouldEmitSPIRVDirectly()) + return IRTypeLayoutRules::getNatural(); + + // If the user specified a scalar buffer layout, then just use that. + if (target->getForceGLSLScalarBufferLayout()) + return IRTypeLayoutRules::getNatural(); + + // The default behavior is to use std140 for constant buffers and std430 for other buffers. + switch (bufferType->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + case kIROp_HLSLAppendStructuredBufferType: + case kIROp_HLSLConsumeStructuredBufferType: + case kIROp_HLSLRasterizerOrderedStructuredBufferType: + return IRTypeLayoutRules::getStd430(); + case kIROp_ConstantBufferType: + case kIROp_ParameterBlockType: + return IRTypeLayoutRules::getStd140(); + } + return IRTypeLayoutRules::getNatural(); + } + } |
