From 49de1e8f60c698e9d524befacc988fb06574b234 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 15 May 2025 07:02:38 +0000 Subject: Support tensor addressing (#7060) This commit implements two new types and related Load/Store functions in CoopMat. tensor_addrressing.TensorLayout tensor_addressing.TensorView CoopMat.Load(..., TensorLayout) CoopMat.Load(..., TensorLayout, TensorView) CoopMat.Store(..., TensorLayout) CoopMat.Store(..., TensorLayout, TensorView) CoopMat.Load(..., TensorLayout, TensorView) --- source/slang/slang-emit-spirv.cpp | 100 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) (limited to 'source/slang/slang-emit-spirv.cpp') diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 32d3ba7c3..7d202c7c1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast(coopMatType->getMatrixUse())->getValue(), builder.getIntType())); } + case kIROp_TensorAddressingTensorLayoutType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorLayoutType = static_cast(inst); + return emitOpTypeTensorLayout( + tensorLayoutType, + emitIntConstant( + static_cast(tensorLayoutType->getDimension())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast(tensorLayoutType->getClampMode())->getValue(), + builder.getIntType())); + } + case kIROp_TensorAddressingTensorViewType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorViewType = static_cast(inst); + + IRIntegerValue dim = + static_cast(tensorViewType->getDimension())->getValue(); + SpvInst* spvDim = emitIntConstant(dim, builder.getIntType()); + + SpvInst* spvHasDimension = + ensureInst(static_cast(tensorViewType->getHasDimension())); + + SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr}; + for (int i = 0; i < (int)dim; i++) + { + spvPermutations[i] = emitIntConstant( + static_cast(tensorViewType->getPermutation(i))->getValue(), + builder.getIntType()); + } + + if (dim == 1) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0]); + } + else if (dim == 2) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1]); + } + else if (dim == 3) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2]); + } + else if (dim == 4) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3]); + } + else if (dim == 5) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3], + spvPermutations[4]); + } + else + { + SLANG_UNEXPECTED("Unsupported tensor dimension"); + } + } case kIROp_MatrixType: { auto matrixType = static_cast(inst); @@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_MakeTensorAddressingTensorLayout: + result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); + break; + case kIROp_MakeTensorAddressingTensorView: + result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType()))); + break; case kIROp_Select: result = emitInst( parent, -- cgit v1.2.3