summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-29 16:36:49 -0700
committerGitHub <noreply@github.com>2025-05-29 16:36:49 -0700
commit984d7f22f8a0909dc870c65bb927094c54f55402 (patch)
treeab255bf44e14f6cbaa09522f90b12464f1c6a339 /source/slang/slang-emit-spirv.cpp
parentf4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff)
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures: - CoopMat<...>::MapElement(functype(...)); - CoopMat<...>::MapElement(capturing-lambda); - CoopMat<...>::MapElement(not-capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(functype(...)); - Tuple<CoopMat<...>,...>::MapElement(capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index ba238985b..57ad1a988 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_CoopMatMapElementIFunc:
+ result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst));
+ break;
case kIROp_MakeTensorAddressingTensorLayout:
result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
break;
@@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst)
+ {
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2"));
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV);
+
+ IRInst* matOrTuple = inst->getCoopMat();
+
+ IRInst* mat0 = nullptr;
+
+ UInt tupleCount = 0;
+ IRInst* tuple = as<IRMakeStruct>(matOrTuple);
+ if (tuple)
+ {
+ mat0 = tuple->getOperand(0);
+ tupleCount = tuple->getOperandCount();
+ }
+ else
+ {
+ mat0 = matOrTuple;
+ }
+
+ auto funcCall = inst->getIFuncCall();
+
+ IRInst* ifuncThis = nullptr;
+ if (inst->getOperandCount() > 2)
+ ifuncThis = inst->getIFuncThis();
+
+ return emitInstCustomOperandFunc(
+ parent,
+ inst,
+ SpvOpCooperativeMatrixPerElementOpNV,
+ [&]()
+ {
+ emitOperand(mat0->getDataType());
+ emitOperand(kResultID);
+
+ emitOperand(mat0);
+ emitOperand(funcCall);
+
+ if (ifuncThis)
+ emitOperand(ifuncThis);
+
+ for (UInt i = 1; i < tupleCount; i++)
+ emitOperand(tuple->getOperand(i));
+ });
+ }
+
SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems)
{
const auto scalarTy = as<IRBasicType>(scalar->getDataType());