diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 51 |
1 files changed, 40 insertions, 11 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index baef62f1c..d07d587e5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), coopVecType); } + case kIROp_CoopMatrixType: + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + IRBuilder builder(m_irModule); + auto coopMatType = static_cast<IRCoopMatrixType*>(inst); + return emitOpTypeCoopMat( + coopMatType, + coopMatType->getElementType(), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getScope())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), + builder.getIntType())); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) || + as<IRCoopMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); if (as<IRBoolType>(fromType)) { @@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isMatrixCast = false; if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || - as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) || + // Cooperative matrices behave like vectors where arithmetic operations can be performed + // directly without having to loop through the matrix and performing operations on the + // vectors. + as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV)) { - fromType = getVectorElementType(fromTypeV); - toType = getVectorElementType(toTypeV); + fromType = getVectorOrCoopMatrixElementType(fromTypeV); + toType = getVectorOrCoopMatrixElementType(toTypeV); } else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV)) { @@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + SLANG_ASSERT(isFloatingType(toType)); if (isIntegralType(fromType)) @@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); if (as<IRBoolType>(toType)) @@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = getVectorElementType(operands[0]->getDataType()); + IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; |
