diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-10-05 04:03:27 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-05 01:03:27 +0000 |
| commit | 04093bcbaea9784cdffe55f3931f50db7ad9f808 (patch) | |
| tree | f612a78fb4ac1342af881b9b5acfd89fe8e3f843 /source | |
| parent | 3375cde1add65894b8f2e2780cc91ab4ccf6d8fb (diff) | |
Matrix legalization for missing instructions & MakeMatrix of vectors (#8605)
Fixes these issues:
* During matrix legalization, `MakeMatrix` crashed if it was given a
list of vectors instead of individual elements.
* Matrix casts, IRem, and Frem would be emitted using arrays, e.g.
`IntToFloatCast` with `float2[2]` parameters.
I found these bugs while enabling various `hlsl-intrinsic` tests for the
LLVM target. For now, I've chose to get rid of all matrix types with the
matrix legalization pass so that the LLVM emitter doesn't need to be
aware. These bugs were preventing
`tests/hlsl-intrinsic/matrix-double-reduced-intrinsic.slang` and
`tests/hlsl-intrinsic/matrix-double.slang` from passing there.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 327fa7ead..56d7bf207 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -126,27 +126,43 @@ struct MatrixTypeLoweringContext UInt operandIndex = 0; // Assert that we have the expected number of operands - SLANG_ASSERT( - makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && - "makeMatrix operand count must match matrix dimensions"); + if (makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue())) + { + // Each operand is a matrix element + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) + { + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && + "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } - for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + } + else if (makeMatrix->getOperandCount() == UInt(rowCount->getValue())) { - List<IRInst*> rowElements; - for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) + // Each operand is a vector with width columnCount->getValue(). + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { + auto rowVector = getReplacement(makeMatrix->getOperand(row)); + auto vecType = as<IRVectorType>(rowVector->getDataType()); SLANG_ASSERT( - operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); - rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); - operandIndex++; + getIntVal(vecType->getElementCount()) == columnCount->getValue() && + "Row elements count must match column count"); + rowVectors.add(rowVector); } - - SLANG_ASSERT( - rowElements.getCount() == columnCount->getValue() && - "Row elements count must match column count"); - auto rowVector = builder.emitMakeVector(vectorType, rowElements); - rowVectors.add(rowVector); } + else + SLANG_ASSERT_FAILURE("makeMatrix operand count must match matrix dimensions"); SLANG_ASSERT( rowVectors.getCount() == rowCount->getValue() && @@ -509,6 +525,8 @@ struct MatrixTypeLoweringContext case kIROp_Sub: case kIROp_Mul: case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: case kIROp_Lsh: case kIROp_Rsh: case kIROp_And: @@ -527,6 +545,10 @@ struct MatrixTypeLoweringContext case kIROp_Not: case kIROp_BitNot: case kIROp_Neg: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: return legalizeUnaryOperation(inst, inst->getOp()); default: break; @@ -594,4 +616,4 @@ void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, Diagnos context.processModule(); } -} // namespace Slang
\ No newline at end of file +} // namespace Slang |
