summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-10-05 04:03:27 +0300
committerGitHub <noreply@github.com>2025-10-05 01:03:27 +0000
commit04093bcbaea9784cdffe55f3931f50db7ad9f808 (patch)
treef612a78fb4ac1342af881b9b5acfd89fe8e3f843 /source/slang
parent3375cde1add65894b8f2e2780cc91ab4ccf6d8fb (diff)
Matrix legalization for missing instructions & MakeMatrix of vectors (#8605)
Fixes these issues: * During matrix legalization, `MakeMatrix` crashed if it was given a list of vectors instead of individual elements. * Matrix casts, IRem, and Frem would be emitted using arrays, e.g. `IntToFloatCast` with `float2[2]` parameters. I found these bugs while enabling various `hlsl-intrinsic` tests for the LLVM target. For now, I've chose to get rid of all matrix types with the matrix legalization pass so that the LLVM emitter doesn't need to be aware. These bugs were preventing `tests/hlsl-intrinsic/matrix-double-reduced-intrinsic.slang` and `tests/hlsl-intrinsic/matrix-double.slang` from passing there.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp54
1 files changed, 38 insertions, 16 deletions
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
index 327fa7ead..56d7bf207 100644
--- a/source/slang/slang-ir-legalize-matrix-types.cpp
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -126,27 +126,43 @@ struct MatrixTypeLoweringContext
UInt operandIndex = 0;
// Assert that we have the expected number of operands
- SLANG_ASSERT(
- makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) &&
- "makeMatrix operand count must match matrix dimensions");
+ if (makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()))
+ {
+ // Each operand is a matrix element
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ List<IRInst*> rowElements;
+ for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ {
+ SLANG_ASSERT(
+ operandIndex < makeMatrix->getOperandCount() &&
+ "Operand index out of bounds");
+ rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex)));
+ operandIndex++;
+ }
- for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ SLANG_ASSERT(
+ rowElements.getCount() == columnCount->getValue() &&
+ "Row elements count must match column count");
+ auto rowVector = builder.emitMakeVector(vectorType, rowElements);
+ rowVectors.add(rowVector);
+ }
+ }
+ else if (makeMatrix->getOperandCount() == UInt(rowCount->getValue()))
{
- List<IRInst*> rowElements;
- for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ // Each operand is a vector with width columnCount->getValue().
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
{
+ auto rowVector = getReplacement(makeMatrix->getOperand(row));
+ auto vecType = as<IRVectorType>(rowVector->getDataType());
SLANG_ASSERT(
- operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds");
- rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex)));
- operandIndex++;
+ getIntVal(vecType->getElementCount()) == columnCount->getValue() &&
+ "Row elements count must match column count");
+ rowVectors.add(rowVector);
}
-
- SLANG_ASSERT(
- rowElements.getCount() == columnCount->getValue() &&
- "Row elements count must match column count");
- auto rowVector = builder.emitMakeVector(vectorType, rowElements);
- rowVectors.add(rowVector);
}
+ else
+ SLANG_ASSERT_FAILURE("makeMatrix operand count must match matrix dimensions");
SLANG_ASSERT(
rowVectors.getCount() == rowCount->getValue() &&
@@ -509,6 +525,8 @@ struct MatrixTypeLoweringContext
case kIROp_Sub:
case kIROp_Mul:
case kIROp_Div:
+ case kIROp_IRem:
+ case kIROp_FRem:
case kIROp_Lsh:
case kIROp_Rsh:
case kIROp_And:
@@ -527,6 +545,10 @@ struct MatrixTypeLoweringContext
case kIROp_Not:
case kIROp_BitNot:
case kIROp_Neg:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
return legalizeUnaryOperation(inst, inst->getOp());
default:
break;
@@ -594,4 +616,4 @@ void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, Diagnos
context.processModule();
}
-} // namespace Slang \ No newline at end of file
+} // namespace Slang