summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index ba238985b..57ad1a988 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_CoopMatMapElementIFunc:
+ result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst));
+ break;
case kIROp_MakeTensorAddressingTensorLayout:
result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
break;
@@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst)
+ {
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2"));
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV);
+
+ IRInst* matOrTuple = inst->getCoopMat();
+
+ IRInst* mat0 = nullptr;
+
+ UInt tupleCount = 0;
+ IRInst* tuple = as<IRMakeStruct>(matOrTuple);
+ if (tuple)
+ {
+ mat0 = tuple->getOperand(0);
+ tupleCount = tuple->getOperandCount();
+ }
+ else
+ {
+ mat0 = matOrTuple;
+ }
+
+ auto funcCall = inst->getIFuncCall();
+
+ IRInst* ifuncThis = nullptr;
+ if (inst->getOperandCount() > 2)
+ ifuncThis = inst->getIFuncThis();
+
+ return emitInstCustomOperandFunc(
+ parent,
+ inst,
+ SpvOpCooperativeMatrixPerElementOpNV,
+ [&]()
+ {
+ emitOperand(mat0->getDataType());
+ emitOperand(kResultID);
+
+ emitOperand(mat0);
+ emitOperand(funcCall);
+
+ if (ifuncThis)
+ emitOperand(ifuncThis);
+
+ for (UInt i = 1; i < tupleCount; i++)
+ emitOperand(tuple->getOperand(i));
+ });
+ }
+
SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems)
{
const auto scalarTy = as<IRBasicType>(scalar->getDataType());