diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 |
1 files changed, 42 insertions, 128 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; |
