summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-buffer-element-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-14 16:23:19 -0700
committerGitHub <noreply@github.com>2023-08-14 16:23:19 -0700
commit661d6198bbb9857d3fdc6df477e0742ed0b0765c (patch)
tree974a57cfa2e43624e91502e9e652a0cc78105b3a /source/slang/slang-ir-lower-buffer-element-type.cpp
parent0403e0556b470f6b316153caea2dc6f5c314da5b (diff)
Support per field matrix layout (#3101)
* Support per field matrix layout * Fix warnings. * Fix. * Fix tests. * Fix spiv gen. * Fix. * More test fixes. * Fix. * Run only GPU tests on self-hosted servers. * Remove -use-glsl-matrix-layout-modifier. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-buffer-element-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp658
1 files changed, 658 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
new file mode 100644
index 000000000..3ef94d415
--- /dev/null
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -0,0 +1,658 @@
+#include "slang-ir-lower-buffer-element-type.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+
+namespace Slang
+{
+ struct LoweredElementTypeContext
+ {
+ struct LoweredElementTypeInfo
+ {
+ IRType* originalType;
+ IRType* loweredType;
+ IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field.
+ IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field.
+ IRFunc* convertOriginalToLowered = nullptr;
+ IRFunc* convertLoweredToOriginal = nullptr;
+ };
+ Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo;
+ Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo;
+
+ SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+
+ LoweredElementTypeContext(SlangMatrixLayoutMode inDefaultMatrixLayout)
+ : defaultMatrixLayout(inDefaultMatrixLayout)
+ {}
+
+ IRFunc* createMatrixUnpackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRStructKey* dataKey,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = (Index)getIntVal(matrixType->getRowCount());
+ auto colCount = (Index)getIntVal(matrixType->getColumnCount());
+ auto packedParam = builder.emitParam(structType);
+ auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey);
+ List<IRInst*> args;
+ args.setCount(rowCount * colCount);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, c);
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = builder.emitElementExtract(vector, r);
+ args[(Index)(r*colCount + c)] = element;
+ }
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(vectorArray, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ args[(Index)(r * colCount + c)] = element;
+ }
+ }
+ }
+ IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createMatrixPackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRVectorType* vectorType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto originalParam = builder.emitParam(matrixType);
+ List<IRInst*> elements;
+ elements.setCount((Index)(rowCount * colCount));
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(originalParam, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ elements[(Index)(r * colCount + c)] = element;
+ }
+ }
+ List<IRInst*> vectors;
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ auto colVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer());
+ vectors.add(colVector);
+ }
+ }
+ else
+ {
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ List<IRInst*> vecArgs;
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = elements[(Index)(r * colCount + c)];
+ vecArgs.add(element);
+ }
+ auto rowVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer());
+ vectors.add(rowVector);
+ }
+ }
+
+ auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &vectorArray);
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createArrayUnpackFunc(
+ IRArrayType* arrayType,
+ IRStructType* structType,
+ IRStructKey* dataKey,
+ IRArrayType* innerArrayType,
+ LoweredElementTypeInfo innerTypeInfo)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto packedParam = builder.emitParam(structType);
+ auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey);
+ auto count = getIntVal(arrayType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto packedElement = builder.emitElementExtract(packedArray, ii);
+ auto originalElement = builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement);
+ args[(Index)ii] = originalElement;
+ }
+ auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createArrayPackFunc(
+ IRArrayType* arrayType,
+ IRStructType* structType,
+ IRArrayType* innerArrayType,
+ LoweredElementTypeInfo innerTypeInfo)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packStorage"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto originalParam = builder.emitParam(arrayType);
+ auto count = getIntVal(arrayType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto originalElement = builder.emitElementExtract(originalParam, ii);
+ auto packedElement = builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement);
+ args[(Index)ii] = packedElement;
+ }
+ auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &packedArray);
+ builder.emitReturn(result);
+ return func;
+ }
+
+ LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type)
+ {
+ IRBuilder builder(type);
+ builder.setInsertAfter(type);
+
+ LoweredElementTypeInfo info;
+ info.originalType = type;
+
+ if (auto matrixType = as<IRMatrixType>(type))
+ {
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout)
+ {
+ info.loweredType = type;
+ return info;
+ }
+
+ auto loweredType = builder.createStructType();
+ StringBuilder nameSB;
+ bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ nameSB << "_MatrixStorage_";
+ getTypeNameHint(nameSB, matrixType->getElementType());
+ nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount());
+ if (isColMajor)
+ nameSB << "_ColMajor";
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ auto vectorType = builder.getVectorType(matrixType->getElementType(),
+ isColMajor?matrixType->getRowCount():matrixType->getColumnCount());
+ auto arrayType = builder.getArrayType(vectorType, isColMajor?matrixType->getColumnCount():matrixType->getRowCount());
+ builder.createStructField(loweredType, structKey, arrayType);
+
+ info.loweredType = loweredType;
+ info.loweredInnerArrayType = arrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType);
+ info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType);
+ return info;
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType());
+
+ if (loweredInnerTypeInfo.loweredType != loweredInnerTypeInfo.originalType)
+ {
+ auto loweredType = builder.createStructType();
+ info.loweredType = loweredType;
+ StringBuilder nameSB;
+ nameSB << "_ArrayStorage_";
+ getTypeNameHint(nameSB, arrayType->getElementType());
+ nameSB << getIntVal(arrayType->getElementCount());
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder.createStructKey();
+ builder.addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ auto innerArrayType = builder.getArrayType(loweredInnerTypeInfo.loweredType, arrayType->getElementCount());
+ builder.createStructField(loweredType, structKey, innerArrayType);
+ info.loweredInnerArrayType = innerArrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo);
+ info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo);
+ }
+ else
+ {
+ info.loweredType = type;
+ }
+ return info;
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ bool hasNonTrivialField = false;
+ List<LoweredElementTypeInfo> fieldLoweredTypeInfo;
+ for (auto field : structType->getFields())
+ {
+ auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType());
+ fieldLoweredTypeInfo.add(loweredFieldTypeInfo);
+ if (loweredFieldTypeInfo.loweredType != loweredFieldTypeInfo.originalType)
+ hasNonTrivialField = true;
+ }
+
+ if (!hasNonTrivialField)
+ {
+ info.loweredType = type;
+ return info;
+ }
+
+ auto loweredType = builder.createStructType();
+ StringBuilder nameSB;
+ getTypeNameHint(nameSB, type);
+ nameSB << "_Storage";
+ builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ info.loweredType = loweredType;
+
+ // Create fields.
+ {
+ Index fieldId = 0;
+ for (auto field : structType->getFields())
+ {
+ auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId];
+ builder.createStructField(loweredType, field->getKey(), loweredFieldTypeInfo.loweredType);
+ fieldId++;
+ }
+ }
+
+ // Create unpack func.
+ {
+ builder.setInsertAfter(loweredType);
+ info.convertLoweredToOriginal = builder.createFunc();
+ builder.setInsertInto(info.convertLoweredToOriginal);
+ builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage"));
+ info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&loweredType, type));
+ builder.emitBlock();
+ auto loweredParam = builder.emitParam(loweredType);
+ List<IRInst*> args;
+ Index fieldId = 0;
+ for (auto field : structType->getFields())
+ {
+ auto storageField = builder.emitFieldExtract(fieldLoweredTypeInfo[fieldId].loweredType, loweredParam, field->getKey());
+ auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal
+ ? builder.emitCallInst(field->getFieldType(), fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, 1, &storageField)
+ : storageField;
+ args.add(unpackedField);
+ fieldId++;
+ }
+ auto result = builder.emitMakeStruct(type, args);
+ builder.emitReturn(result);
+ }
+
+ // Create pack func.
+ {
+ builder.setInsertAfter(info.convertLoweredToOriginal);
+ info.convertOriginalToLowered = builder.createFunc();
+ builder.setInsertInto(info.convertOriginalToLowered);
+ builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage"));
+ info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, loweredType));
+ builder.emitBlock();
+ auto param = builder.emitParam(type);
+ List<IRInst*> args;
+ Index fieldId = 0;
+ for (auto field : structType->getFields())
+ {
+ auto fieldVal = builder.emitFieldExtract(type, param, field->getKey());
+ auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered
+ ? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal)
+ : fieldVal;
+ args.add(packedField);
+ fieldId++;
+ }
+ auto result = builder.emitMakeStruct(loweredType, args);
+ builder.emitReturn(result);
+ }
+
+ return info;
+ }
+
+ info.loweredType = type;
+ return info;
+ }
+
+ LoweredElementTypeInfo getLoweredTypeInfo(IRType* type)
+ {
+ LoweredElementTypeInfo info;
+ if (loweredTypeInfo.tryGetValue(type, info))
+ return info;
+ info = getLoweredTypeInfoImpl(type);
+ loweredTypeInfo[type] = info;
+ mapLoweredTypeToInfo[info.loweredType] = info;
+ return info;
+ }
+
+ IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType)
+ {
+ if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType))
+ {
+ IRBuilder builder(newElementType);
+ builder.setInsertAfter(newElementType);
+ return builder.getType(originalPtrLikeType->getOp(), newElementType);
+ }
+ SLANG_UNREACHABLE("unhandled ptr like or buffer type");
+ }
+
+ IRInst* getStoreVal(IRInst* storeInst)
+ {
+ if (auto store = as<IRStore>(storeInst))
+ return store->getVal();
+ else if (auto sbStore = as<IRRWStructuredBufferStore>(storeInst))
+ return sbStore->getVal();
+ return nullptr;
+ }
+
+ void processModule(IRModule* module)
+ {
+ IRBuilder builder(module);
+ struct BufferTypeInfo
+ {
+ IRType* bufferType;
+ IRType* elementType;
+ };
+ List<BufferTypeInfo> bufferTypeInsts;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ IRType* elementType = nullptr;
+ if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst))
+ elementType = structBuffer->getElementType();
+ else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
+ elementType = constBuffer->getElementType();
+ if (as<IRTextureBufferType>(globalInst))
+ continue;
+ if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType))
+ continue;
+ bufferTypeInsts.add(BufferTypeInfo{ (IRType*)globalInst, elementType });
+ }
+
+ // Maintain a pending work list of all matrix addresses, and try to lower them out of existance
+ // after everything else has been lowered.
+ List<IRInst*> matrixAddrInsts;
+
+ for (auto bufferTypeInfo : bufferTypeInsts)
+ {
+ auto bufferType = bufferTypeInfo.bufferType;
+ auto elementType = bufferTypeInfo.elementType;
+ auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType);
+
+ // If the lowered type is the same as original type, no change is required.
+ if (!loweredBufferElementTypeInfo.convertLoweredToOriginal)
+ continue;
+
+ builder.setInsertBefore(bufferType);
+
+ auto loweredBufferType = builder.getType(
+ bufferType->getOp(),
+ loweredBufferElementTypeInfo.loweredType);
+
+ // We treat a value of a buffer type as a pointer, and use a work list to translate
+ // all loads and stores through the pointer values that needs lowering.
+
+ List<IRInst*> ptrValsWorkList;
+ traverseUses(bufferType, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (use != &user->typeUse)
+ return;
+ ptrValsWorkList.add(use->getUser());
+ });
+
+ // Translate the values to use new lowered buffer type instead.
+ for (Index i = 0; i < ptrValsWorkList.getCount(); i++)
+ {
+ auto ptrVal = ptrValsWorkList[i];
+ auto oldPtrType = ptrVal->getFullType();
+ auto originalElementType = oldPtrType->getOperand(0);
+ auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType);
+ if (!loweredElementTypeInfo.convertLoweredToOriginal)
+ continue;
+
+ ptrVal->setFullType(getLoweredPtrLikeType(ptrVal->getFullType(), loweredElementTypeInfo.loweredType));
+
+ traverseUses(ptrVal, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_Load:
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferLoadStatus:
+ case kIROp_RWStructuredBufferLoad:
+ case kIROp_RWStructuredBufferLoadStatus:
+ {
+ IRCloneEnv cloneEnv = {};
+ builder.setInsertBefore(user);
+ auto newLoad = cloneInst(&cloneEnv, &builder, user);
+ newLoad->setFullType(loweredElementTypeInfo.loweredType);
+ auto unpackedVal = builder.emitCallInst(elementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad);
+ user->replaceUsesWith(unpackedVal);
+ user->removeAndDeallocate();
+ break;
+ }
+ case kIROp_Store:
+ case kIROp_RWStructuredBufferStore:
+ {
+ // Use must be the dest operand of the store inst.
+ if (use != user->getOperands() + 0)
+ break;
+ IRCloneEnv cloneEnv = {};
+ builder.setInsertBefore(user);
+ auto originalVal = getStoreVal(user);
+ auto packedVal = builder.emitCallInst(loweredElementTypeInfo.loweredType, loweredElementTypeInfo.convertOriginalToLowered, 1, &originalVal);
+ if (auto store = as<IRStore>(user))
+ store->val.set(packedVal);
+ else if (auto sbStore = as<IRRWStructuredBufferStore>(user))
+ sbStore->setOperand(2, packedVal);
+ else
+ SLANG_UNREACHABLE("unhandled store type");
+ break;
+ }
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ // If original type is an array, the lowered type will be a struct.
+ // In that case, all existing address insts should be appended with a field extract.
+ if (as<IRArrayType>(originalElementType))
+ {
+ builder.setInsertBefore(user);
+ List<IRInst*> args;
+ for (UInt i = 0; i < user->getOperandCount(); i++)
+ args.add(user->getOperand(i));
+ auto newArrayPtrVal = builder.emitFieldAddress(
+ builder.getPtrType(loweredElementTypeInfo.loweredInnerArrayType),
+ ptrVal,
+ loweredElementTypeInfo.loweredInnerStructKey);
+ builder.replaceOperand(use, newArrayPtrVal);
+ ptrValsWorkList.add(user);
+ }
+ else if (as<IRMatrixType>(originalElementType))
+ {
+ // We are tring to get a pointer to a lowered matrix element.
+ // We process this insts at a later phase.
+ SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr);
+ matrixAddrInsts.add(user);
+ }
+ else
+ {
+ // If we getting a derived address from the pointer, we need to recursively
+ // lower the new address. We do so by pushing the address inst into the
+ // work list.
+ ptrValsWorkList.add(user);
+ }
+ }
+ break;
+ case kIROp_RWStructuredBufferGetElementPtr:
+ ptrValsWorkList.add(user);
+ break;
+ default:
+ SLANG_UNREACHABLE("unhandled inst of a buffer/pointer value that needs storage lowering.");
+ break;
+ }
+ });
+ }
+
+ // Replace all remaining uses of bufferType to loweredBufferType, these uses are non-operational and should be
+ // directly replaceable, such as uses in `IRFuncType`.
+ bufferType->replaceUsesWith(loweredBufferType);
+ bufferType->removeAndDeallocate();
+ }
+
+ lowerMatrixAddresses(module, matrixAddrInsts);
+ }
+
+ // Lower all getElementPtr insts of a lowered matrix out of existance.
+ void lowerMatrixAddresses(IRModule* module, List<IRInst*>& matrixAddrInsts)
+ {
+ IRBuilder builder(module);
+ for (auto majorAddr : matrixAddrInsts)
+ {
+ auto majorGEP = as<IRGetElementPtr>(majorAddr);
+ SLANG_ASSERT(majorGEP);
+ auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType();
+ auto matrixTypeInfo = mapLoweredTypeToInfo.tryGetValue(loweredMatrixType);
+ SLANG_ASSERT(matrixTypeInfo);
+ auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType);
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ traverseUses(majorAddr, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ builder.setInsertBefore(user);
+ switch (user->getOp())
+ {
+ case kIROp_Load:
+ {
+ IRInst* resultInst = nullptr;
+ auto dataPtr = builder.emitFieldAddress(
+ builder.getPtrType(matrixTypeInfo->loweredInnerArrayType),
+ majorGEP->getBase(),
+ matrixTypeInfo->loweredInnerStructKey);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ auto vector = builder.emitLoad(builder.emitElementAddress(dataPtr, i));
+ auto element = builder.emitElementExtract(vector, majorGEP->getIndex());
+ args.add(element);
+ }
+ resultInst = builder.emitMakeVector(builder.getVectorType(matrixType->getElementType(), (IRIntegerValue)args.getCount()), args);
+ }
+ else
+ {
+ auto element = builder.emitElementAddress(dataPtr, majorGEP->getIndex());
+ resultInst = builder.emitLoad(element);
+ }
+ user->replaceUsesWith(resultInst);
+ user->removeAndDeallocate();
+ }
+ break;
+ case kIROp_Store:
+ {
+ auto storeInst = cast<IRStore>(user);
+ if (storeInst->getOperand(0) != majorAddr)
+ break;
+ auto dataPtr = builder.emitFieldAddress(
+ builder.getPtrType(matrixTypeInfo->loweredInnerArrayType),
+ majorGEP->getBase(),
+ matrixTypeInfo->loweredInnerStructKey);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ auto vectorAddr = builder.emitElementAddress(dataPtr, i);
+ auto elementAddr = builder.emitElementAddress(vectorAddr, majorGEP->getIndex());
+ builder.emitStore(elementAddr, builder.emitElementExtract(storeInst->getVal(), i));
+ }
+ }
+ else
+ {
+ auto rowAddr = builder.emitElementAddress(dataPtr, majorGEP->getIndex());
+ builder.emitStore(rowAddr, storeInst->getVal());
+ user->removeAndDeallocate();
+ }
+ break;
+ }
+ case kIROp_GetElementPtr:
+ {
+ auto gep2 = cast<IRGetElementPtr>(user);
+ auto rowIndex = majorGEP->getIndex();
+ auto colIndex = gep2->getIndex();
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ Swap(rowIndex, colIndex);
+ }
+ auto dataPtr = builder.emitFieldAddress(
+ builder.getPtrType(matrixTypeInfo->loweredInnerArrayType),
+ majorGEP->getBase(),
+ matrixTypeInfo->loweredInnerStructKey);
+ auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex);
+ auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex);
+ gep2->replaceUsesWith(elementAddr);
+ gep2->removeAndDeallocate();
+ break;
+ }
+ default:
+ SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs storage lowering.");
+ break;
+ }
+ });
+ }
+ }
+ };
+
+ void lowerBufferElementTypeToStorageType(TargetRequest* target, IRModule* module)
+ {
+ SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getDefaultMatrixLayoutMode();
+ if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)
+ defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR;
+ LoweredElementTypeContext context(defaultMatrixMode);
+ context.processModule(module);
+ }
+}