diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-lower-buffer-element-type.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 1821 |
1 files changed, 991 insertions, 830 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 9ea41e3b4..46243eba3 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -1,486 +1,570 @@ #include "slang-ir-lower-buffer-element-type.h" -#include "slang-ir.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" + #include "slang-ir-clone.h" +#include "slang-ir-insts.h" #include "slang-ir-layout.h" +#include "slang-ir-util.h" +#include "slang-ir.h" namespace Slang { - struct LoweredElementTypeContext - { - static const IRIntegerValue kMaxArraySizeToUnroll = 32; +struct LoweredElementTypeContext +{ + static const IRIntegerValue kMaxArraySizeToUnroll = 32; - struct LoweredElementTypeInfo - { - IRType* originalType; - IRType* loweredType; - IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field. - IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field. - IRFunc* convertOriginalToLowered = nullptr; - IRFunc* convertLoweredToOriginal = nullptr; - }; + struct LoweredElementTypeInfo + { + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // inner array type of the data field. + IRStructKey* loweredInnerStructKey = + nullptr; // For matrix/array types that are lowered into a struct type, this is the + // struct key of the data field. + IRFunc* convertOriginalToLowered = nullptr; + IRFunc* convertLoweredToOriginal = nullptr; + }; - Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count]; - Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count]; + Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo[(int)IRTypeLayoutRuleName::_Count]; + Dictionary<IRType*, LoweredElementTypeInfo> + mapLoweredTypeToInfo[(int)IRTypeLayoutRuleName::_Count]; - SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - TargetProgram* target; - bool lowerBufferPointer = false; + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + TargetProgram* target; + bool lowerBufferPointer = false; - LoweredElementTypeContext(TargetProgram* target, bool lowerBufferPointer, SlangMatrixLayoutMode inDefaultMatrixLayout) - : target(target), defaultMatrixLayout(inDefaultMatrixLayout), lowerBufferPointer(lowerBufferPointer) - {} + LoweredElementTypeContext( + TargetProgram* target, + bool lowerBufferPointer, + SlangMatrixLayoutMode inDefaultMatrixLayout) + : target(target) + , defaultMatrixLayout(inDefaultMatrixLayout) + , lowerBufferPointer(lowerBufferPointer) + { + } - IRFunc* createMatrixUnpackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRStructKey* dataKey, - IRArrayType* arrayType) + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParam = builder.emitParam(structType); + auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = (Index)getIntVal(matrixType->getRowCount()); - auto colCount = (Index)getIntVal(matrixType->getColumnCount()); - auto packedParam = builder.emitParam(structType); - auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); - List<IRInst*> args; - args.setCount(rowCount * colCount); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto vector = builder.emitElementExtract(vectorArray, c); - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = builder.emitElementExtract(vector, r); - args[(Index)(r*colCount + c)] = element; - } - } - } - else + for (IRIntegerValue c = 0; c < colCount; c++) { + auto vector = builder.emitElementExtract(vectorArray, c); for (IRIntegerValue r = 0; r < rowCount; r++) { - auto vector = builder.emitElementExtract(vectorArray, r); - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = builder.emitElementExtract(vector, c); - args[(Index)(r * colCount + c)] = element; - } + auto element = builder.emitElementExtract(vector, r); + args[(Index)(r * colCount + c)] = element; } } - IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); - builder.emitReturn(result); - return func; } - - IRFunc* createMatrixPackFunc( - IRMatrixType* matrixType, - IRStructType* structType, - IRVectorType* vectorType, - IRArrayType* arrayType) + else { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); - builder.setInsertInto(func); - builder.emitBlock(); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); - auto originalParam = builder.emitParam(matrixType); - List<IRInst*> elements; - elements.setCount((Index)(rowCount * colCount)); for (IRIntegerValue r = 0; r < rowCount; r++) { - auto vector = builder.emitElementExtract(originalParam, r); + auto vector = builder.emitElementExtract(vectorArray, r); for (IRIntegerValue c = 0; c < colCount; c++) { auto element = builder.emitElementExtract(vector, c); - elements[(Index)(r * colCount + c)] = element; + args[(Index)(r * colCount + c)] = element; } } - List<IRInst*> vectors; - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + } + IRInst* result = + builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRVectorType* vectorType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) { - for (IRIntegerValue c = 0; c < colCount; c++) - { - List<IRInst*> vecArgs; - for (IRIntegerValue r = 0; r < rowCount; r++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - auto colVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); - vectors.add(colVector); - } + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; } - else + } + List<IRInst*> vectors; + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) { + List<IRInst*> vecArgs; for (IRIntegerValue r = 0; r < rowCount; r++) { - List<IRInst*> vecArgs; - for (IRIntegerValue c = 0; c < colCount; c++) - { - auto element = elements[(Index)(r * colCount + c)]; - vecArgs.add(element); - } - auto rowVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); - vectors.add(rowVector); + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); } + auto colVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(colVector); } - - auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); - auto result = builder.emitMakeStruct(structType, 1, &vectorArray); - builder.emitReturn(result); - return func; } - - IRFunc* createArrayUnpackFunc( - IRArrayType* arrayType, - IRStructType* structType, - IRStructKey* dataKey, - IRArrayType* innerArrayType, - LoweredElementTypeInfo innerTypeInfo) + else { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); - builder.setInsertInto(func); - builder.emitBlock(); - auto packedParam = builder.emitParam(structType); - auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); - auto count = getIntVal(arrayType->getElementCount()); - IRInst* result = nullptr; - if (count <= kMaxArraySizeToUnroll) + for (IRIntegerValue r = 0; r < rowCount; r++) { - // If the array is small enough, just process each element directly. - List<IRInst*> args; - args.setCount((Index)count); - for (IRIntegerValue ii = 0; ii < count; ++ii) + List<IRInst*> vecArgs; + for (IRIntegerValue c = 0; c < colCount; c++) { - auto packedElement = builder.emitElementExtract(packedArray, ii); - auto originalElement = innerTypeInfo.convertLoweredToOriginal - ? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement) - : packedElement; - args[(Index)ii] = originalElement; + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); } - result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); - - } - else - { - // The general case for large arrays is to emit a loop through the elements. - IRVar* resultVar = builder.emitVar(arrayType); - IRBlock* loopBodyBlock; - IRBlock* loopBreakBlock; - auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count), - loopBodyBlock, loopBreakBlock); - - builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); - auto packedElement = builder.emitElementExtract(packedArray, loopParam); - auto originalElement = innerTypeInfo.convertLoweredToOriginal - ? builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement) - : packedElement; - auto varPtr = builder.emitElementAddress(resultVar, loopParam); - builder.emitStore(varPtr, originalElement); - builder.setInsertInto(loopBreakBlock); - result = builder.emitLoad(resultVar); + auto rowVector = builder.emitMakeVector( + vectorType, + (UInt)vecArgs.getCount(), + vecArgs.getBuffer()); + vectors.add(rowVector); } - builder.emitReturn(result); - return func; } - IRFunc* createArrayPackFunc( - IRArrayType* arrayType, - IRStructType* structType, - IRArrayType* innerArrayType, - LoweredElementTypeInfo innerTypeInfo) + auto vectorArray = + builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &vectorArray); + builder.emitReturn(result); + return func; + } + + IRFunc* createArrayUnpackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto packedParam = builder.emitParam(structType); + auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); + auto count = getIntVal(arrayType->getElementCount()); + IRInst* result = nullptr; + if (count <= kMaxArraySizeToUnroll) { - IRBuilder builder(structType); - builder.setInsertAfter(structType); - auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&arrayType, structType); - func->setFullType(funcType); - builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); - builder.setInsertInto(func); - builder.emitBlock(); - auto originalParam = builder.emitParam(arrayType); - IRInst* packedArray = nullptr; - auto count = getIntVal(arrayType->getElementCount()); - if (count <= kMaxArraySizeToUnroll) + // If the array is small enough, just process each element directly. + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) { - // If the array is small enough, just process each element directly. - List<IRInst*> args; - args.setCount((Index)count); - for (IRIntegerValue ii = 0; ii < count; ++ii) - { - auto originalElement = builder.emitElementExtract(originalParam, ii); - auto packedElement = innerTypeInfo.convertOriginalToLowered - ? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement) - : originalElement; - args[(Index)ii] = packedElement; - } - packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + auto packedElement = builder.emitElementExtract(packedArray, ii); + auto originalElement = innerTypeInfo.convertLoweredToOriginal + ? builder.emitCallInst( + innerTypeInfo.originalType, + innerTypeInfo.convertLoweredToOriginal, + 1, + &packedElement) + : packedElement; + args[(Index)ii] = originalElement; } - else - { - // The general case for large arrays is to emit a loop through the elements. - IRVar* packedArrayVar = builder.emitVar(innerArrayType); - IRBlock* loopBodyBlock; - IRBlock* loopBreakBlock; - auto loopParam = emitLoopBlocks(&builder, builder.getIntValue(builder.getIntType(), 0), builder.getIntValue(builder.getIntType(), count), - loopBodyBlock, loopBreakBlock); + result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + } + else + { + // The general case for large arrays is to emit a loop through the elements. + IRVar* resultVar = builder.emitVar(arrayType); + IRBlock* loopBodyBlock; + IRBlock* loopBreakBlock; + auto loopParam = emitLoopBlocks( + &builder, + builder.getIntValue(builder.getIntType(), 0), + builder.getIntValue(builder.getIntType(), count), + loopBodyBlock, + loopBreakBlock); - builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); - auto originalElement = builder.emitElementExtract(originalParam, loopParam); - auto packedElement = innerTypeInfo.convertOriginalToLowered - ? builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement) - : originalElement; - auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); - builder.emitStore(varPtr, packedElement); - builder.setInsertInto(loopBreakBlock); - packedArray = builder.emitLoad(packedArrayVar); - } - - auto result = builder.emitMakeStruct(structType, 1, &packedArray); - builder.emitReturn(result); - return func; + builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); + auto packedElement = builder.emitElementExtract(packedArray, loopParam); + auto originalElement = innerTypeInfo.convertLoweredToOriginal + ? builder.emitCallInst( + innerTypeInfo.originalType, + innerTypeInfo.convertLoweredToOriginal, + 1, + &packedElement) + : packedElement; + auto varPtr = builder.emitElementAddress(resultVar, loopParam); + builder.emitStore(varPtr, originalElement); + builder.setInsertInto(loopBreakBlock); + result = builder.emitLoad(resultVar); } + builder.emitReturn(result); + return func; + } - const char* getLayoutName(IRTypeLayoutRuleName name) + IRFunc* createArrayPackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&arrayType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto originalParam = builder.emitParam(arrayType); + IRInst* packedArray = nullptr; + auto count = getIntVal(arrayType->getElementCount()); + if (count <= kMaxArraySizeToUnroll) { - switch (name) + // If the array is small enough, just process each element directly. + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) { - case IRTypeLayoutRuleName::Std140: return "std140"; - case IRTypeLayoutRuleName::Std430: return "std430"; - case IRTypeLayoutRuleName::Natural: return "natural"; - default: return "default"; + auto originalElement = builder.emitElementExtract(originalParam, ii); + auto packedElement = innerTypeInfo.convertOriginalToLowered + ? builder.emitCallInst( + innerTypeInfo.loweredType, + innerTypeInfo.convertOriginalToLowered, + 1, + &originalElement) + : originalElement; + args[(Index)ii] = packedElement; } + packedArray = + builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + } + else + { + // The general case for large arrays is to emit a loop through the elements. + IRVar* packedArrayVar = builder.emitVar(innerArrayType); + IRBlock* loopBodyBlock; + IRBlock* loopBreakBlock; + auto loopParam = emitLoopBlocks( + &builder, + builder.getIntValue(builder.getIntType(), 0), + builder.getIntValue(builder.getIntType(), count), + loopBodyBlock, + loopBreakBlock); + + builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); + auto originalElement = builder.emitElementExtract(originalParam, loopParam); + auto packedElement = innerTypeInfo.convertOriginalToLowered + ? builder.emitCallInst( + innerTypeInfo.loweredType, + innerTypeInfo.convertOriginalToLowered, + 1, + &originalElement) + : originalElement; + auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); + builder.emitStore(varPtr, packedElement); + builder.setInsertInto(loopBreakBlock); + packedArray = builder.emitLoad(packedArrayVar); } - LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules) + auto result = builder.emitMakeStruct(structType, 1, &packedArray); + builder.emitReturn(result); + return func; + } + + const char* getLayoutName(IRTypeLayoutRuleName name) + { + switch (name) { - IRBuilder builder(type); - builder.setInsertAfter(type); + case IRTypeLayoutRuleName::Std140: return "std140"; + case IRTypeLayoutRuleName::Std430: return "std430"; + case IRTypeLayoutRuleName::Natural: return "natural"; + default: return "default"; + } + } - LoweredElementTypeInfo info; - info.originalType = type; + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, IRTypeLayoutRules* rules) + { + IRBuilder builder(type); + builder.setInsertAfter(type); - if (auto matrixType = as<IRMatrixType>(type)) + LoweredElementTypeInfo info; + info.originalType = type; + + if (auto matrixType = as<IRMatrixType>(type)) + { + // For spirv, we always want to lower all matrix types, because matrix types + // are considered abstract types. + if (!target->shouldEmitSPIRVDirectly()) { - // For spirv, we always want to lower all matrix types, because matrix types - // are considered abstract types. - if (!target->shouldEmitSPIRVDirectly()) + // For other targets, we only lower the matrix types if they differ from the default + // matrix layout. + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + rules->ruleName == IRTypeLayoutRuleName::Natural) { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - rules->ruleName == IRTypeLayoutRuleName::Natural) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } + } - auto loweredType = builder.createStructType(); - StringBuilder nameSB; - bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; - nameSB << "_MatrixStorage_"; - getTypeNameHint(nameSB, matrixType->getElementType()); - nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount()); - if (isColMajor) - nameSB << "_ColMajor"; - nameSB << getLayoutName(rules->ruleName); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - auto vectorType = builder.getVectorType(matrixType->getElementType(), - isColMajor?matrixType->getRowCount():matrixType->getColumnCount()); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment); - elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + bool isColMajor = + getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" + << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + nameSB << getLayoutName(rules->ruleName); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto vectorType = builder.getVectorType( + matrixType->getElementType(), + isColMajor ? matrixType->getRowCount() : matrixType->getColumnCount()); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment(target->getOptionSet(), rules, vectorType, &elementSizeAlignment); + elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); - auto arrayType = builder.getArrayType( - vectorType, - isColMajor?matrixType->getColumnCount():matrixType->getRowCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, arrayType); + auto arrayType = builder.getArrayType( + vectorType, + isColMajor ? matrixType->getColumnCount() : matrixType->getRowCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, arrayType); - info.loweredType = loweredType; - info.loweredInnerArrayType = arrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); - info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); - return info; - } - else if (auto arrayType = as<IRArrayType>(type)) + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = + createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + info.convertOriginalToLowered = + createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); + return info; + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); + // For spirv backend, we always want to lower all array types, even if the element type + // comes out the same. This is because different layout rules may have different array + // stride requirements. + if (!target->shouldEmitSPIRVDirectly()) { - auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); - // For spirv backend, we always want to lower all array types, even if the element type - // comes out the same. This is because different layout rules may have different array - // stride requirements. - if (!target->shouldEmitSPIRVDirectly()) + if (!loweredInnerTypeInfo.convertLoweredToOriginal) { - if (!loweredInnerTypeInfo.convertLoweredToOriginal) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } - auto loweredType = builder.createStructType(); - info.loweredType = loweredType; - StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_"; - getTypeNameHint(nameSB, arrayType->getElementType()); - nameSB << getIntVal(arrayType->getElementCount()); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment(target->getOptionSet(), rules, loweredInnerTypeInfo.loweredType, &elementSizeAlignment); - elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); - auto innerArrayType = builder.getArrayType( - loweredInnerTypeInfo.loweredType, - arrayType->getElementCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, innerArrayType); - info.loweredInnerArrayType = innerArrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo); - info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); - return info; } - else if (as<IRArrayTypeBase>(type)) + auto loweredType = builder.createStructType(); + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_Array_" << getLayoutName(rules->ruleName) << "_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + rules, + loweredInnerTypeInfo.loweredType, + &elementSizeAlignment); + elementSizeAlignment = rules->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayType( + loweredInnerTypeInfo.loweredType, + arrayType->getElementCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createArrayUnpackFunc( + arrayType, + loweredType, + structKey, + innerArrayType, + loweredInnerTypeInfo); + info.convertOriginalToLowered = + createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); + return info; + } + else if (as<IRArrayTypeBase>(type)) + { + info.loweredType = builder.getVoidType(); + return info; + } + else if (auto structType = as<IRStructType>(type)) + { + List<LoweredElementTypeInfo> fieldLoweredTypeInfo; + bool isTrivial = true; + for (auto field : structType->getFields()) { - info.loweredType = builder.getVoidType(); - return info; + auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules); + fieldLoweredTypeInfo.add(loweredFieldTypeInfo); + if (loweredFieldTypeInfo.convertLoweredToOriginal || + rules->ruleName != IRTypeLayoutRuleName::Natural) + isTrivial = false; } - else if (auto structType = as<IRStructType>(type)) - { - List<LoweredElementTypeInfo> fieldLoweredTypeInfo; - bool isTrivial = true; - for (auto field : structType->getFields()) - { - auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType(), rules); - fieldLoweredTypeInfo.add(loweredFieldTypeInfo); - if (loweredFieldTypeInfo.convertLoweredToOriginal || rules->ruleName != IRTypeLayoutRuleName::Natural) - isTrivial = false; - } - // For spirv backend, we always want to lower all array types, even if the element type - // comes out the same. This is because different layout rules may have different array - // stride requirements. - if (!target->shouldEmitSPIRVDirectly()) + // For spirv backend, we always want to lower all array types, even if the element type + // comes out the same. This is because different layout rules may have different array + // stride requirements. + if (!target->shouldEmitSPIRVDirectly()) + { + // For non-spirv target, we skip lowering this type if all field types are + // unchanged. + if (isTrivial) { - // For non-spirv target, we skip lowering this type if all field types are unchanged. - if (isTrivial) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } - auto loweredType = builder.createStructType(); - StringBuilder nameSB; - getTypeNameHint(nameSB, type); - nameSB << "_" << getLayoutName(rules->ruleName); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - info.loweredType = loweredType; - // Create fields. + } + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + getTypeNameHint(nameSB, type); + nameSB << "_" << getLayoutName(rules->ruleName); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + info.loweredType = loweredType; + // Create fields. + { + Index fieldId = 0; + for (auto field : structType->getFields()) { - Index fieldId = 0; - for (auto field : structType->getFields()) + if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) { - if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) - { - fieldId++; - continue; - } - auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; - builder.createStructField(loweredType, field->getKey(), loweredFieldTypeInfo.loweredType); fieldId++; + continue; } + auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; + builder.createStructField( + loweredType, + field->getKey(), + loweredFieldTypeInfo.loweredType); + fieldId++; } + } - // Create unpack func. + // Create unpack func. + { + builder.setInsertAfter(loweredType); + info.convertLoweredToOriginal = builder.createFunc(); + builder.setInsertInto(info.convertLoweredToOriginal); + builder.addNameHintDecoration( + info.convertLoweredToOriginal, + UnownedStringSlice("unpackStorage")); + info.convertLoweredToOriginal->setFullType( + builder.getFuncType(1, (IRType**)&loweredType, type)); + builder.emitBlock(); + auto loweredParam = builder.emitParam(loweredType); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) { - builder.setInsertAfter(loweredType); - info.convertLoweredToOriginal = builder.createFunc(); - builder.setInsertInto(info.convertLoweredToOriginal); - builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&loweredType, type)); - builder.emitBlock(); - auto loweredParam = builder.emitParam(loweredType); - List<IRInst*> args; - Index fieldId = 0; - for (auto field : structType->getFields()) + if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) { - if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) - { - fieldId++; - continue; - } - auto storageField = builder.emitFieldExtract(fieldLoweredTypeInfo[fieldId].loweredType, loweredParam, field->getKey()); - auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal - ? builder.emitCallInst(field->getFieldType(), fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, 1, &storageField) - : storageField; - args.add(unpackedField); fieldId++; + continue; } - auto result = builder.emitMakeStruct(type, args); - builder.emitReturn(result); + auto storageField = builder.emitFieldExtract( + fieldLoweredTypeInfo[fieldId].loweredType, + loweredParam, + field->getKey()); + auto unpackedField = + fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal + ? builder.emitCallInst( + field->getFieldType(), + fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, + 1, + &storageField) + : storageField; + args.add(unpackedField); + fieldId++; } + auto result = builder.emitMakeStruct(type, args); + builder.emitReturn(result); + } - // Create pack func. + // Create pack func. + { + builder.setInsertAfter(info.convertLoweredToOriginal); + info.convertOriginalToLowered = builder.createFunc(); + builder.setInsertInto(info.convertOriginalToLowered); + builder.addNameHintDecoration( + info.convertOriginalToLowered, + UnownedStringSlice("packStorage")); + info.convertOriginalToLowered->setFullType( + builder.getFuncType(1, (IRType**)&type, loweredType)); + builder.emitBlock(); + auto param = builder.emitParam(type); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) { - builder.setInsertAfter(info.convertLoweredToOriginal); - info.convertOriginalToLowered = builder.createFunc(); - builder.setInsertInto(info.convertOriginalToLowered); - builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage")); - info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, loweredType)); - builder.emitBlock(); - auto param = builder.emitParam(type); - List<IRInst*> args; - Index fieldId = 0; - for (auto field : structType->getFields()) + if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) { - if (as<IRVoidType>(fieldLoweredTypeInfo[fieldId].loweredType)) - { - fieldId++; - continue; - } - auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); - auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered - ? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal) - : fieldVal; - args.add(packedField); fieldId++; + continue; } - auto result = builder.emitMakeStruct(loweredType, args); - builder.emitReturn(result); + auto fieldVal = + builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); + auto packedField = + fieldLoweredTypeInfo[fieldId].convertOriginalToLowered + ? builder.emitCallInst( + fieldLoweredTypeInfo[fieldId].loweredType, + fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, + 1, + &fieldVal) + : fieldVal; + args.add(packedField); + fieldId++; } - - return info; + auto result = builder.emitMakeStruct(loweredType, args); + builder.emitReturn(result); } - if (target->shouldEmitSPIRVDirectly()) + return info; + } + + if (target->shouldEmitSPIRVDirectly()) + { + switch (target->getTargetReq()->getTarget()) { - switch (target->getTargetReq()->getTarget()) - { - case CodeGenTarget::SPIRV: - case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: { auto scalarType = type; auto vectorType = as<IRVectorType>(scalarType); @@ -492,14 +576,19 @@ namespace Slang // Bool is an abstract type in SPIRV, so we need to lower them into an int. info.loweredType = builder.getIntType(); if (vectorType) - info.loweredType = builder.getVectorType(info.loweredType, vectorType->getElementCount()); + info.loweredType = builder.getVectorType( + info.loweredType, + vectorType->getElementCount()); // Create unpack func. { builder.setInsertAfter(type); info.convertLoweredToOriginal = builder.createFunc(); builder.setInsertInto(info.convertLoweredToOriginal); - builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&info.loweredType, type)); + builder.addNameHintDecoration( + info.convertLoweredToOriginal, + UnownedStringSlice("unpackStorage")); + info.convertLoweredToOriginal->setFullType( + builder.getFuncType(1, (IRType**)&info.loweredType, type)); builder.emitBlock(); auto loweredParam = builder.emitParam(info.loweredType); auto result = builder.emitCast(type, loweredParam); @@ -511,8 +600,11 @@ namespace Slang builder.setInsertAfter(info.convertLoweredToOriginal); info.convertOriginalToLowered = builder.createFunc(); builder.setInsertInto(info.convertOriginalToLowered); - builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage")); - info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, info.loweredType)); + builder.addNameHintDecoration( + info.convertOriginalToLowered, + UnownedStringSlice("packStorage")); + info.convertOriginalToLowered->setFullType( + builder.getFuncType(1, (IRType**)&type, info.loweredType)); builder.emitBlock(); auto param = builder.emitParam(type); auto result = builder.emitCast(info.loweredType, param); @@ -521,530 +613,599 @@ namespace Slang return info; } } - default: - break; - } + default: break; } + } + + info.loweredType = type; + return info; + } + LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules) + { + // If `type` is already a lowered type, no more lowering is required. + LoweredElementTypeInfo info; + if (mapLoweredTypeToInfo->tryGetValue(type)) + { + info.originalType = type; info.loweredType = type; return info; } - LoweredElementTypeInfo getLoweredTypeInfo(IRType* type, IRTypeLayoutRules* rules) - { - // If `type` is already a lowered type, no more lowering is required. - LoweredElementTypeInfo info; - if (mapLoweredTypeToInfo->tryGetValue(type)) - { - info.originalType = type; - info.loweredType = type; - return info; - } - - if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info)) - return info; - info = getLoweredTypeInfoImpl(type, rules); - IRSizeAndAlignment sizeAlignment; - getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment); - loweredTypeInfo[(int)rules->ruleName].set(type, info); - mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info); + if (loweredTypeInfo[(int)rules->ruleName].tryGetValue(type, info)) return info; - } + info = getLoweredTypeInfoImpl(type, rules); + IRSizeAndAlignment sizeAlignment; + getSizeAndAlignment(target->getOptionSet(), rules, info.loweredType, &sizeAlignment); + loweredTypeInfo[(int)rules->ruleName].set(type, info); + mapLoweredTypeToInfo[(int)rules->ruleName].set(info.loweredType, info); + return info; + } - IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) + IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) + { + if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || + as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType)) { - if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType)) - { - IRBuilder builder(newElementType); - builder.setInsertAfter(newElementType); - ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPtrLikeType->getOperandCount(); i++) - operands.add(originalPtrLikeType->getOperand(i)); - operands[0] = newElementType; - return builder.getType(originalPtrLikeType->getOp(), (UInt)operands.getCount(), operands.getArrayView().getBuffer()); - } - SLANG_UNREACHABLE("unhandled ptr like or buffer type"); + IRBuilder builder(newElementType); + builder.setInsertAfter(newElementType); + ShortList<IRInst*> operands; + for (UInt i = 0; i < originalPtrLikeType->getOperandCount(); i++) + operands.add(originalPtrLikeType->getOperand(i)); + operands[0] = newElementType; + return builder.getType( + originalPtrLikeType->getOp(), + (UInt)operands.getCount(), + operands.getArrayView().getBuffer()); } + SLANG_UNREACHABLE("unhandled ptr like or buffer type"); + } - IRInst* getStoreVal(IRInst* storeInst) - { - if (auto store = as<IRStore>(storeInst)) - return store->getVal(); - else if (auto sbStore = as<IRRWStructuredBufferStore>(storeInst)) - return sbStore->getVal(); - return nullptr; - } + IRInst* getStoreVal(IRInst* storeInst) + { + if (auto store = as<IRStore>(storeInst)) + return store->getVal(); + else if (auto sbStore = as<IRRWStructuredBufferStore>(storeInst)) + return sbStore->getVal(); + return nullptr; + } - struct MatrixAddrWorkItem + struct MatrixAddrWorkItem + { + IRInst* matrixAddrInst; + IRTypeLayoutRules* layoutRules; + }; + + void processModule(IRModule* module) + { + IRBuilder builder(module); + struct BufferTypeInfo { - IRInst* matrixAddrInst; - IRTypeLayoutRules* layoutRules; + IRType* bufferType; + IRType* elementType; }; - - void processModule(IRModule* module) + List<BufferTypeInfo> bufferTypeInsts; + for (auto globalInst : module->getGlobalInsts()) { - IRBuilder builder(module); - struct BufferTypeInfo - { - IRType* bufferType; - IRType* elementType; - }; - List<BufferTypeInfo> bufferTypeInsts; - for (auto globalInst : module->getGlobalInsts()) + IRType* elementType = nullptr; + if (lowerBufferPointer) { - IRType* elementType = nullptr; - if (lowerBufferPointer) - { - if (auto ptrType = as<IRPtrType>(globalInst)) - { - if (ptrType->getAddressSpace() == AddressSpace::UserPointer) - elementType = ptrType->getValueType(); - } - } - else + if (auto ptrType = as<IRPtrType>(globalInst)) { - if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) - elementType = structBuffer->getElementType(); - else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) - elementType = constBuffer->getElementType(); + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) + elementType = ptrType->getValueType(); } - if (as<IRTextureBufferType>(globalInst)) - continue; - if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType) && !as<IRBoolType>(elementType)) - continue; - bufferTypeInsts.add(BufferTypeInfo{ (IRType*)globalInst, elementType }); } + else + { + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); + } + if (as<IRTextureBufferType>(globalInst)) + continue; + if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && + !as<IRArrayType>(elementType) && !as<IRBoolType>(elementType)) + continue; + bufferTypeInsts.add(BufferTypeInfo{(IRType*)globalInst, elementType}); + } - // Maintain a pending work list of all matrix addresses, and try to lower them out of existance - // after everything else has been lowered. - - List<MatrixAddrWorkItem> matrixAddrInsts; + // Maintain a pending work list of all matrix addresses, and try to lower them out of + // existance after everything else has been lowered. - for (auto bufferTypeInfo : bufferTypeInsts) - { - auto bufferType = bufferTypeInfo.bufferType; - auto elementType = bufferTypeInfo.elementType; - auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType); - auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules); + List<MatrixAddrWorkItem> matrixAddrInsts; - // If the lowered type is the same as original type, no change is required. - if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) - continue; + for (auto bufferTypeInfo : bufferTypeInsts) + { + auto bufferType = bufferTypeInfo.bufferType; + auto elementType = bufferTypeInfo.elementType; + auto layoutRules = getTypeLayoutRuleForBuffer(target, bufferType); + auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, layoutRules); - builder.setInsertBefore(bufferType); + // If the lowered type is the same as original type, no change is required. + if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) + continue; - ShortList<IRInst*> typeOperands; - for (UInt i = 0; i < bufferType->getOperandCount(); i++) - typeOperands.add(bufferType->getOperand(i)); - typeOperands[0] = loweredBufferElementTypeInfo.loweredType; - auto loweredBufferType = builder.getType( - bufferType->getOp(), - (UInt)typeOperands.getCount(), - typeOperands.getArrayView().getBuffer()); + builder.setInsertBefore(bufferType); - // We treat a value of a buffer type as a pointer, and use a work list to translate - // all loads and stores through the pointer values that needs lowering. + ShortList<IRInst*> typeOperands; + for (UInt i = 0; i < bufferType->getOperandCount(); i++) + typeOperands.add(bufferType->getOperand(i)); + typeOperands[0] = loweredBufferElementTypeInfo.loweredType; + auto loweredBufferType = builder.getType( + bufferType->getOp(), + (UInt)typeOperands.getCount(), + typeOperands.getArrayView().getBuffer()); - List<IRInst*> ptrValsWorkList; - traverseUses(bufferType, [&](IRUse* use) - { - auto user = use->getUser(); - if (use != &user->typeUse) - return; - ptrValsWorkList.add(use->getUser()); - }); + // We treat a value of a buffer type as a pointer, and use a work list to translate + // all loads and stores through the pointer values that needs lowering. - // Translate the values to use new lowered buffer type instead. - for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + List<IRInst*> ptrValsWorkList; + traverseUses( + bufferType, + [&](IRUse* use) { - auto ptrVal = ptrValsWorkList[i]; - auto oldPtrType = ptrVal->getFullType(); - auto originalElementType = oldPtrType->getOperand(0); + auto user = use->getUser(); + if (use != &user->typeUse) + return; + ptrValsWorkList.add(use->getUser()); + }); + + // Translate the values to use new lowered buffer type instead. + for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + { + auto ptrVal = ptrValsWorkList[i]; + auto oldPtrType = ptrVal->getFullType(); + auto originalElementType = oldPtrType->getOperand(0); - // If we are accessing an unsized array element from a pointer, we need to compute - // the trailing ptr that points to the first element of the array. - // And then replace all getElementPtr(arrayPtr, index) with getOffsetPtr(trailingPtr, index). - if (auto fieldAddr = as<IRFieldAddress>(ptrVal)) + // If we are accessing an unsized array element from a pointer, we need to compute + // the trailing ptr that points to the first element of the array. + // And then replace all getElementPtr(arrayPtr, index) with + // getOffsetPtr(trailingPtr, index). + if (auto fieldAddr = as<IRFieldAddress>(ptrVal)) + { + if (auto ptrType = as<IRPtrType>(ptrVal->getDataType())) { - if (auto ptrType = as<IRPtrType>(ptrVal->getDataType())) + if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType())) { - if (auto unsizedArrayType = as<IRUnsizedArrayType>(ptrType->getValueType())) + builder.setInsertBefore(ptrVal); + auto newArrayPtrVal = fieldAddr->getBase(); + // Is base a pointer to an empty struct? If so, don't offset it. + // For example, if the user has written: + // ``` + // struct S {int arr[]}; + // uniform S* p; + // void test() { p->arr[1]; } + // ``` + // Then `S` will become an empty struct after we remove `arr[]`. + // And `p` will be come a `void*`. + // We don't want to offset `p` to `p+1` to get the starting address of + // the array in this case. + IRSizeAndAlignment parentStructSize = {}; + getNaturalSizeAndAlignment( + target->getOptionSet(), + tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), + &parentStructSize); + if (parentStructSize.size != 0) { - builder.setInsertBefore(ptrVal); - auto newArrayPtrVal = fieldAddr->getBase(); - // Is base a pointer to an empty struct? If so, don't offset it. - // For example, if the user has written: - // ``` - // struct S {int arr[]}; - // uniform S* p; - // void test() { p->arr[1]; } - // ``` - // Then `S` will become an empty struct after we remove `arr[]`. - // And `p` will be come a `void*`. - // We don't want to offset `p` to `p+1` to get the starting address of the array in this case. - IRSizeAndAlignment parentStructSize = {}; - getNaturalSizeAndAlignment( - target->getOptionSet(), - tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), - &parentStructSize); - if (parentStructSize.size != 0) - { - newArrayPtrVal = builder.emitGetOffsetPtr(fieldAddr->getBase(), builder.getIntValue(builder.getIntType(), 1)); - } - auto loweredInnerType = getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules); - - IRSizeAndAlignment arrayElementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), layoutRules, loweredInnerType.loweredType, &arrayElementSizeAlignment); - IRSizeAndAlignment baseSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - layoutRules, - tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), - &baseSizeAlignment); - - // Convert pointer to uint64 and adjust offset. - IRIntegerValue offset = baseSizeAlignment.size; - offset = align(offset, arrayElementSizeAlignment.alignment); - if (offset != 0) - { - auto rawPtr = builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); - newArrayPtrVal = builder.emitAdd(rawPtr->getFullType(), rawPtr, - builder.getIntValue(builder.getUInt64Type(), offset)); - } - newArrayPtrVal = builder.emitBitCast( - builder.getPtrType(loweredInnerType.loweredType, - ptrType->getAddressSpace()), newArrayPtrVal); - traverseUses(ptrVal, [&](IRUse* use) - { - auto user = use->getUser(); - if (user->getOp() == kIROp_GetElementPtr) - { - builder.setInsertBefore(user); - auto newElementPtr = builder.emitGetOffsetPtr(newArrayPtrVal, user->getOperand(1)); - user->replaceUsesWith(newElementPtr); - user->removeAndDeallocate(); - ptrValsWorkList.add(newElementPtr); - } - else if (user->getOp() == kIROp_GetOffsetPtr) - { - } - else - { - SLANG_UNEXPECTED("unknown use of pointer to unsized array."); - } - }); - SLANG_ASSERT(!ptrVal->hasUses()); - ptrVal->removeAndDeallocate(); - continue; + newArrayPtrVal = builder.emitGetOffsetPtr( + fieldAddr->getBase(), + builder.getIntValue(builder.getIntType(), 1)); } - } - } + auto loweredInnerType = + getLoweredTypeInfo(unsizedArrayType->getElementType(), layoutRules); - auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType, layoutRules); - if (!loweredElementTypeInfo.convertLoweredToOriginal) - continue; - - ptrVal->setFullType(getLoweredPtrLikeType(ptrVal->getFullType(), loweredElementTypeInfo.loweredType)); + IRSizeAndAlignment arrayElementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + layoutRules, + loweredInnerType.loweredType, + &arrayElementSizeAlignment); + IRSizeAndAlignment baseSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + layoutRules, + tryGetPointedToType(&builder, fieldAddr->getBase()->getDataType()), + &baseSizeAlignment); - traverseUses(ptrVal, [&](IRUse* use) - { - auto user = use->getUser(); - if (as<IRDecoration>(user)) - return; - switch (user->getOp()) + // Convert pointer to uint64 and adjust offset. + IRIntegerValue offset = baseSizeAlignment.size; + offset = align(offset, arrayElementSizeAlignment.alignment); + if (offset != 0) { - case kIROp_Load: - case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferLoadStatus: - case kIROp_RWStructuredBufferLoad: - case kIROp_RWStructuredBufferLoadStatus: - case kIROp_StructuredBufferConsume: - { - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto newLoad = cloneInst(&cloneEnv, &builder, user); - newLoad->setFullType(loweredElementTypeInfo.loweredType); - auto unpackedVal = builder.emitCallInst((IRType*)originalElementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); - user->replaceUsesWith(unpackedVal); - user->removeAndDeallocate(); - break; - } - case kIROp_Store: - case kIROp_RWStructuredBufferStore: - case kIROp_StructuredBufferAppend: - { - // Use must be the dest operand of the store inst. - if (use != user->getOperands() + 0) - break; - IRCloneEnv cloneEnv = {}; - builder.setInsertBefore(user); - auto originalVal = getStoreVal(user); - auto packedVal = builder.emitCallInst(loweredElementTypeInfo.loweredType, loweredElementTypeInfo.convertOriginalToLowered, 1, &originalVal); - if (auto store = as<IRStore>(user)) - store->val.set(packedVal); - else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) - sbStore->setOperand(2, packedVal); - else if (auto sbAppend = as<IRStructuredBufferAppend>(user)) - sbAppend->setOperand(1, packedVal); - else - SLANG_UNREACHABLE("unhandled store type"); - break; - } - case kIROp_GetElementPtr: - case kIROp_FieldAddress: + auto rawPtr = + builder.emitBitCast(builder.getUInt64Type(), newArrayPtrVal); + newArrayPtrVal = builder.emitAdd( + rawPtr->getFullType(), + rawPtr, + builder.getIntValue(builder.getUInt64Type(), offset)); + } + newArrayPtrVal = builder.emitBitCast( + builder.getPtrType( + loweredInnerType.loweredType, + ptrType->getAddressSpace()), + newArrayPtrVal); + traverseUses( + ptrVal, + [&](IRUse* use) { - // If original type is an array, the lowered type will be a struct. - // In that case, all existing address insts should be appended with a field extract. - if (as<IRArrayType>(originalElementType)) + auto user = use->getUser(); + if (user->getOp() == kIROp_GetElementPtr) { builder.setInsertBefore(user); - List<IRInst*> args; - for (UInt i = 0; i < user->getOperandCount(); i++) - args.add(user->getOperand(i)); - auto newArrayPtrVal = builder.emitFieldAddress( - builder.getPtrType(loweredElementTypeInfo.loweredInnerArrayType), - ptrVal, - loweredElementTypeInfo.loweredInnerStructKey); - builder.replaceOperand(use, newArrayPtrVal); - ptrValsWorkList.add(user); + auto newElementPtr = builder.emitGetOffsetPtr( + newArrayPtrVal, + user->getOperand(1)); + user->replaceUsesWith(newElementPtr); + user->removeAndDeallocate(); + ptrValsWorkList.add(newElementPtr); } - else if (as<IRMatrixType>(originalElementType)) + else if (user->getOp() == kIROp_GetOffsetPtr) { - // We are tring to get a pointer to a lowered matrix element. - // We process this insts at a later phase. - SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); - matrixAddrInsts.add(MatrixAddrWorkItem{ user, layoutRules }); } else { - // If we getting a derived address from the pointer, we need to recursively - // lower the new address. We do so by pushing the address inst into the - // work list. - ptrValsWorkList.add(user); + SLANG_UNEXPECTED( + "unknown use of pointer to unsized array."); } - } - break; - case kIROp_RWStructuredBufferGetElementPtr: - case kIROp_GetOffsetPtr: - ptrValsWorkList.add(user); - break; - case kIROp_StructuredBufferGetDimensions: - break; - case kIROp_Call: - { - // If a structured buffer or pointer typed value is used directly as an argument, - // we don't need to do any marshalling here. - if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType())) - break; - if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType())) - break; - // If we are calling a function with an l-value pointer from buffer access, - // we need to materialize the object as a local variable, and pass the address - // of the local variable to the function. - builder.setInsertBefore(user); - auto newLoad = builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal); - auto unpackedVal = builder.emitCallInst((IRType*)originalElementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); - auto var = builder.emitVar((IRType*)originalElementType); - builder.emitStore(var, unpackedVal); - use->set(var); - builder.setInsertAfter(user); - auto newVal = builder.emitLoad(var); - auto packedVal = builder.emitCallInst((IRType*)loweredElementTypeInfo.loweredType, loweredElementTypeInfo.convertOriginalToLowered, 1, &newVal); - builder.emitStore(ptrVal, packedVal); - } - break; - default: - break; - } - }); + }); + SLANG_ASSERT(!ptrVal->hasUses()); + ptrVal->removeAndDeallocate(); + continue; + } + } } - // Replace all remaining uses of bufferType to loweredBufferType, these uses are non-operational and should be - // directly replaceable, such as uses in `IRFuncType`. - bufferType->replaceUsesWith(loweredBufferType); - bufferType->removeAndDeallocate(); - } - - // Process all matrix address uses. - lowerMatrixAddresses(module, matrixAddrInsts); - } + auto loweredElementTypeInfo = + getLoweredTypeInfo((IRType*)originalElementType, layoutRules); + if (!loweredElementTypeInfo.convertLoweredToOriginal) + continue; - // Lower all getElementPtr insts of a lowered matrix out of existance. - void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts) - { - IRBuilder builder(module); - for (auto workItem : matrixAddrInsts) - { - auto majorAddr = workItem.matrixAddrInst; - auto layoutRules = workItem.layoutRules; + ptrVal->setFullType(getLoweredPtrLikeType( + ptrVal->getFullType(), + loweredElementTypeInfo.loweredType)); - int layoutRuleName = (int)layoutRules->ruleName; - auto majorGEP = as<IRGetElementPtr>(majorAddr); - SLANG_ASSERT(majorGEP); - auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); - auto matrixTypeInfo = mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType); - SLANG_ASSERT(matrixTypeInfo); - auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); - auto rowCount = getIntVal(matrixType->getRowCount()); - traverseUses(majorAddr, [&](IRUse* use) + traverseUses( + ptrVal, + [&](IRUse* use) { auto user = use->getUser(); - builder.setInsertBefore(user); + if (as<IRDecoration>(user)) + return; switch (user->getOp()) { case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_StructuredBufferConsume: { - IRInst* resultInst = nullptr; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType(majorAddr->getDataType(), matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - List<IRInst*> args; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vector = builder.emitLoad(builder.emitElementAddress(dataPtr, i)); - auto element = builder.emitElementExtract(vector, majorGEP->getIndex()); - args.add(element); - } - resultInst = builder.emitMakeVector(builder.getVectorType(matrixType->getElementType(), (IRIntegerValue)args.getCount()), args); - } - else - { - auto element = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - resultInst = builder.emitLoad(element); - } - user->replaceUsesWith(resultInst); + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + auto unpackedVal = builder.emitCallInst( + (IRType*)originalElementType, + loweredElementTypeInfo.convertLoweredToOriginal, + 1, + &newLoad); + user->replaceUsesWith(unpackedVal); user->removeAndDeallocate(); + break; } - break; case kIROp_Store: + case kIROp_RWStructuredBufferStore: + case kIROp_StructuredBufferAppend: { - auto storeInst = cast<IRStore>(user); - if (storeInst->getOperand(0) != majorAddr) + // Use must be the dest operand of the store inst. + if (use != user->getOperands() + 0) break; - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType(majorAddr->getDataType(), matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) - { - for (IRIntegerValue i = 0; i < rowCount; i++) - { - auto vectorAddr = builder.emitElementAddress(dataPtr, i); - auto elementAddr = builder.emitElementAddress(vectorAddr, majorGEP->getIndex()); - builder.emitStore(elementAddr, builder.emitElementExtract(storeInst->getVal(), i)); - } - } + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto originalVal = getStoreVal(user); + auto packedVal = builder.emitCallInst( + loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.convertOriginalToLowered, + 1, + &originalVal); + if (auto store = as<IRStore>(user)) + store->val.set(packedVal); + else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) + sbStore->setOperand(2, packedVal); + else if (auto sbAppend = as<IRStructuredBufferAppend>(user)) + sbAppend->setOperand(1, packedVal); else - { - auto rowAddr = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); - builder.emitStore(rowAddr, storeInst->getVal()); - user->removeAndDeallocate(); - } + SLANG_UNREACHABLE("unhandled store type"); break; } case kIROp_GetElementPtr: + case kIROp_FieldAddress: { - auto gep2 = cast<IRGetElementPtr>(user); - auto rowIndex = majorGEP->getIndex(); - auto colIndex = gep2->getIndex(); - if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + // If original type is an array, the lowered type will be a struct. + // In that case, all existing address insts should be appended with + // a field extract. + if (as<IRArrayType>(originalElementType)) + { + builder.setInsertBefore(user); + List<IRInst*> args; + for (UInt i = 0; i < user->getOperandCount(); i++) + args.add(user->getOperand(i)); + auto newArrayPtrVal = builder.emitFieldAddress( + builder.getPtrType( + loweredElementTypeInfo.loweredInnerArrayType), + ptrVal, + loweredElementTypeInfo.loweredInnerStructKey); + builder.replaceOperand(use, newArrayPtrVal); + ptrValsWorkList.add(user); + } + else if (as<IRMatrixType>(originalElementType)) { - Swap(rowIndex, colIndex); + // We are tring to get a pointer to a lowered matrix element. + // We process this insts at a later phase. + SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); + matrixAddrInsts.add(MatrixAddrWorkItem{user, layoutRules}); } - auto dataPtr = builder.emitFieldAddress( - getLoweredPtrLikeType(majorAddr->getDataType(), matrixTypeInfo->loweredInnerArrayType), - majorGEP->getBase(), - matrixTypeInfo->loweredInnerStructKey); - auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); - auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); - gep2->replaceUsesWith(elementAddr); - gep2->removeAndDeallocate(); - break; + else + { + // If we getting a derived address from the pointer, we need to + // recursively lower the new address. We do so by pushing the + // address inst into the work list. + ptrValsWorkList.add(user); + } + } + break; + case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_GetOffsetPtr: ptrValsWorkList.add(user); break; + case kIROp_StructuredBufferGetDimensions: break; + case kIROp_Call: + { + // If a structured buffer or pointer typed value is used directly as + // an argument, we don't need to do any marshalling here. + if (as<IRHLSLStructuredBufferTypeBase>(ptrVal->getDataType())) + break; + if (lowerBufferPointer && as<IRPtrType>(ptrVal->getDataType())) + break; + // If we are calling a function with an l-value pointer from buffer + // access, we need to materialize the object as a local variable, + // and pass the address of the local variable to the function. + builder.setInsertBefore(user); + auto newLoad = + builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal); + auto unpackedVal = builder.emitCallInst( + (IRType*)originalElementType, + loweredElementTypeInfo.convertLoweredToOriginal, + 1, + &newLoad); + auto var = builder.emitVar((IRType*)originalElementType); + builder.emitStore(var, unpackedVal); + use->set(var); + builder.setInsertAfter(user); + auto newVal = builder.emitLoad(var); + auto packedVal = builder.emitCallInst( + (IRType*)loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.convertOriginalToLowered, + 1, + &newVal); + builder.emitStore(ptrVal, packedVal); } - default: - SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs storage lowering."); break; + default: break; } }); } + + // Replace all remaining uses of bufferType to loweredBufferType, these uses are + // non-operational and should be directly replaceable, such as uses in `IRFuncType`. + bufferType->replaceUsesWith(loweredBufferType); + bufferType->removeAndDeallocate(); } - }; - void lowerBufferElementTypeToStorageType(TargetProgram* target, IRModule* module, bool lowerBufferPointer) + // Process all matrix address uses. + lowerMatrixAddresses(module, matrixAddrInsts); + } + + // Lower all getElementPtr insts of a lowered matrix out of existance. + void lowerMatrixAddresses(IRModule* module, List<MatrixAddrWorkItem>& matrixAddrInsts) { - SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); - if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || isMetalTarget(target->getTargetReq()))) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) - defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; - LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode); - context.processModule(module); + IRBuilder builder(module); + for (auto workItem : matrixAddrInsts) + { + auto majorAddr = workItem.matrixAddrInst; + auto layoutRules = workItem.layoutRules; + + int layoutRuleName = (int)layoutRules->ruleName; + auto majorGEP = as<IRGetElementPtr>(majorAddr); + SLANG_ASSERT(majorGEP); + auto loweredMatrixType = + cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); + auto matrixTypeInfo = + mapLoweredTypeToInfo[layoutRuleName].tryGetValue(loweredMatrixType); + SLANG_ASSERT(matrixTypeInfo); + auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); + auto rowCount = getIntVal(matrixType->getRowCount()); + traverseUses( + majorAddr, + [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) + { + case kIROp_Load: + { + IRInst* resultInst = nullptr; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == + SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + List<IRInst*> args; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vector = + builder.emitLoad(builder.emitElementAddress(dataPtr, i)); + auto element = + builder.emitElementExtract(vector, majorGEP->getIndex()); + args.add(element); + } + resultInst = builder.emitMakeVector( + builder.getVectorType( + matrixType->getElementType(), + (IRIntegerValue)args.getCount()), + args); + } + else + { + auto element = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + resultInst = builder.emitLoad(element); + } + user->replaceUsesWith(resultInst); + user->removeAndDeallocate(); + } + break; + case kIROp_Store: + { + auto storeInst = cast<IRStore>(user); + if (storeInst->getOperand(0) != majorAddr) + break; + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == + SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vectorAddr = builder.emitElementAddress(dataPtr, i); + auto elementAddr = builder.emitElementAddress( + vectorAddr, + majorGEP->getIndex()); + builder.emitStore( + elementAddr, + builder.emitElementExtract(storeInst->getVal(), i)); + } + } + else + { + auto rowAddr = + builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + builder.emitStore(rowAddr, storeInst->getVal()); + user->removeAndDeallocate(); + } + break; + } + case kIROp_GetElementPtr: + { + auto gep2 = cast<IRGetElementPtr>(user); + auto rowIndex = majorGEP->getIndex(); + auto colIndex = gep2->getIndex(); + if (getIntVal(matrixType->getLayout()) == + SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + Swap(rowIndex, colIndex); + } + auto dataPtr = builder.emitFieldAddress( + getLoweredPtrLikeType( + majorAddr->getDataType(), + matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); + auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); + gep2->replaceUsesWith(elementAddr); + gep2->removeAndDeallocate(); + break; + } + default: + SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs " + "storage lowering."); + break; + } + }); + } } +}; + +void lowerBufferElementTypeToStorageType( + TargetProgram* target, + IRModule* module, + bool lowerBufferPointer) +{ + SlangMatrixLayoutMode defaultMatrixMode = + (SlangMatrixLayoutMode)target->getOptionSet().getMatrixLayoutMode(); + if ((isCPUTarget(target->getTargetReq()) || isCUDATarget(target->getTargetReq()) || + isMetalTarget(target->getTargetReq()))) + defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + else if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + LoweredElementTypeContext context(target, lowerBufferPointer, defaultMatrixMode); + context.processModule(module); +} - IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) +{ + if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) { - if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) - { - if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + if (!isKhronosTarget(target->getTargetReq())) + return IRTypeLayoutRules::getNatural(); - // If we are just emitting GLSL, we can just use the general layout rule. - if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + // If we are just emitting GLSL, we can just use the general layout rule. + if (!target->shouldEmitSPIRVDirectly()) + return IRTypeLayoutRules::getNatural(); - // If the user specified a scalar buffer layout, then just use that. - if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); - } + // If the user specified a scalar buffer layout, then just use that. + if (target->getOptionSet().shouldUseScalarLayout()) + return IRTypeLayoutRules::getNatural(); + } - if (target->getOptionSet().shouldUseDXLayout()) + if (target->getOptionSet().shouldUseDXLayout()) + { + if (as<IRUniformParameterGroupType>(bufferType)) { - if (as<IRUniformParameterGroupType>(bufferType)) - { - return IRTypeLayoutRules::getConstantBuffer(); - } - else - return IRTypeLayoutRules::getNatural(); + return IRTypeLayoutRules::getConstantBuffer(); } + else + return IRTypeLayoutRules::getNatural(); + } - // The default behavior is to use std140 for constant buffers and std430 for other buffers. - switch (bufferType->getOp()) - { - case kIROp_HLSLStructuredBufferType: - case kIROp_HLSLRWStructuredBufferType: - case kIROp_HLSLAppendStructuredBufferType: - case kIROp_HLSLConsumeStructuredBufferType: - case kIROp_HLSLRasterizerOrderedStructuredBufferType: + // The default behavior is to use std140 for constant buffers and std430 for other buffers. + switch (bufferType->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + case kIROp_HLSLAppendStructuredBufferType: + case kIROp_HLSLConsumeStructuredBufferType: + case kIROp_HLSLRasterizerOrderedStructuredBufferType: { auto structBufferType = as<IRHLSLStructuredBufferTypeBase>(bufferType); auto layoutTypeOp = structBufferType->getDataLayout() - ? structBufferType->getDataLayout()->getOp() - : kIROp_DefaultBufferLayoutType; + ? structBufferType->getDataLayout()->getOp() + : kIROp_DefaultBufferLayoutType; switch (layoutTypeOp) { - case kIROp_DefaultBufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); + case kIROp_DefaultBufferLayoutType: return IRTypeLayoutRules::getStd430(); + case kIROp_Std140BufferLayoutType: return IRTypeLayoutRules::getStd140(); + case kIROp_Std430BufferLayoutType: return IRTypeLayoutRules::getStd430(); + case kIROp_ScalarBufferLayoutType: return IRTypeLayoutRules::getNatural(); } return IRTypeLayoutRules::getStd430(); } - case kIROp_ConstantBufferType: - case kIROp_ParameterBlockType: - return IRTypeLayoutRules::getStd140(); - case kIROp_PtrType: - return IRTypeLayoutRules::getNatural(); - } - return IRTypeLayoutRules::getNatural(); + case kIROp_ConstantBufferType: + case kIROp_ParameterBlockType: return IRTypeLayoutRules::getStd140(); + case kIROp_PtrType: return IRTypeLayoutRules::getNatural(); } - + return IRTypeLayoutRules::getNatural(); } + +} // namespace Slang |
