diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-15 07:02:38 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-15 00:02:38 -0700 |
| commit | 49de1e8f60c698e9d524befacc988fb06574b234 (patch) | |
| tree | cc1006b24532b0f98a2f8af49010925e9d992f66 /source | |
| parent | dd275dd952afc1b0d1a156d786c28620a48863e1 (diff) | |
Support tensor addressing (#7060)
This commit implements two new types and related Load/Store functions in CoopMat.
tensor_addrressing.TensorLayout
tensor_addressing.TensorView
CoopMat.Load(..., TensorLayout)
CoopMat.Load(..., TensorLayout, TensorView)
CoopMat.Store(..., TensorLayout)
CoopMat.Store(..., TensorLayout, TensorView)
CoopMat.Load(..., TensorLayout, TensorView)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 625 | ||||
| -rw-r--r-- | source/slang/slang-capabilities.capdef | 19 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 50 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 100 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 17 |
7 files changed, 777 insertions, 39 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 829d5ce97..d7a65c031 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -22539,6 +22539,243 @@ enum CoopMatClampMode }; +${{{{ +// SPIRV described that the max value for `Dim` is 5. +// +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV +// OpTypeTensorLayoutNV: +// Dim is the number of dimensions in the tensor layout, and must be a constant +// instruction with scalar 32-bit integer type. The value must be greater than +// zero and less than or equal to 5. +// +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV +// OpTypeTensorViewNV: +// Dim is the number of dimensions in the tensor view, and must be a constant +// instruction with scalar 32-bit integer type. The value must be greater than +// zero and less than or equal to 5. +// +const int kMaxCoopMatTensorDimension = 5; +}}}} + +// +// TensorLayout +// + +__intrinsic_type($(kIROp_TensorAddressingTensorLayoutType)) +[require(tensor_addressing)] +__generic< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode = CoopMatClampMode.Undefined +> +struct TensorLayout +{ + __intrinsic_op($(kIROp_MakeTensorAddressingTensorLayout)) + __init(); +}; + + +${{{{ + for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim) + { + StringBuilder dimParams; + StringBuilder dimAsms; + StringBuilder strideParams; + StringBuilder strideAsms; + StringBuilder sliceParams; + StringBuilder sliceAsms; + StringBuilder blockSizeParams; + StringBuilder blockSizeAsms; + for (int j = 1; j < iDim; ++j) + { + dimParams << ", uint32_t dim" << j; + dimAsms << " $dim" << j; + strideParams << ", uint32_t stride" << j; + strideAsms << " $stride" << j; + sliceParams << ", uint32_t offset" << j << ", uint32_t span" << j; + sliceAsms << " $offset" << j << " $span" << j; + blockSizeParams << ", uint32_t blockSize" << j; + blockSizeAsms << " $blockSize" << j; + } +}}}} + + +extension< + let ClampMode : CoopMatClampMode +> TensorLayout<$(iDim), ClampMode> +{ + [require(tensor_addressing)] + This Dimension(uint32_t dim0 $(dimParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorLayoutSetDimensionNV $this $dim0 $(dimAsms) + }; + } + } + + [require(tensor_addressing)] + This Stride(uint32_t stride0 $(strideParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorLayoutSetStrideNV $this $stride0 $(strideAsms); + }; + } + } + + [require(tensor_addressing)] + This Slice(uint32_t offset0, uint32_t span0 $(sliceParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorLayoutSliceNV $this $offset0 $span0 $(sliceAsms); + }; + } + } + + [require(tensor_addressing)] + This ClampValue(CoopMatClampMode clampMode) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorLayoutSetClampValueNV $this $clampMode; + }; + } + } + + [require(tensor_addressing)] + This BlockSize(uint32_t blockSize0 $(blockSizeParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorLayoutSetBlockSizeNV $this $blockSize0 $(blockSizeAsms); + }; + } + } +}; + +${{{{ + } // iDim +}}}} + +// +// TensorView +// + +${{{{ + StringBuilder tensorViewStruct; + for (int j = 0; j < kMaxCoopMatTensorDimension; ++j) + { + // Assigning the max value as a default value, + // because the max value is an invalid value and it allows us to check if the value + // is explicitly set by the user or not. + tensorViewStruct << ", let p" << j << " : uint32_t = 0xff"; + } +}}}} + +__intrinsic_type($(kIROp_TensorAddressingTensorViewType)) +__generic< + let Dim : uint32_t, + let HasDimensions : bool + $(tensorViewStruct) +> +struct TensorView +{ + __intrinsic_op($(kIROp_MakeTensorAddressingTensorView)) + __init(); +}; + +${{{{ + for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim) + { + StringBuilder tensorViewTypes; + StringBuilder tensorViewExtensions; + StringBuilder dimParams; + StringBuilder dimAsms; + StringBuilder strideParams; + StringBuilder strideAsms; + for (int j = 1; j < iDim; ++j) + { + tensorViewTypes << ", let Dim" << j << " : uint32_t"; + tensorViewExtensions << ", Dim" << j; + dimParams << ", uint32_t dim" << j; + dimAsms << " $dim" << j; + strideParams << ", uint32_t stride" << j; + strideAsms << " $stride" << j; + } + for (int j = iDim; j < kMaxCoopMatTensorDimension; ++j) + { + tensorViewExtensions << ", 0xff"; + } +}}}} + +[require(tensor_addressing)] +extension< + let HasDimensions : bool, + let Dim0 : uint32_t + $(tensorViewTypes) +> TensorView<$(iDim), HasDimensions, Dim0 $(tensorViewExtensions)> +{ + [require(tensor_addressing)] + This Dimension(uint32_t dim0 $(dimParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorViewSetDimensionNV $this $dim0 $(dimAsms); + }; + } + } + + [require(tensor_addressing)] + This Stride(uint32_t stride0 $(strideParams)) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorViewSetStrideNV $this $stride0 $(strideAsms); + }; + } + } + + [require(tensor_addressing)] + This Clip(uint clipRowOffset, uint clipRowSpan, uint clipColOffset, uint clipColSpan) + { + __target_switch + { + case spirv: + return spirv_asm + { + result:$$This = OpTensorViewSetClipNV $this $clipRowOffset $clipRowSpan $clipColOffset $clipColSpan; + }; + } + } +}; + +${{{{ + } // iDim +}}}} + + // // Cooperative Matrix type // @@ -22788,21 +23025,23 @@ struct CoopMat [require(cooperative_matrix)] void Store< let matrixLayout : CoopMatMatrixLayout - >(RWStructuredBuffer<T> buffer, uint element, uint stride) + >(RWByteAddressBuffer buffer, uint element, uint stride) { - __Store(buffer, element, stride, matrixLayout); + __store<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } [require(cooperative_matrix)] void Store< let matrixLayout : CoopMatMatrixLayout - >(RWByteAddressBuffer buffer, uint element, uint stride) + >(RWStructuredBuffer<T> buffer, uint element, uint stride) { - __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + __store<matrixLayout>(buffer, element, stride); } [require(cooperative_matrix)] - internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + internal void __store< + let matrixLayout : CoopMatMatrixLayout + >(RWStructuredBuffer<T> buffer, uint element, uint stride) { let zero = 0; let alignment = 16U; @@ -22883,20 +23122,19 @@ struct CoopMat // Load // - [__NoSideEffect] - [require(cooperative_matrix)] - static This Load< - let matrixLayout : CoopMatMatrixLayout - >(ByteAddressBuffer buffer, uint element, uint stride) +${{{{ + for (const char* RW : { "", "RW" }) { - return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); - } +}}}} [__NoSideEffect] [require(cooperative_matrix)] static This Load< let matrixLayout : CoopMatMatrixLayout - >(RWByteAddressBuffer buffer, uint element, uint stride) + >( + $(RW)ByteAddressBuffer buffer, + uint element, + uint stride) { return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } @@ -22905,7 +23143,10 @@ struct CoopMat [require(cooperative_matrix)] static This Load< let matrixLayout : CoopMatMatrixLayout - >(StructuredBuffer<T> buffer, uint element, uint stride) + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + uint stride) { let zero = 0; let alignment = 16U; @@ -22917,21 +23158,9 @@ struct CoopMat }; } - [__NoSideEffect] - [require(cooperative_matrix)] - static This Load< - let matrixLayout : CoopMatMatrixLayout - >(RWStructuredBuffer<T> buffer, uint element, uint stride) - { - let zero = 0; - let alignment = 16U; - return spirv_asm - { - %storagePointerType = OpTypePointer StorageBuffer $$T; - %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; - result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; - }; - } +${{{{ + } // RW +}}}} [ForceInline] [__NoSideEffect] @@ -23076,6 +23305,340 @@ struct CoopMat return true; } + // + // Load with TensorLayout and TensorView + // + +${{{{ + StringBuilder tensorViewTypes; + StringBuilder tensorViewParams; + for (int j = 0; j < kMaxCoopMatTensorDimension; ++j) + { + tensorViewTypes << ", let p" << j << " : uint32_t = " << kMaxCoopMatTensorDimension; + tensorViewParams << ", p" << j; + } + + for (const char* RW : { "", "RW" }) + { +}}}} + + [require(cooperative_matrix_tensor_addressing)] + static This Load< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)ByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout) + { + return __loadLayout<Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout); + } + + [require(cooperative_matrix_tensor_addressing)] + static This Load< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout) + { + return __loadLayout<Dim, ClampMode>(buffer, element, tensorLayout); + } + + [require(cooperative_matrix_tensor_addressing)] + static This __loadLayout< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout) + { + let zero = 0; + let alignment = 16U; + + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV + This ret; + return spirv_asm + { + OpCapability CooperativeMatrixTensorAddressingNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment None; + }; + } + + [require(cooperative_matrix_tensor_addressing)] + static This Load< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)ByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView) + { + return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView); + } + + [require(cooperative_matrix_tensor_addressing)] + static This Load< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView) + { + return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView); + } + + [require(cooperative_matrix_tensor_addressing)] + static This __loadView< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView) + { + let zero = 0; + let alignment = 16U; + + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV + This ret; + return spirv_asm + { + OpCapability CooperativeMatrixTensorAddressingNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView $tensorView; + }; + } + + [require(cooperative_matrix_block_load)] + static This Load< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)ByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + return __loadLayoutDecode<U, Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, decodeFunc); + } + + [require(cooperative_matrix_block_load)] + static This Load< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + return __loadLayoutDecode<U, Dim, ClampMode>(buffer, element, tensorLayout, decodeFunc); + } + + [require(cooperative_matrix_block_load)] + static This __loadLayoutDecode< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + let zero = 0; + let alignment = 16U; + + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV + This ret; + return spirv_asm + { + OpCapability CooperativeMatrixBlockLoadsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment DecodeFunc $decodeFunc; + }; + } + + [require(cooperative_matrix_block_load)] + static This Load< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)ByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView, decodeFunc); + } + + [require(cooperative_matrix_block_load)] + static This Load< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView, decodeFunc); + } + + [require(cooperative_matrix_block_load)] + static This __loadViewDecode< + U, + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + $(RW)StructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView, + functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc) + { + let zero = 0; + let alignment = 16U; + + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV + This ret; + return spirv_asm + { + OpCapability CooperativeMatrixTensorAddressingNV; + OpCapability CooperativeMatrixBlockLoadsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView|DecodeFunc $tensorView $decodeFunc; + }; + } + +${{{{ + } // RW +}}}} + + [require(cooperative_matrix_tensor_addressing)] + void Store< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + RWByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout) + { + Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout); + } + + [require(cooperative_matrix_tensor_addressing)] + void Store< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode + >( + RWStructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout) + { + let zero = 0; + let alignment = 16U; + + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeMatrixTensorAddressingNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment None; + }; + } + } + + [require(cooperative_matrix_tensor_addressing)] + void Store< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + RWByteAddressBuffer buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView) + { + Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView); + } + + [require(cooperative_matrix_tensor_addressing)] + void Store< + let Dim : uint32_t, + let ClampMode : CoopMatClampMode, + let DimView : uint32_t, + let HasDimensions : bool + $(tensorViewTypes) + >( + RWStructuredBuffer<T> buffer, + uint element, + TensorLayout<Dim, ClampMode> tensorLayout, + TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView) + { + let zero = 0; + let alignment = 16U; + + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeMatrixTensorAddressingNV; + OpExtension "SPV_NV_cooperative_matrix2"; + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment TensorView $tensorView; + }; + } + } + } // struct CoopMat @@ -23098,7 +23661,7 @@ CoopMat<T, S, M, N, R> coopMatLoad< uint element, uint stride) { - return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } [ForceInline] @@ -23115,7 +23678,7 @@ CoopMat<T, S, M, N, R> coopMatLoad< uint element, uint stride) { - return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } [ForceInline] diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index b909fa0f9..937584908 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -567,25 +567,25 @@ def SPV_GOOGLE_user_type : _spirv_1_0; /// [EXT] def SPV_EXT_replicated_composites : _spirv_1_0; -/// Represents the SPIR-V extension for SPV_NV_cooperative_vector. +/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model. /// [EXT] -def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites; +def SPV_KHR_vulkan_memory_model : _spirv_1_3; -/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix. +/// Represents the SPIR-V extension for SPV_NV_cooperative_vector. /// [EXT] -def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer; +def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites + SPV_KHR_vulkan_memory_model; -/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2. +/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix. /// [EXT] -def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix; +def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer + SPV_KHR_vulkan_memory_model; /// Represents the SPIR-V extension for SPV_NV_tensor_addressing. /// [EXT] def SPV_NV_tensor_addressing : _spirv_1_6; -/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model. +/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2. /// [EXT] -def SPV_KHR_vulkan_memory_model : _spirv_1_3; +def SPV_NV_cooperative_matrix2 : SPV_NV_tensor_addressing + SPV_KHR_cooperative_matrix; // SPIRV Capabilities. @@ -1183,6 +1183,9 @@ alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV; /// Capabilities needed to use tensor addressing /// [Compound] alias tensor_addressing = spvTensorAddressingNV; +/// Capabilities needed to use tensor addressing +/// [Compound] +alias cooperative_matrix_2 = spvCooperativeMatrixKHR + spvCooperativeMatrixReductionsNV + spvCooperativeMatrixConversionsNV + spvCooperativeMatrixPerElementOperationsNV + spvCooperativeMatrixTensorAddressingNV + spvCooperativeMatrixBlockLoadsNV + spvTensorAddressingNV; // Non-internal shader stages // diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index b716e470d..81ad4bddf 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -174,6 +174,56 @@ SpvInst* emitOpTypeCoopMat( matrixUse); } +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV +template<typename T1, typename T2> +SpvInst* emitOpTypeTensorLayout(IRInst* inst, const T1& dim, const T2& clampMode) +{ + static_assert(isSingular<T1>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeTensorLayoutNV, + kResultID, + dim, + clampMode); +} + +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV +template<typename T1, typename T2, typename... TPerms> +SpvInst* emitOpTypeTensorView( + IRInst* inst, + const T1& dim, + const T2& hasDimensions, + const TPerms&... perms) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeTensorViewNV, + kResultID, + dim, + hasDimensions, + perms...); +} + +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorLayoutNV +template<typename T1> +SpvInst* emitOpCreateTensorLayout(SpvInstParent* parent, IRInst* inst, const T1& idResultType) +{ + static_assert(isSingular<T1>); + return emitInst(parent, inst, SpvOpCreateTensorLayoutNV, idResultType, kResultID); +} + +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorViewNV +template<typename T1> +SpvInst* emitOpCreateTensorView(SpvInstParent* parent, IRInst* inst, const T1& idResultType) +{ + static_assert(isSingular<T1>); + return emitInst(parent, inst, SpvOpCreateTensorViewNV, idResultType, kResultID); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix template<typename T> SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 32d3ba7c3..7d202c7c1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), builder.getIntType())); } + case kIROp_TensorAddressingTensorLayoutType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorLayoutType = static_cast<IRTensorAddressingTensorLayoutType*>(inst); + return emitOpTypeTensorLayout( + tensorLayoutType, + emitIntConstant( + static_cast<IRIntLit*>(tensorLayoutType->getDimension())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(tensorLayoutType->getClampMode())->getValue(), + builder.getIntType())); + } + case kIROp_TensorAddressingTensorViewType: + { + requireSPIRVCapability(SpvCapabilityTensorAddressingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing")); + + IRBuilder builder(m_irModule); + auto tensorViewType = static_cast<IRTensorAddressingTensorViewType*>(inst); + + IRIntegerValue dim = + static_cast<IRIntLit*>(tensorViewType->getDimension())->getValue(); + SpvInst* spvDim = emitIntConstant(dim, builder.getIntType()); + + SpvInst* spvHasDimension = + ensureInst(static_cast<IRBoolLit*>(tensorViewType->getHasDimension())); + + SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr}; + for (int i = 0; i < (int)dim; i++) + { + spvPermutations[i] = emitIntConstant( + static_cast<IRIntLit*>(tensorViewType->getPermutation(i))->getValue(), + builder.getIntType()); + } + + if (dim == 1) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0]); + } + else if (dim == 2) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1]); + } + else if (dim == 3) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2]); + } + else if (dim == 4) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3]); + } + else if (dim == 5) + { + return emitOpTypeTensorView( + tensorViewType, + spvDim, + spvHasDimension, + spvPermutations[0], + spvPermutations[1], + spvPermutations[2], + spvPermutations[3], + spvPermutations[4]); + } + else + { + SLANG_UNEXPECTED("Unsupported tensor dimension"); + } + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_MakeTensorAddressingTensorLayout: + result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); + break; + case kIROp_MakeTensorAddressingTensorView: + result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType()))); + break; case kIROp_Select: result = emitInst( parent, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f863858e4..e44954521 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -238,6 +238,10 @@ INST(RayQueryType, RayQuery, 1, HOISTABLE) INST(HitObjectType, HitObject, 0, HOISTABLE) INST(CoopVectorType, CoopVectorType, 2, HOISTABLE) INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE) +INST(TensorAddressingTensorLayoutType, TensorAddressingTensorLayoutType, 2, HOISTABLE) +INST(TensorAddressingTensorViewType, TensorAddressingTensorViewType, 3, HOISTABLE) +INST(MakeTensorAddressingTensorLayout, MakeTensorAddressingTensorLayout, 0, 0) +INST(MakeTensorAddressingTensorView, MakeTensorAddressingTensorView, 0, 0) // Opaque type that can be dynamically cast to other resource types. INST(DynamicResourceType, DynamicResource, 0, HOISTABLE) diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 0287ae81a..3319aa89d 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1806,6 +1806,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } } + // Scan through the entry points and find the max version required. auto processInst = [&](IRInst* globalInst) { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 461ed567a..9fd2263c1 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1882,6 +1882,23 @@ struct IRCoopMatrixType : IRType IR_LEAF_ISA(CoopMatrixType) }; +struct IRTensorAddressingTensorLayoutType : IRType +{ + IRInst* getDimension() { return getOperand(0); } + IRInst* getClampMode() { return getOperand(1); } + + IR_LEAF_ISA(TensorAddressingTensorLayoutType) +}; + +struct IRTensorAddressingTensorViewType : IRType +{ + IRInst* getDimension() { return getOperand(0); } + IRInst* getHasDimension() { return getOperand(1); } + IRInst* getPermutation(int index) { return getOperand(2 + index); } + + IR_LEAF_ISA(TensorAddressingTensorViewType) +}; + bool isDefinition(IRInst* inVal); // A structure type is represented as a parent instruction, |
