summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-15 07:02:38 +0000
committerGitHub <noreply@github.com>2025-05-15 00:02:38 -0700
commit49de1e8f60c698e9d524befacc988fb06574b234 (patch)
treecc1006b24532b0f98a2f8af49010925e9d992f66 /source/slang/slang-emit-spirv.cpp
parentdd275dd952afc1b0d1a156d786c28620a48863e1 (diff)
Support tensor addressing (#7060)
This commit implements two new types and related Load/Store functions in CoopMat. tensor_addrressing.TensorLayout tensor_addressing.TensorView CoopMat.Load(..., TensorLayout) CoopMat.Load(..., TensorLayout, TensorView) CoopMat.Store(..., TensorLayout) CoopMat.Store(..., TensorLayout, TensorView) CoopMat.Load(..., TensorLayout, TensorView)
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp100
1 files changed, 100 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 32d3ba7c3..7d202c7c1 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
builder.getIntType()));
}
+ case kIROp_TensorAddressingTensorLayoutType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorLayoutType = static_cast<IRTensorAddressingTensorLayoutType*>(inst);
+ return emitOpTypeTensorLayout(
+ tensorLayoutType,
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getDimension())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getClampMode())->getValue(),
+ builder.getIntType()));
+ }
+ case kIROp_TensorAddressingTensorViewType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorViewType = static_cast<IRTensorAddressingTensorViewType*>(inst);
+
+ IRIntegerValue dim =
+ static_cast<IRIntLit*>(tensorViewType->getDimension())->getValue();
+ SpvInst* spvDim = emitIntConstant(dim, builder.getIntType());
+
+ SpvInst* spvHasDimension =
+ ensureInst(static_cast<IRBoolLit*>(tensorViewType->getHasDimension()));
+
+ SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr};
+ for (int i = 0; i < (int)dim; i++)
+ {
+ spvPermutations[i] = emitIntConstant(
+ static_cast<IRIntLit*>(tensorViewType->getPermutation(i))->getValue(),
+ builder.getIntType());
+ }
+
+ if (dim == 1)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0]);
+ }
+ else if (dim == 2)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1]);
+ }
+ else if (dim == 3)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2]);
+ }
+ else if (dim == 4)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3]);
+ }
+ else if (dim == 5)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3],
+ spvPermutations[4]);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unsupported tensor dimension");
+ }
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_MakeTensorAddressingTensorLayout:
+ result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
+ case kIROp_MakeTensorAddressingTensorView:
+ result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
case kIROp_Select:
result = emitInst(
parent,