diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-14 16:23:19 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-14 16:23:19 -0700 |
| commit | 661d6198bbb9857d3fdc6df477e0742ed0b0765c (patch) | |
| tree | 974a57cfa2e43624e91502e9e652a0cc78105b3a /source/slang/slang-ir-lower-buffer-element-type.cpp | |
| parent | 0403e0556b470f6b316153caea2dc6f5c314da5b (diff) | |
Support per field matrix layout (#3101)
* Support per field matrix layout
* Fix warnings.
* Fix.
* Fix tests.
* Fix spiv gen.
* Fix.
* More test fixes.
* Fix.
* Run only GPU tests on self-hosted servers.
* Remove -use-glsl-matrix-layout-modifier.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-buffer-element-type.cpp | 658 |
1 files changed, 658 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp new file mode 100644 index 000000000..3ef94d415 --- /dev/null +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -0,0 +1,658 @@ +#include "slang-ir-lower-buffer-element-type.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + struct LoweredElementTypeContext + { + struct LoweredElementTypeInfo + { + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field. + IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field. + IRFunc* convertOriginalToLowered = nullptr; + IRFunc* convertLoweredToOriginal = nullptr; + }; + Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo; + Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo; + + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + + LoweredElementTypeContext(SlangMatrixLayoutMode inDefaultMatrixLayout) + : defaultMatrixLayout(inDefaultMatrixLayout) + {} + + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParam = builder.emitParam(structType); + auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto vector = builder.emitElementExtract(vectorArray, c); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = builder.emitElementExtract(vector, r); + args[(Index)(r*colCount + c)] = element; + } + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(vectorArray, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + args[(Index)(r * colCount + c)] = element; + } + } + } + IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRVectorType* vectorType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; + } + } + List<IRInst*> vectors; + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + auto colVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); + vectors.add(colVector); + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + auto rowVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); + vectors.add(rowVector); + } + } + + auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &vectorArray); + builder.emitReturn(result); + return func; + } + + IRFunc* createArrayUnpackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto packedParam = builder.emitParam(structType); + auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto packedElement = builder.emitElementExtract(packedArray, ii); + auto originalElement = builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement); + args[(Index)ii] = originalElement; + } + auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createArrayPackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto originalParam = builder.emitParam(arrayType); + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto originalElement = builder.emitElementExtract(originalParam, ii); + auto packedElement = builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement); + args[(Index)ii] = packedElement; + } + auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &packedArray); + builder.emitReturn(result); + return func; + } + + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type) + { + IRBuilder builder(type); + builder.setInsertAfter(type); + + LoweredElementTypeInfo info; + info.originalType = type; + + if (auto matrixType = as<IRMatrixType>(type)) + { + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto vectorType = builder.getVectorType(matrixType->getElementType(), + isColMajor?matrixType->getRowCount():matrixType->getColumnCount()); + auto arrayType = builder.getArrayType(vectorType, isColMajor?matrixType->getColumnCount():matrixType->getRowCount()); + builder.createStructField(loweredType, structKey, arrayType); + + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); + return info; + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType()); + + if (loweredInnerTypeInfo.loweredType != loweredInnerTypeInfo.originalType) + { + auto loweredType = builder.createStructType(); + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_ArrayStorage_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto innerArrayType = builder.getArrayType(loweredInnerTypeInfo.loweredType, arrayType->getElementCount()); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo); + info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); + } + else + { + info.loweredType = type; + } + return info; + } + else if (auto structType = as<IRStructType>(type)) + { + bool hasNonTrivialField = false; + List<LoweredElementTypeInfo> fieldLoweredTypeInfo; + for (auto field : structType->getFields()) + { + auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType()); + fieldLoweredTypeInfo.add(loweredFieldTypeInfo); + if (loweredFieldTypeInfo.loweredType != loweredFieldTypeInfo.originalType) + hasNonTrivialField = true; + } + + if (!hasNonTrivialField) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + getTypeNameHint(nameSB, type); + nameSB << "_Storage"; + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + info.loweredType = loweredType; + + // Create fields. + { + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; + builder.createStructField(loweredType, field->getKey(), loweredFieldTypeInfo.loweredType); + fieldId++; + } + } + + // Create unpack func. + { + builder.setInsertAfter(loweredType); + info.convertLoweredToOriginal = builder.createFunc(); + builder.setInsertInto(info.convertLoweredToOriginal); + builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage")); + info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&loweredType, type)); + builder.emitBlock(); + auto loweredParam = builder.emitParam(loweredType); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto storageField = builder.emitFieldExtract(fieldLoweredTypeInfo[fieldId].loweredType, loweredParam, field->getKey()); + auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal + ? builder.emitCallInst(field->getFieldType(), fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, 1, &storageField) + : storageField; + args.add(unpackedField); + fieldId++; + } + auto result = builder.emitMakeStruct(type, args); + builder.emitReturn(result); + } + + // Create pack func. + { + builder.setInsertAfter(info.convertLoweredToOriginal); + info.convertOriginalToLowered = builder.createFunc(); + builder.setInsertInto(info.convertOriginalToLowered); + builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage")); + info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, loweredType)); + builder.emitBlock(); + auto param = builder.emitParam(type); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto fieldVal = builder.emitFieldExtract(type, param, field->getKey()); + auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered + ? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal) + : fieldVal; + args.add(packedField); + fieldId++; + } + auto result = builder.emitMakeStruct(loweredType, args); + builder.emitReturn(result); + } + + return info; + } + + info.loweredType = type; + return info; + } + + LoweredElementTypeInfo getLoweredTypeInfo(IRType* type) + { + LoweredElementTypeInfo info; + if (loweredTypeInfo.tryGetValue(type, info)) + return info; + info = getLoweredTypeInfoImpl(type); + loweredTypeInfo[type] = info; + mapLoweredTypeToInfo[info.loweredType] = info; + return info; + } + + IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) + { + if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType)) + { + IRBuilder builder(newElementType); + builder.setInsertAfter(newElementType); + return builder.getType(originalPtrLikeType->getOp(), newElementType); + } + SLANG_UNREACHABLE("unhandled ptr like or buffer type"); + } + + IRInst* getStoreVal(IRInst* storeInst) + { + if (auto store = as<IRStore>(storeInst)) + return store->getVal(); + else if (auto sbStore = as<IRRWStructuredBufferStore>(storeInst)) + return sbStore->getVal(); + return nullptr; + } + + void processModule(IRModule* module) + { + IRBuilder builder(module); + struct BufferTypeInfo + { + IRType* bufferType; + IRType* elementType; + }; + List<BufferTypeInfo> bufferTypeInsts; + for (auto globalInst : module->getGlobalInsts()) + { + IRType* elementType = nullptr; + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); + if (as<IRTextureBufferType>(globalInst)) + continue; + if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType)) + continue; + bufferTypeInsts.add(BufferTypeInfo{ (IRType*)globalInst, elementType }); + } + + // Maintain a pending work list of all matrix addresses, and try to lower them out of existance + // after everything else has been lowered. + List<IRInst*> matrixAddrInsts; + + for (auto bufferTypeInfo : bufferTypeInsts) + { + auto bufferType = bufferTypeInfo.bufferType; + auto elementType = bufferTypeInfo.elementType; + auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType); + + // If the lowered type is the same as original type, no change is required. + if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) + continue; + + builder.setInsertBefore(bufferType); + + auto loweredBufferType = builder.getType( + bufferType->getOp(), + loweredBufferElementTypeInfo.loweredType); + + // We treat a value of a buffer type as a pointer, and use a work list to translate + // all loads and stores through the pointer values that needs lowering. + + List<IRInst*> ptrValsWorkList; + traverseUses(bufferType, [&](IRUse* use) + { + auto user = use->getUser(); + if (use != &user->typeUse) + return; + ptrValsWorkList.add(use->getUser()); + }); + + // Translate the values to use new lowered buffer type instead. + for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + { + auto ptrVal = ptrValsWorkList[i]; + auto oldPtrType = ptrVal->getFullType(); + auto originalElementType = oldPtrType->getOperand(0); + auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType); + if (!loweredElementTypeInfo.convertLoweredToOriginal) + continue; + + ptrVal->setFullType(getLoweredPtrLikeType(ptrVal->getFullType(), loweredElementTypeInfo.loweredType)); + + traverseUses(ptrVal, [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + auto unpackedVal = builder.emitCallInst(elementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); + user->replaceUsesWith(unpackedVal); + user->removeAndDeallocate(); + break; + } + case kIROp_Store: + case kIROp_RWStructuredBufferStore: + { + // Use must be the dest operand of the store inst. + if (use != user->getOperands() + 0) + break; + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto originalVal = getStoreVal(user); + auto packedVal = builder.emitCallInst(loweredElementTypeInfo.loweredType, loweredElementTypeInfo.convertOriginalToLowered, 1, &originalVal); + if (auto store = as<IRStore>(user)) + store->val.set(packedVal); + else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) + sbStore->setOperand(2, packedVal); + else + SLANG_UNREACHABLE("unhandled store type"); + break; + } + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + // If original type is an array, the lowered type will be a struct. + // In that case, all existing address insts should be appended with a field extract. + if (as<IRArrayType>(originalElementType)) + { + builder.setInsertBefore(user); + List<IRInst*> args; + for (UInt i = 0; i < user->getOperandCount(); i++) + args.add(user->getOperand(i)); + auto newArrayPtrVal = builder.emitFieldAddress( + builder.getPtrType(loweredElementTypeInfo.loweredInnerArrayType), + ptrVal, + loweredElementTypeInfo.loweredInnerStructKey); + builder.replaceOperand(use, newArrayPtrVal); + ptrValsWorkList.add(user); + } + else if (as<IRMatrixType>(originalElementType)) + { + // We are tring to get a pointer to a lowered matrix element. + // We process this insts at a later phase. + SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); + matrixAddrInsts.add(user); + } + else + { + // If we getting a derived address from the pointer, we need to recursively + // lower the new address. We do so by pushing the address inst into the + // work list. + ptrValsWorkList.add(user); + } + } + break; + case kIROp_RWStructuredBufferGetElementPtr: + ptrValsWorkList.add(user); + break; + default: + SLANG_UNREACHABLE("unhandled inst of a buffer/pointer value that needs storage lowering."); + break; + } + }); + } + + // Replace all remaining uses of bufferType to loweredBufferType, these uses are non-operational and should be + // directly replaceable, such as uses in `IRFuncType`. + bufferType->replaceUsesWith(loweredBufferType); + bufferType->removeAndDeallocate(); + } + + lowerMatrixAddresses(module, matrixAddrInsts); + } + + // Lower all getElementPtr insts of a lowered matrix out of existance. + void lowerMatrixAddresses(IRModule* module, List<IRInst*>& matrixAddrInsts) + { + IRBuilder builder(module); + for (auto majorAddr : matrixAddrInsts) + { + auto majorGEP = as<IRGetElementPtr>(majorAddr); + SLANG_ASSERT(majorGEP); + auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); + auto matrixTypeInfo = mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); + SLANG_ASSERT(matrixTypeInfo); + auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); + auto rowCount = getIntVal(matrixType->getRowCount()); + traverseUses(majorAddr, [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) + { + case kIROp_Load: + { + IRInst* resultInst = nullptr; + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + List<IRInst*> args; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vector = builder.emitLoad(builder.emitElementAddress(dataPtr, i)); + auto element = builder.emitElementExtract(vector, majorGEP->getIndex()); + args.add(element); + } + resultInst = builder.emitMakeVector(builder.getVectorType(matrixType->getElementType(), (IRIntegerValue)args.getCount()), args); + } + else + { + auto element = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + resultInst = builder.emitLoad(element); + } + user->replaceUsesWith(resultInst); + user->removeAndDeallocate(); + } + break; + case kIROp_Store: + { + auto storeInst = cast<IRStore>(user); + if (storeInst->getOperand(0) != majorAddr) + break; + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vectorAddr = builder.emitElementAddress(dataPtr, i); + auto elementAddr = builder.emitElementAddress(vectorAddr, majorGEP->getIndex()); + builder.emitStore(elementAddr, builder.emitElementExtract(storeInst->getVal(), i)); + } + } + else + { + auto rowAddr = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + builder.emitStore(rowAddr, storeInst->getVal()); + user->removeAndDeallocate(); + } + break; + } + case kIROp_GetElementPtr: + { + auto gep2 = cast<IRGetElementPtr>(user); + auto rowIndex = majorGEP->getIndex(); + auto colIndex = gep2->getIndex(); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + Swap(rowIndex, colIndex); + } + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); + auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); + gep2->replaceUsesWith(elementAddr); + gep2->removeAndDeallocate(); + break; + } + default: + SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs storage lowering."); + break; + } + }); + } + } + }; + + void lowerBufferElementTypeToStorageType(TargetRequest* target, IRModule* module) + { + SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getDefaultMatrixLayoutMode(); + if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + LoweredElementTypeContext context(defaultMatrixMode); + context.processModule(module); + } +} |
