summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2025-07-18 09:38:00 -0700
committerGitHub <noreply@github.com>2025-07-18 16:38:00 +0000
commit48b6e2432ea28c06d04931fccd633e31eed6d995 (patch)
treeb976380fd3464b231275e0ae2c1c6ac8af1bb6c3 /source/slang/slang-emit-spirv.cpp
parent85edfb178cd243134f4bb3d35ad71f154d76c81c (diff)
Lower int/uint/bool matrices to arrays for SPIRV (#7687)
* Add tests for expected behaviour * Allow matrix types in logical or/and * Legalize int/bool matrix types and construction with makeMatrix * Legalize uint matrices and operations * Limit testing to only SPIRV * Better tests for int and bool * Add test for uint * Remove GLSL tests * Remove old test for diagnosing int matrices * Emit SPIRV directly in tests * format code * Address PR comments * Improve testing * Address PR comments * format code * Add tests for matrix intrinsic operations * Move matrix lowering to dedicated legalization pass * Fix compiler warning * Remove signal again * Reorder matrix and vector legalization * Fix formatting * Add shift and comparison tests --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp172
1 files changed, 129 insertions, 43 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 2b6f1c821..bbed44c51 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -216,7 +216,7 @@ struct SpvInst : SpvInstParent
//
// > Word Count: The complete number of words taken by an instruction,
// > including the word holding the word count and opcode, and any optional
- // > operands. An instruction’s word count is the total space taken by the instruction.
+ // > operands. An instruction's word count is the total space taken by the instruction.
//
SpvWord wordCount = 1 + SpvWord(operandWordsCount);
@@ -360,7 +360,7 @@ struct SpvLiteralBits
// > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed
// > four per word, following the little-endian convention (i.e., the
// > first octet is in the lowest-order 8 bits of the word).
- // > The final word contains the string’s nul-termination character (0), and
+ // > The final word contains the string's nul-termination character (0), and
// > all contents past the end of the string in the final word are padded with 0.
// First work out the amount of words we'll need
@@ -2039,17 +2039,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
+ auto elementType = matrixType->getElementType();
+
+ // SPIR-V only supports floating-point matrices
+ // bool/int matrices should be lowered to
+ // arrays of vectors before reaching here
+ SLANG_ASSERT(!as<IRBoolType>(elementType));
+ SLANG_ASSERT(!as<IRIntType>(elementType));
+ SLANG_ASSERT(!as<IRUIntType>(elementType));
+
auto vectorSpvType = ensureVectorType(
- static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
+ static_cast<IRBasicType*>(elementType)->getBaseType(),
static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
const auto columnCount =
static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
- auto matrixSPVType = emitOpTypeMatrix(
- inst,
- vectorSpvType,
- SpvLiteralInteger::from32(int32_t(columnCount)));
- return matrixSPVType;
+ const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount));
+ SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv);
+ return matrixSpvType;
}
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
@@ -2621,7 +2628,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvWord arrayed =
inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed;
- // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored."
+ // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored."
SpvWord depth =
ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image
SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled
@@ -7767,12 +7774,40 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// Otherwise, operands are raw elements, we need to construct row vectors first,
// then construct matrix from row vectors.
List<SpvInst*> rowVectors;
- auto matrixType = as<IRMatrixType>(inst->getDataType());
- auto rowCount = getIntVal(matrixType->getRowCount());
- auto colCount = getIntVal(matrixType->getColumnCount());
+
+ IRIntegerValue rowCount;
+ IRIntegerValue colCount;
+ IRType* elementType;
+
+ // Data type can be either matrix or vector depending on the
+ // legalization requirements
+ auto dataType = inst->getDataType();
+
+ if (auto matrixType = as<IRMatrixType>(dataType))
+ {
+ elementType = matrixType->getElementType();
+ rowCount = getIntVal(matrixType->getRowCount());
+ colCount = getIntVal(matrixType->getColumnCount());
+ }
+ else if (auto arrayType = as<IRArrayType>(dataType))
+ {
+ auto vectorType = as<IRVectorType>(arrayType->getElementType());
+ SLANG_ASSERT(vectorType);
+
+ elementType = vectorType->getElementType();
+ rowCount = getIntVal(arrayType->getElementCount());
+ colCount = getIntVal(vectorType->getElementCount());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("data type for makeMatrix operation is "
+ "expected be either a matrix or array type");
+ }
+
IRBuilder builder(inst);
builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ auto rowVectorType = builder.getVectorType(elementType, colCount);
+
List<IRInst*> colElements;
UInt index = 0;
for (IRIntegerValue j = 0; j < rowCount; j++)
@@ -7897,7 +7932,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
ArrayView<IRInst*> operands)
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
+ SLANG_ASSERT(elementType);
+
IRBasicType* basicType = as<IRBasicType>(elementType);
+ SLANG_ASSERT(basicType);
SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
if (opCode == SpvOpUndef)
@@ -7958,6 +7996,52 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
+ // Helper method to handle composite arithmetic operations for matrices and arrays
+ SpvInst* emitCompositeArithmetic(
+ SpvInstParent* parent,
+ IRInst* inst,
+ IRIntegerValue rowCount,
+ IRIntegerValue colCount,
+ IRType* elementType,
+ IRType* resultType,
+ bool isMatrixType)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(elementType, colCount);
+ List<SpvInst*> rows;
+
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ bool shouldExtract =
+ isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr
+ : as<IRArrayType>(originalOperand->getDataType()) != nullptr;
+
+ if (shouldExtract)
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(
+ parent,
+ nullptr,
+ rowVectorType,
+ inst->getOp(),
+ inst->getOperandCount(),
+ operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, resultType, rows);
+ }
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
@@ -7965,36 +8049,38 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto rowCount = getIntVal(matrixType->getRowCount());
auto colCount = getIntVal(matrixType->getColumnCount());
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
- List<SpvInst*> rows;
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- List<IRInst*> operands;
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto originalOperand = inst->getOperand(j);
- if (as<IRMatrixType>(originalOperand->getDataType()))
- {
- auto operand = builder.emitElementExtract(originalOperand, i);
- emitLocalInst(parent, operand);
- operands.add(operand);
- }
- else
- {
- operands.add(originalOperand);
- }
- }
- rows.add(emitVectorOrScalarArithmetic(
- parent,
- nullptr,
- rowVectorType,
- inst->getOp(),
- inst->getOperandCount(),
- operands.getArrayView()));
- }
- return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
+ return emitCompositeArithmetic(
+ parent,
+ inst,
+ rowCount,
+ colCount,
+ matrixType->getElementType(),
+ inst->getDataType(),
+ true);
+ }
+ else if (const auto arrayType = as<IRArrayType>(inst->getDataType()))
+ {
+ // Only for legalization
+ auto arrayElementType = arrayType->getElementType();
+ SLANG_ASSERT(as<IRVectorType>(arrayElementType));
+
+ auto vectorType = as<IRVectorType>(arrayElementType);
+ auto elementType = vectorType->getElementType();
+ SLANG_ASSERT(
+ as<IRBoolType>(elementType) || as<IRUIntType>(elementType) ||
+ as<IRIntType>(elementType));
+
+ auto rowCount = getIntVal(arrayType->getElementCount());
+ auto colCount = getIntVal(vectorType->getElementCount());
+
+ return emitCompositeArithmetic(
+ parent,
+ inst,
+ rowCount,
+ colCount,
+ elementType,
+ inst->getDataType(),
+ false);
}
Array<IRInst*, 4> operands;