From 48b6e2432ea28c06d04931fccd633e31eed6d995 Mon Sep 17 00:00:00 2001 From: venkataram-nv Date: Fri, 18 Jul 2025 09:38:00 -0700 Subject: Lower int/uint/bool matrices to arrays for SPIRV (#7687) * Add tests for expected behaviour * Allow matrix types in logical or/and * Legalize int/bool matrix types and construction with makeMatrix * Legalize uint matrices and operations * Limit testing to only SPIRV * Better tests for int and bool * Add test for uint * Remove GLSL tests * Remove old test for diagnosing int matrices * Emit SPIRV directly in tests * format code * Address PR comments * Improve testing * Address PR comments * format code * Add tests for matrix intrinsic operations * Move matrix lowering to dedicated legalization pass * Fix compiler warning * Remove signal again * Reorder matrix and vector legalization * Fix formatting * Add shift and comparison tests --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/slang-emit-spirv.cpp | 172 ++++++++++++++++++++++++++++---------- 1 file changed, 129 insertions(+), 43 deletions(-) (limited to 'source/slang/slang-emit-spirv.cpp') diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2b6f1c821..bbed44c51 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -216,7 +216,7 @@ struct SpvInst : SpvInstParent // // > Word Count: The complete number of words taken by an instruction, // > including the word holding the word count and opcode, and any optional - // > operands. An instruction’s word count is the total space taken by the instruction. + // > operands. An instruction's word count is the total space taken by the instruction. // SpvWord wordCount = 1 + SpvWord(operandWordsCount); @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string’s nul-termination character (0), and + // > The final word contains the string's nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,17 +2039,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast(inst); + auto elementType = matrixType->getElementType(); + + // SPIR-V only supports floating-point matrices + // bool/int matrices should be lowered to + // arrays of vectors before reaching here + SLANG_ASSERT(!as(elementType)); + SLANG_ASSERT(!as(elementType)); + SLANG_ASSERT(!as(elementType)); + auto vectorSpvType = ensureVectorType( - static_cast(matrixType->getElementType())->getBaseType(), + static_cast(elementType)->getBaseType(), static_cast(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast(matrixType->getRowCount())->getValue(); - auto matrixSPVType = emitOpTypeMatrix( - inst, - vectorSpvType, - SpvLiteralInteger::from32(int32_t(columnCount))); - return matrixSPVType; + const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); + SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); + return matrixSpvType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2621,7 +2628,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7767,12 +7774,40 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List rowVectors; - auto matrixType = as(inst->getDataType()); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); + + IRIntegerValue rowCount; + IRIntegerValue colCount; + IRType* elementType; + + // Data type can be either matrix or vector depending on the + // legalization requirements + auto dataType = inst->getDataType(); + + if (auto matrixType = as(dataType)) + { + elementType = matrixType->getElementType(); + rowCount = getIntVal(matrixType->getRowCount()); + colCount = getIntVal(matrixType->getColumnCount()); + } + else if (auto arrayType = as(dataType)) + { + auto vectorType = as(arrayType->getElementType()); + SLANG_ASSERT(vectorType); + + elementType = vectorType->getElementType(); + rowCount = getIntVal(arrayType->getElementCount()); + colCount = getIntVal(vectorType->getElementCount()); + } + else + { + SLANG_UNEXPECTED("data type for makeMatrix operation is " + "expected be either a matrix or array type"); + } + IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7897,7 +7932,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); + SLANG_ASSERT(elementType); + IRBasicType* basicType = as(elementType); + SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -7958,6 +7996,52 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } + // Helper method to handle composite arithmetic operations for matrices and arrays + SpvInst* emitCompositeArithmetic( + SpvInstParent* parent, + IRInst* inst, + IRIntegerValue rowCount, + IRIntegerValue colCount, + IRType* elementType, + IRType* resultType, + bool isMatrixType) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List rows; + + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + bool shouldExtract = + isMatrixType ? as(originalOperand->getDataType()) != nullptr + : as(originalOperand->getDataType()) != nullptr; + + if (shouldExtract) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, resultType, rows); + } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -7965,36 +8049,38 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); - List rows; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - if (as(originalOperand->getDataType())) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + matrixType->getElementType(), + inst->getDataType(), + true); + } + else if (const auto arrayType = as(inst->getDataType())) + { + // Only for legalization + auto arrayElementType = arrayType->getElementType(); + SLANG_ASSERT(as(arrayElementType)); + + auto vectorType = as(arrayElementType); + auto elementType = vectorType->getElementType(); + SLANG_ASSERT( + as(elementType) || as(elementType) || + as(elementType)); + + auto rowCount = getIntVal(arrayType->getElementCount()); + auto colCount = getIntVal(vectorType->getElementCount()); + + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + elementType, + inst->getDataType(), + false); } Array operands; -- cgit v1.2.3