diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-29 16:36:49 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-29 16:36:49 -0700 |
| commit | 984d7f22f8a0909dc870c65bb927094c54f55402 (patch) | |
| tree | ab255bf44e14f6cbaa09522f90b12464f1c6a339 /source/slang/slang-emit-spirv.cpp | |
| parent | f4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff) | |
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures:
- CoopMat<...>::MapElement(functype(...));
- CoopMat<...>::MapElement(capturing-lambda);
- CoopMat<...>::MapElement(not-capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(functype(...));
- Tuple<CoopMat<...>,...>::MapElement(capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index ba238985b..57ad1a988 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_CoopMatMapElementIFunc: + result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst)); + break; case kIROp_MakeTensorAddressingTensorLayout: result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); break; @@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst) + { + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2")); + requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV); + + IRInst* matOrTuple = inst->getCoopMat(); + + IRInst* mat0 = nullptr; + + UInt tupleCount = 0; + IRInst* tuple = as<IRMakeStruct>(matOrTuple); + if (tuple) + { + mat0 = tuple->getOperand(0); + tupleCount = tuple->getOperandCount(); + } + else + { + mat0 = matOrTuple; + } + + auto funcCall = inst->getIFuncCall(); + + IRInst* ifuncThis = nullptr; + if (inst->getOperandCount() > 2) + ifuncThis = inst->getIFuncThis(); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + [&]() + { + emitOperand(mat0->getDataType()); + emitOperand(kResultID); + + emitOperand(mat0); + emitOperand(funcCall); + + if (ifuncThis) + emitOperand(ifuncThis); + + for (UInt i = 1; i < tupleCount; i++) + emitOperand(tuple->getOperand(i)); + }); + } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as<IRBasicType>(scalar->getDataType()); |
