diff options
| author | venkataram-nv <vedavamadath@nvidia.com> | 2025-07-30 09:27:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-30 16:27:38 +0000 |
| commit | 92ee2927d0012dd454dff7bb53b900f5240073d5 (patch) | |
| tree | d0a648fbb1e6b08c6eec90fadb23435731c1eefe /source/slang/slang-emit-spirv.cpp | |
| parent | 42dc521f7817328a20e40b3352ae667dfd124edb (diff) | |
Lowering unsupported matrix types for GLSL/WGSL/Metal targets (#7936)
* Add emit cases for WGSL and GLSL
* Fix compilation warnings
Modify short cutting test to reflect change in emit logic
Lower matrix for metal as well
Add emit matrix logic for metal
Fix compiler warning
Brace initializer for lowered matrices
Fix compiler warnings
* Tests for metal
* Fix mult, any, and determinant
* Fix matrix-matrix multiplication
* Fix mat mul to be element-wise
* Fix compiler warning
* Move makeMatrix to legalization
* Move unary and binary arithmetic operator lowering to legalization
* Remove emit logic and move final comparison operators to legalization
* Handle vector/matrix negation for WGSL
* Restore older SPIR-V emit logic
* Address PR comments
* Revert to zero minus for negation
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 |
1 files changed, 42 insertions, 128 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; |
