summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp170
1 files changed, 42 insertions, 128 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index b1b4c4570..da2620856 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -360,7 +360,7 @@ struct SpvLiteralBits
// > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed
// > four per word, following the little-endian convention (i.e., the
// > first octet is in the lowest-order 8 bits of the word).
- // > The final word contains the string's nul-termination character (0), and
+ // > The final word contains the string’s nul-termination character (0), and
// > all contents past the end of the string in the final word are padded with 0.
// First work out the amount of words we'll need
@@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
- auto elementType = matrixType->getElementType();
-
- // SPIR-V only supports floating-point matrices
- // bool/int matrices should be lowered to
- // arrays of vectors before reaching here
- SLANG_ASSERT(!as<IRBoolType>(elementType));
- SLANG_ASSERT(!as<IRIntType>(elementType));
- SLANG_ASSERT(!as<IRUIntType>(elementType));
-
auto vectorSpvType = ensureVectorType(
- static_cast<IRBasicType*>(elementType)->getBaseType(),
+ static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
const auto columnCount =
static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
- const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount));
- SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv);
- return matrixSpvType;
+ auto matrixSPVType = emitOpTypeMatrix(
+ inst,
+ vectorSpvType,
+ SpvLiteralInteger::from32(int32_t(columnCount)));
+ return matrixSPVType;
}
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
@@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvWord arrayed =
inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed;
- // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored."
+ // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored."
SpvWord depth =
ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image
SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled
@@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// Otherwise, operands are raw elements, we need to construct row vectors first,
// then construct matrix from row vectors.
List<SpvInst*> rowVectors;
-
- IRIntegerValue rowCount;
- IRIntegerValue colCount;
- IRType* elementType;
-
- // Data type can be either matrix or vector depending on the
- // legalization requirements
- auto dataType = inst->getDataType();
-
- if (auto matrixType = as<IRMatrixType>(dataType))
- {
- elementType = matrixType->getElementType();
- rowCount = getIntVal(matrixType->getRowCount());
- colCount = getIntVal(matrixType->getColumnCount());
- }
- else if (auto arrayType = as<IRArrayType>(dataType))
- {
- auto vectorType = as<IRVectorType>(arrayType->getElementType());
- SLANG_ASSERT(vectorType);
-
- elementType = vectorType->getElementType();
- rowCount = getIntVal(arrayType->getElementCount());
- colCount = getIntVal(vectorType->getElementCount());
- }
- else
- {
- SLANG_UNEXPECTED("data type for makeMatrix operation is "
- "expected be either a matrix or array type");
- }
-
+ auto matrixType = cast<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
IRBuilder builder(inst);
builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(elementType, colCount);
-
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
List<IRInst*> colElements;
UInt index = 0;
for (IRIntegerValue j = 0; j < rowCount; j++)
@@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
ArrayView<IRInst*> operands)
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
- SLANG_ASSERT(elementType);
-
IRBasicType* basicType = as<IRBasicType>(elementType);
- SLANG_ASSERT(basicType);
SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
if (opCode == SpvOpUndef)
@@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
- // Helper method to handle composite arithmetic operations for matrices and arrays
- SpvInst* emitCompositeArithmetic(
- SpvInstParent* parent,
- IRInst* inst,
- IRIntegerValue rowCount,
- IRIntegerValue colCount,
- IRType* elementType,
- IRType* resultType,
- bool isMatrixType)
- {
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(elementType, colCount);
- List<SpvInst*> rows;
-
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- List<IRInst*> operands;
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto originalOperand = inst->getOperand(j);
- bool shouldExtract =
- isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr
- : as<IRArrayType>(originalOperand->getDataType()) != nullptr;
-
- if (shouldExtract)
- {
- auto operand = builder.emitElementExtract(originalOperand, i);
- emitLocalInst(parent, operand);
- operands.add(operand);
- }
- else
- {
- operands.add(originalOperand);
- }
- }
- rows.add(emitVectorOrScalarArithmetic(
- parent,
- nullptr,
- rowVectorType,
- inst->getOp(),
- inst->getOperandCount(),
- operands.getArrayView()));
- }
- return emitCompositeConstruct(parent, inst, resultType, rows);
- }
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
@@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto rowCount = getIntVal(matrixType->getRowCount());
auto colCount = getIntVal(matrixType->getColumnCount());
- return emitCompositeArithmetic(
- parent,
- inst,
- rowCount,
- colCount,
- matrixType->getElementType(),
- inst->getDataType(),
- true);
- }
- else if (const auto arrayType = as<IRArrayType>(inst->getDataType()))
- {
- // Only for legalization
- auto arrayElementType = arrayType->getElementType();
- SLANG_ASSERT(as<IRVectorType>(arrayElementType));
-
- auto vectorType = as<IRVectorType>(arrayElementType);
- auto elementType = vectorType->getElementType();
- SLANG_ASSERT(
- as<IRBoolType>(elementType) || as<IRUIntType>(elementType) ||
- as<IRIntType>(elementType));
-
- auto rowCount = getIntVal(arrayType->getElementCount());
- auto colCount = getIntVal(vectorType->getElementCount());
-
- return emitCompositeArithmetic(
- parent,
- inst,
- rowCount,
- colCount,
- elementType,
- inst->getDataType(),
- false);
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<SpvInst*> rows;
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ if (as<IRMatrixType>(originalOperand->getDataType()))
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(
+ parent,
+ nullptr,
+ rowVectorType,
+ inst->getOp(),
+ inst->getOperandCount(),
+ operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
}
Array<IRInst*, 4> operands;