diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-15 07:02:38 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-15 00:02:38 -0700 |
| commit | 49de1e8f60c698e9d524befacc988fb06574b234 (patch) | |
| tree | cc1006b24532b0f98a2f8af49010925e9d992f66 /source/slang/slang-emit-spirv.cpp | |
| parent | dd275dd952afc1b0d1a156d786c28620a48863e1 (diff) | |
Support tensor addressing (#7060)
This commit implements two new types and related Load/Store functions in CoopMat.
tensor_addrressing.TensorLayout
tensor_addressing.TensorView
CoopMat.Load(..., TensorLayout)
CoopMat.Load(..., TensorLayout, TensorView)
CoopMat.Store(..., TensorLayout)
CoopMat.Store(..., TensorLayout, TensorView)
CoopMat.Load(..., TensorLayout, TensorView)
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 32d3ba7c3..7d202c7c1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), builder.getIntType())); } + case kIROp_TensorAddressingTensorLayoutType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorLayoutType = static_cast<IRTensorAddressingTensorLayoutType*>(inst); + return emitOpTypeTensorLayout( + tensorLayoutType, + emitIntConstant( + static_cast<IRIntLit*>(tensorLayoutType->getDimension())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(tensorLayoutType->getClampMode())->getValue(), + builder.getIntType())); + } + case kIROp_TensorAddressingTensorViewType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorViewType = static_cast<IRTensorAddressingTensorViewType*>(inst); + + IRIntegerValue dim = + static_cast<IRIntLit*>(tensorViewType->getDimension())->getValue(); + SpvInst* spvDim = emitIntConstant(dim, builder.getIntType()); + + SpvInst* spvHasDimension = + ensureInst(static_cast<IRBoolLit*>(tensorViewType->getHasDimension())); + + SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr}; + for (int i = 0; i < (int)dim; i++) + { + spvPermutations[i] = emitIntConstant( + static_cast<IRIntLit*>(tensorViewType->getPermutation(i))->getValue(), + builder.getIntType()); + } + + if (dim == 1) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0]); + } + else if (dim == 2) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1]); + } + else if (dim == 3) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2]); + } + else if (dim == 4) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3]); + } + else if (dim == 5) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3], + spvPermutations[4]); + } + else + { + SLANG_UNEXPECTED("Unsupported tensor dimension"); + } + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_MakeTensorAddressingTensorLayout: + result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); + break; + case kIROp_MakeTensorAddressingTensorView: + result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType()))); + break; case kIROp_Select: result = emitInst( parent, |
