diff options
| author | venkataram-nv <vedavamadath@nvidia.com> | 2025-07-18 09:38:00 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-18 16:38:00 +0000 |
| commit | 48b6e2432ea28c06d04931fccd633e31eed6d995 (patch) | |
| tree | b976380fd3464b231275e0ae2c1c6ac8af1bb6c3 /source/slang/slang-emit-spirv.cpp | |
| parent | 85edfb178cd243134f4bb3d35ad71f154d76c81c (diff) | |
Lower int/uint/bool matrices to arrays for SPIRV (#7687)
* Add tests for expected behaviour
* Allow matrix types in logical or/and
* Legalize int/bool matrix types and construction with makeMatrix
* Legalize uint matrices and operations
* Limit testing to only SPIRV
* Better tests for int and bool
* Add test for uint
* Remove GLSL tests
* Remove old test for diagnosing int matrices
* Emit SPIRV directly in tests
* format code
* Address PR comments
* Improve testing
* Address PR comments
* format code
* Add tests for matrix intrinsic operations
* Move matrix lowering to dedicated legalization pass
* Fix compiler warning
* Remove signal again
* Reorder matrix and vector legalization
* Fix formatting
* Add shift and comparison tests
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 172 |
1 files changed, 129 insertions, 43 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2b6f1c821..bbed44c51 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -216,7 +216,7 @@ struct SpvInst : SpvInstParent // // > Word Count: The complete number of words taken by an instruction, // > including the word holding the word count and opcode, and any optional - // > operands. An instruction’s word count is the total space taken by the instruction. + // > operands. An instruction's word count is the total space taken by the instruction. // SpvWord wordCount = 1 + SpvWord(operandWordsCount); @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string’s nul-termination character (0), and + // > The final word contains the string's nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,17 +2039,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); + auto elementType = matrixType->getElementType(); + + // SPIR-V only supports floating-point matrices + // bool/int matrices should be lowered to + // arrays of vectors before reaching here + SLANG_ASSERT(!as<IRBoolType>(elementType)); + SLANG_ASSERT(!as<IRIntType>(elementType)); + SLANG_ASSERT(!as<IRUIntType>(elementType)); + auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), + static_cast<IRBasicType*>(elementType)->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - auto matrixSPVType = emitOpTypeMatrix( - inst, - vectorSpvType, - SpvLiteralInteger::from32(int32_t(columnCount))); - return matrixSPVType; + const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); + SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); + return matrixSpvType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2621,7 +2628,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7767,12 +7774,40 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - auto matrixType = as<IRMatrixType>(inst->getDataType()); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); + + IRIntegerValue rowCount; + IRIntegerValue colCount; + IRType* elementType; + + // Data type can be either matrix or vector depending on the + // legalization requirements + auto dataType = inst->getDataType(); + + if (auto matrixType = as<IRMatrixType>(dataType)) + { + elementType = matrixType->getElementType(); + rowCount = getIntVal(matrixType->getRowCount()); + colCount = getIntVal(matrixType->getColumnCount()); + } + else if (auto arrayType = as<IRArrayType>(dataType)) + { + auto vectorType = as<IRVectorType>(arrayType->getElementType()); + SLANG_ASSERT(vectorType); + + elementType = vectorType->getElementType(); + rowCount = getIntVal(arrayType->getElementCount()); + colCount = getIntVal(vectorType->getElementCount()); + } + else + { + SLANG_UNEXPECTED("data type for makeMatrix operation is " + "expected be either a matrix or array type"); + } + IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7897,7 +7932,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); + SLANG_ASSERT(elementType); + IRBasicType* basicType = as<IRBasicType>(elementType); + SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -7958,6 +7996,52 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } + // Helper method to handle composite arithmetic operations for matrices and arrays + SpvInst* emitCompositeArithmetic( + SpvInstParent* parent, + IRInst* inst, + IRIntegerValue rowCount, + IRIntegerValue colCount, + IRType* elementType, + IRType* resultType, + bool isMatrixType) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List<SpvInst*> rows; + + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + bool shouldExtract = + isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr + : as<IRArrayType>(originalOperand->getDataType()) != nullptr; + + if (shouldExtract) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, resultType, rows); + } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -7965,36 +8049,38 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); - List<SpvInst*> rows; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - if (as<IRMatrixType>(originalOperand->getDataType())) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + matrixType->getElementType(), + inst->getDataType(), + true); + } + else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) + { + // Only for legalization + auto arrayElementType = arrayType->getElementType(); + SLANG_ASSERT(as<IRVectorType>(arrayElementType)); + + auto vectorType = as<IRVectorType>(arrayElementType); + auto elementType = vectorType->getElementType(); + SLANG_ASSERT( + as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || + as<IRIntType>(elementType)); + + auto rowCount = getIntVal(arrayType->getElementCount()); + auto colCount = getIntVal(vectorType->getElementCount()); + + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + elementType, + inst->getDataType(), + false); } Array<IRInst*, 4> operands; |
