From 984d7f22f8a0909dc870c65bb927094c54f55402 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 29 May 2025 16:36:49 -0700 Subject: Implement MapElement for CoopMat (#7159) With this PR, MapElement works for the following signatures: - CoopMat<...>::MapElement(functype(...)); - CoopMat<...>::MapElement(capturing-lambda); - CoopMat<...>::MapElement(not-capturing-lambda); - Tuple,...>::MapElement(functype(...)); - Tuple,...>::MapElement(capturing-lambda); - Tuple,...>::MapElement(not-capturing-lambda); --- source/slang/slang-emit-spirv.cpp | 50 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'source/slang/slang-emit-spirv.cpp') diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index ba238985b..57ad1a988 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_CoopMatMapElementIFunc: + result = emitCoopMatMapElementWithIFunc(parent, as(inst)); + break; case kIROp_MakeTensorAddressingTensorLayout: result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); break; @@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst) + { + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2")); + requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV); + + IRInst* matOrTuple = inst->getCoopMat(); + + IRInst* mat0 = nullptr; + + UInt tupleCount = 0; + IRInst* tuple = as(matOrTuple); + if (tuple) + { + mat0 = tuple->getOperand(0); + tupleCount = tuple->getOperandCount(); + } + else + { + mat0 = matOrTuple; + } + + auto funcCall = inst->getIFuncCall(); + + IRInst* ifuncThis = nullptr; + if (inst->getOperandCount() > 2) + ifuncThis = inst->getIFuncThis(); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + [&]() + { + emitOperand(mat0->getDataType()); + emitOperand(kResultID); + + emitOperand(mat0); + emitOperand(funcCall); + + if (ifuncThis) + emitOperand(ifuncThis); + + for (UInt i = 1; i < tupleCount; i++) + emitOperand(tuple->getOperand(i)); + }); + } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as(scalar->getDataType()); -- cgit v1.2.3