summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp54
1 files changed, 38 insertions, 16 deletions
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
index 327fa7ead..56d7bf207 100644
--- a/source/slang/slang-ir-legalize-matrix-types.cpp
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -126,27 +126,43 @@ struct MatrixTypeLoweringContext
UInt operandIndex = 0;
// Assert that we have the expected number of operands
- SLANG_ASSERT(
- makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) &&
- "makeMatrix operand count must match matrix dimensions");
+ if (makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()))
+ {
+ // Each operand is a matrix element
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ List<IRInst*> rowElements;
+ for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ {
+ SLANG_ASSERT(
+ operandIndex < makeMatrix->getOperandCount() &&
+ "Operand index out of bounds");
+ rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex)));
+ operandIndex++;
+ }
- for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ SLANG_ASSERT(
+ rowElements.getCount() == columnCount->getValue() &&
+ "Row elements count must match column count");
+ auto rowVector = builder.emitMakeVector(vectorType, rowElements);
+ rowVectors.add(rowVector);
+ }
+ }
+ else if (makeMatrix->getOperandCount() == UInt(rowCount->getValue()))
{
- List<IRInst*> rowElements;
- for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ // Each operand is a vector with width columnCount->getValue().
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
{
+ auto rowVector = getReplacement(makeMatrix->getOperand(row));
+ auto vecType = as<IRVectorType>(rowVector->getDataType());
SLANG_ASSERT(
- operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds");
- rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex)));
- operandIndex++;
+ getIntVal(vecType->getElementCount()) == columnCount->getValue() &&
+ "Row elements count must match column count");
+ rowVectors.add(rowVector);
}
-
- SLANG_ASSERT(
- rowElements.getCount() == columnCount->getValue() &&
- "Row elements count must match column count");
- auto rowVector = builder.emitMakeVector(vectorType, rowElements);
- rowVectors.add(rowVector);
}
+ else
+ SLANG_ASSERT_FAILURE("makeMatrix operand count must match matrix dimensions");
SLANG_ASSERT(
rowVectors.getCount() == rowCount->getValue() &&
@@ -509,6 +525,8 @@ struct MatrixTypeLoweringContext
case kIROp_Sub:
case kIROp_Mul:
case kIROp_Div:
+ case kIROp_IRem:
+ case kIROp_FRem:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_And:
@@ -527,6 +545,10 @@ struct MatrixTypeLoweringContext
case kIROp_Not:
case kIROp_BitNot:
case kIROp_Neg:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
return legalizeUnaryOperation(inst, inst->getOp());
default:
break;
@@ -594,4 +616,4 @@ void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, Diagnos
context.processModule();
}
-} // namespace Slang \ No newline at end of file
+} // namespace Slang