diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-15 15:57:45 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-15 14:57:45 -0700 |
| commit | d0b6a0b1ab49b5958015f31364c5ad73d9cd03eb (patch) | |
| tree | e419bb3c89fa8c389eb0ccbbe8aaa29a1dcd515f /source/slang/slang-emit-spirv.cpp | |
| parent | a6174ff9443507dece534aa193f8c45e8f0ce7db (diff) | |
Add cooperative matrix 1 support (#6565)
* initial wip for spirv
* working tiled example
* clean up store and load
* minor fixes
* fix loadAny name
* add initial tests, including broken/unimplemented intrinsics
* fix subscript
* run tests at 16x16, remove not supported arithmetic tests
* minor fixups on implementation
* rename CoopMatMatrixUse
* Update tests to pass validation layers locally
* Add mat-mul-add test and minor fixes
* Add more tests
* Remove dead code
* Add coopMatLoad function and tests, enforce constexpr for matrix layout
* Use getVectorOrCoopMatrixElementType in place of getVectorElementType
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 51 |
1 files changed, 40 insertions, 11 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index baef62f1c..d07d587e5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), coopVecType); } + case kIROp_CoopMatrixType: + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + IRBuilder builder(m_irModule); + auto coopMatType = static_cast<IRCoopMatrixType*>(inst); + return emitOpTypeCoopMat( + coopMatType, + coopMatType->getElementType(), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getScope())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), + builder.getIntType())); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) || + as<IRCoopMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); if (as<IRBoolType>(fromType)) { @@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isMatrixCast = false; if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || - as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) || + // Cooperative matrices behave like vectors where arithmetic operations can be performed + // directly without having to loop through the matrix and performing operations on the + // vectors. + as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV)) { - fromType = getVectorElementType(fromTypeV); - toType = getVectorElementType(toTypeV); + fromType = getVectorOrCoopMatrixElementType(fromTypeV); + toType = getVectorOrCoopMatrixElementType(toTypeV); } else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV)) { @@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + SLANG_ASSERT(isFloatingType(toType)); if (isIntegralType(fromType)) @@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); if (as<IRBoolType>(toType)) @@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = getVectorElementType(operands[0]->getDataType()); + IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; |
