diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 327fa7ead..56d7bf207 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -126,27 +126,43 @@ struct MatrixTypeLoweringContext UInt operandIndex = 0; // Assert that we have the expected number of operands - SLANG_ASSERT( - makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && - "makeMatrix operand count must match matrix dimensions"); + if (makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue())) + { + // Each operand is a matrix element + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) + { + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && + "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } - for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + } + else if (makeMatrix->getOperandCount() == UInt(rowCount->getValue())) { - List<IRInst*> rowElements; - for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) + // Each operand is a vector with width columnCount->getValue(). + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { + auto rowVector = getReplacement(makeMatrix->getOperand(row)); + auto vecType = as<IRVectorType>(rowVector->getDataType()); SLANG_ASSERT( - operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); - rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); - operandIndex++; + getIntVal(vecType->getElementCount()) == columnCount->getValue() && + "Row elements count must match column count"); + rowVectors.add(rowVector); } - - SLANG_ASSERT( - rowElements.getCount() == columnCount->getValue() && - "Row elements count must match column count"); - auto rowVector = builder.emitMakeVector(vectorType, rowElements); - rowVectors.add(rowVector); } + else + SLANG_ASSERT_FAILURE("makeMatrix operand count must match matrix dimensions"); SLANG_ASSERT( rowVectors.getCount() == rowCount->getValue() && @@ -509,6 +525,8 @@ struct MatrixTypeLoweringContext case kIROp_Sub: case kIROp_Mul: case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: case kIROp_Lsh: case kIROp_Rsh: case kIROp_And: @@ -527,6 +545,10 @@ struct MatrixTypeLoweringContext case kIROp_Not: case kIROp_BitNot: case kIROp_Neg: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: return legalizeUnaryOperation(inst, inst->getOp()); default: break; @@ -594,4 +616,4 @@ void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, Diagnos context.processModule(); } -} // namespace Slang
\ No newline at end of file +} // namespace Slang |
