summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp51
1 files changed, 40 insertions, 11 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index baef62f1c..d07d587e5 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(),
coopVecType);
}
+ case kIROp_CoopMatrixType:
+ {
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix"));
+
+ IRBuilder builder(m_irModule);
+ auto coopMatType = static_cast<IRCoopMatrixType*>(inst);
+ return emitOpTypeCoopMat(
+ coopMatType,
+ coopMatType->getElementType(),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getScope())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
+ builder.getIntType()));
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto baseTy = base->getDataType();
SLANG_ASSERT(
as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) ||
- as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy));
+ as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) ||
+ as<IRCoopMatrixType>(baseTy));
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
@@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
if (as<IRBoolType>(fromType))
{
@@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
bool isMatrixCast = false;
if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) ||
- as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV))
+ as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) ||
+ // Cooperative matrices behave like vectors where arithmetic operations can be performed
+ // directly without having to loop through the matrix and performing operations on the
+ // vectors.
+ as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV))
{
- fromType = getVectorElementType(fromTypeV);
- toType = getVectorElementType(toTypeV);
+ fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ toType = getVectorOrCoopMatrixElementType(toTypeV);
}
else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV))
{
@@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
+
SLANG_ASSERT(isFloatingType(toType));
if (isIntegralType(fromType))
@@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
SLANG_ASSERT(isFloatingType(fromType));
if (as<IRBoolType>(toType))
@@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
UInt operandCount,
ArrayView<IRInst*> operands)
{
- IRType* elementType = getVectorElementType(operands[0]->getDataType());
+ IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
bool isFloatingPoint = false;
bool isBool = false;