summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-15 07:02:38 +0000
committerGitHub <noreply@github.com>2025-05-15 00:02:38 -0700
commit49de1e8f60c698e9d524befacc988fb06574b234 (patch)
treecc1006b24532b0f98a2f8af49010925e9d992f66 /source
parentdd275dd952afc1b0d1a156d786c28620a48863e1 (diff)
Support tensor addressing (#7060)
This commit implements two new types and related Load/Store functions in CoopMat. tensor_addrressing.TensorLayout tensor_addressing.TensorView CoopMat.Load(..., TensorLayout) CoopMat.Load(..., TensorLayout, TensorView) CoopMat.Store(..., TensorLayout) CoopMat.Store(..., TensorLayout, TensorView) CoopMat.Load(..., TensorLayout, TensorView)
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang625
-rw-r--r--source/slang/slang-capabilities.capdef19
-rw-r--r--source/slang/slang-emit-spirv-ops.h50
-rw-r--r--source/slang/slang-emit-spirv.cpp100
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp1
-rw-r--r--source/slang/slang-ir.h17
7 files changed, 777 insertions, 39 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 829d5ce97..d7a65c031 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22539,6 +22539,243 @@ enum CoopMatClampMode
};
+${{{{
+// SPIRV described that the max value for `Dim` is 5.
+//
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV
+// OpTypeTensorLayoutNV:
+// Dim is the number of dimensions in the tensor layout, and must be a constant
+// instruction with scalar 32-bit integer type. The value must be greater than
+// zero and less than or equal to 5.
+//
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV
+// OpTypeTensorViewNV:
+// Dim is the number of dimensions in the tensor view, and must be a constant
+// instruction with scalar 32-bit integer type. The value must be greater than
+// zero and less than or equal to 5.
+//
+const int kMaxCoopMatTensorDimension = 5;
+}}}}
+
+//
+// TensorLayout
+//
+
+__intrinsic_type($(kIROp_TensorAddressingTensorLayoutType))
+[require(tensor_addressing)]
+__generic<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode = CoopMatClampMode.Undefined
+>
+struct TensorLayout
+{
+ __intrinsic_op($(kIROp_MakeTensorAddressingTensorLayout))
+ __init();
+};
+
+
+${{{{
+ for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim)
+ {
+ StringBuilder dimParams;
+ StringBuilder dimAsms;
+ StringBuilder strideParams;
+ StringBuilder strideAsms;
+ StringBuilder sliceParams;
+ StringBuilder sliceAsms;
+ StringBuilder blockSizeParams;
+ StringBuilder blockSizeAsms;
+ for (int j = 1; j < iDim; ++j)
+ {
+ dimParams << ", uint32_t dim" << j;
+ dimAsms << " $dim" << j;
+ strideParams << ", uint32_t stride" << j;
+ strideAsms << " $stride" << j;
+ sliceParams << ", uint32_t offset" << j << ", uint32_t span" << j;
+ sliceAsms << " $offset" << j << " $span" << j;
+ blockSizeParams << ", uint32_t blockSize" << j;
+ blockSizeAsms << " $blockSize" << j;
+ }
+}}}}
+
+
+extension<
+ let ClampMode : CoopMatClampMode
+> TensorLayout<$(iDim), ClampMode>
+{
+ [require(tensor_addressing)]
+ This Dimension(uint32_t dim0 $(dimParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetDimensionNV $this $dim0 $(dimAsms)
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Stride(uint32_t stride0 $(strideParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetStrideNV $this $stride0 $(strideAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Slice(uint32_t offset0, uint32_t span0 $(sliceParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSliceNV $this $offset0 $span0 $(sliceAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This ClampValue(CoopMatClampMode clampMode)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetClampValueNV $this $clampMode;
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This BlockSize(uint32_t blockSize0 $(blockSizeParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetBlockSizeNV $this $blockSize0 $(blockSizeAsms);
+ };
+ }
+ }
+};
+
+${{{{
+ } // iDim
+}}}}
+
+//
+// TensorView
+//
+
+${{{{
+ StringBuilder tensorViewStruct;
+ for (int j = 0; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ // Assigning the max value as a default value,
+ // because the max value is an invalid value and it allows us to check if the value
+ // is explicitly set by the user or not.
+ tensorViewStruct << ", let p" << j << " : uint32_t = 0xff";
+ }
+}}}}
+
+__intrinsic_type($(kIROp_TensorAddressingTensorViewType))
+__generic<
+ let Dim : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewStruct)
+>
+struct TensorView
+{
+ __intrinsic_op($(kIROp_MakeTensorAddressingTensorView))
+ __init();
+};
+
+${{{{
+ for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim)
+ {
+ StringBuilder tensorViewTypes;
+ StringBuilder tensorViewExtensions;
+ StringBuilder dimParams;
+ StringBuilder dimAsms;
+ StringBuilder strideParams;
+ StringBuilder strideAsms;
+ for (int j = 1; j < iDim; ++j)
+ {
+ tensorViewTypes << ", let Dim" << j << " : uint32_t";
+ tensorViewExtensions << ", Dim" << j;
+ dimParams << ", uint32_t dim" << j;
+ dimAsms << " $dim" << j;
+ strideParams << ", uint32_t stride" << j;
+ strideAsms << " $stride" << j;
+ }
+ for (int j = iDim; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ tensorViewExtensions << ", 0xff";
+ }
+}}}}
+
+[require(tensor_addressing)]
+extension<
+ let HasDimensions : bool,
+ let Dim0 : uint32_t
+ $(tensorViewTypes)
+> TensorView<$(iDim), HasDimensions, Dim0 $(tensorViewExtensions)>
+{
+ [require(tensor_addressing)]
+ This Dimension(uint32_t dim0 $(dimParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetDimensionNV $this $dim0 $(dimAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Stride(uint32_t stride0 $(strideParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetStrideNV $this $stride0 $(strideAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Clip(uint clipRowOffset, uint clipRowSpan, uint clipColOffset, uint clipColSpan)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetClipNV $this $clipRowOffset $clipRowSpan $clipColOffset $clipColSpan;
+ };
+ }
+ }
+};
+
+${{{{
+ } // iDim
+}}}}
+
+
//
// Cooperative Matrix type
//
@@ -22788,21 +23025,23 @@ struct CoopMat
[require(cooperative_matrix)]
void Store<
let matrixLayout : CoopMatMatrixLayout
- >(RWStructuredBuffer<T> buffer, uint element, uint stride)
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
{
- __Store(buffer, element, stride, matrixLayout);
+ __store<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[require(cooperative_matrix)]
void Store<
let matrixLayout : CoopMatMatrixLayout
- >(RWByteAddressBuffer buffer, uint element, uint stride)
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
- __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ __store<matrixLayout>(buffer, element, stride);
}
[require(cooperative_matrix)]
- internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ internal void __store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22883,20 +23122,19 @@ struct CoopMat
// Load
//
- [__NoSideEffect]
- [require(cooperative_matrix)]
- static This Load<
- let matrixLayout : CoopMatMatrixLayout
- >(ByteAddressBuffer buffer, uint element, uint stride)
+${{{{
+ for (const char* RW : { "", "RW" })
{
- return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
- }
+}}}}
[__NoSideEffect]
[require(cooperative_matrix)]
static This Load<
let matrixLayout : CoopMatMatrixLayout
- >(RWByteAddressBuffer buffer, uint element, uint stride)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
@@ -22905,7 +23143,10 @@ struct CoopMat
[require(cooperative_matrix)]
static This Load<
let matrixLayout : CoopMatMatrixLayout
- >(StructuredBuffer<T> buffer, uint element, uint stride)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22917,21 +23158,9 @@ struct CoopMat
};
}
- [__NoSideEffect]
- [require(cooperative_matrix)]
- static This Load<
- let matrixLayout : CoopMatMatrixLayout
- >(RWStructuredBuffer<T> buffer, uint element, uint stride)
- {
- let zero = 0;
- let alignment = 16U;
- return spirv_asm
- {
- %storagePointerType = OpTypePointer StorageBuffer $$T;
- %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
- result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
- };
- }
+${{{{
+ } // RW
+}}}}
[ForceInline]
[__NoSideEffect]
@@ -23076,6 +23305,340 @@ struct CoopMat
return true;
}
+ //
+ // Load with TensorLayout and TensorView
+ //
+
+${{{{
+ StringBuilder tensorViewTypes;
+ StringBuilder tensorViewParams;
+ for (int j = 0; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ tensorViewTypes << ", let p" << j << " : uint32_t = " << kMaxCoopMatTensorDimension;
+ tensorViewParams << ", p" << j;
+ }
+
+ for (const char* RW : { "", "RW" })
+ {
+}}}}
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ return __loadLayout<Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ return __loadLayout<Dim, ClampMode>(buffer, element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This __loadLayout<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment None;
+ };
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView)
+ {
+ return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This __loadView<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView $tensorView;
+ };
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadLayoutDecode<U, Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadLayoutDecode<U, Dim, ClampMode>(buffer, element, tensorLayout, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This __loadLayoutDecode<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixBlockLoadsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment DecodeFunc $decodeFunc;
+ };
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This __loadViewDecode<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpCapability CooperativeMatrixBlockLoadsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView|DecodeFunc $tensorView $decodeFunc;
+ };
+ }
+
+${{{{
+ } // RW
+}}}}
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ RWByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment None;
+ };
+ }
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ RWByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment TensorView $tensorView;
+ };
+ }
+ }
+
} // struct CoopMat
@@ -23098,7 +23661,7 @@ CoopMat<T, S, M, N, R> coopMatLoad<
uint element,
uint stride)
{
- return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[ForceInline]
@@ -23115,7 +23678,7 @@ CoopMat<T, S, M, N, R> coopMatLoad<
uint element,
uint stride)
{
- return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[ForceInline]
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index b909fa0f9..937584908 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -567,25 +567,25 @@ def SPV_GOOGLE_user_type : _spirv_1_0;
/// [EXT]
def SPV_EXT_replicated_composites : _spirv_1_0;
-/// Represents the SPIR-V extension for SPV_NV_cooperative_vector.
+/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
/// [EXT]
-def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
+def SPV_KHR_vulkan_memory_model : _spirv_1_3;
-/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+/// Represents the SPIR-V extension for SPV_NV_cooperative_vector.
/// [EXT]
-def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites + SPV_KHR_vulkan_memory_model;
-/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
+/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
/// [EXT]
-def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix;
+def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer + SPV_KHR_vulkan_memory_model;
/// Represents the SPIR-V extension for SPV_NV_tensor_addressing.
/// [EXT]
def SPV_NV_tensor_addressing : _spirv_1_6;
-/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
/// [EXT]
-def SPV_KHR_vulkan_memory_model : _spirv_1_3;
+def SPV_NV_cooperative_matrix2 : SPV_NV_tensor_addressing + SPV_KHR_cooperative_matrix;
// SPIRV Capabilities.
@@ -1183,6 +1183,9 @@ alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV;
/// Capabilities needed to use tensor addressing
/// [Compound]
alias tensor_addressing = spvTensorAddressingNV;
+/// Capabilities needed to use tensor addressing
+/// [Compound]
+alias cooperative_matrix_2 = spvCooperativeMatrixKHR + spvCooperativeMatrixReductionsNV + spvCooperativeMatrixConversionsNV + spvCooperativeMatrixPerElementOperationsNV + spvCooperativeMatrixTensorAddressingNV + spvCooperativeMatrixBlockLoadsNV + spvTensorAddressingNV;
// Non-internal shader stages
//
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index b716e470d..81ad4bddf 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -174,6 +174,56 @@ SpvInst* emitOpTypeCoopMat(
matrixUse);
}
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV
+template<typename T1, typename T2>
+SpvInst* emitOpTypeTensorLayout(IRInst* inst, const T1& dim, const T2& clampMode)
+{
+ static_assert(isSingular<T1>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeTensorLayoutNV,
+ kResultID,
+ dim,
+ clampMode);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV
+template<typename T1, typename T2, typename... TPerms>
+SpvInst* emitOpTypeTensorView(
+ IRInst* inst,
+ const T1& dim,
+ const T2& hasDimensions,
+ const TPerms&... perms)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeTensorViewNV,
+ kResultID,
+ dim,
+ hasDimensions,
+ perms...);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorLayoutNV
+template<typename T1>
+SpvInst* emitOpCreateTensorLayout(SpvInstParent* parent, IRInst* inst, const T1& idResultType)
+{
+ static_assert(isSingular<T1>);
+ return emitInst(parent, inst, SpvOpCreateTensorLayoutNV, idResultType, kResultID);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorViewNV
+template<typename T1>
+SpvInst* emitOpCreateTensorView(SpvInstParent* parent, IRInst* inst, const T1& idResultType)
+{
+ static_assert(isSingular<T1>);
+ return emitInst(parent, inst, SpvOpCreateTensorViewNV, idResultType, kResultID);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix
template<typename T>
SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 32d3ba7c3..7d202c7c1 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
builder.getIntType()));
}
+ case kIROp_TensorAddressingTensorLayoutType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorLayoutType = static_cast<IRTensorAddressingTensorLayoutType*>(inst);
+ return emitOpTypeTensorLayout(
+ tensorLayoutType,
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getDimension())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getClampMode())->getValue(),
+ builder.getIntType()));
+ }
+ case kIROp_TensorAddressingTensorViewType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorViewType = static_cast<IRTensorAddressingTensorViewType*>(inst);
+
+ IRIntegerValue dim =
+ static_cast<IRIntLit*>(tensorViewType->getDimension())->getValue();
+ SpvInst* spvDim = emitIntConstant(dim, builder.getIntType());
+
+ SpvInst* spvHasDimension =
+ ensureInst(static_cast<IRBoolLit*>(tensorViewType->getHasDimension()));
+
+ SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr};
+ for (int i = 0; i < (int)dim; i++)
+ {
+ spvPermutations[i] = emitIntConstant(
+ static_cast<IRIntLit*>(tensorViewType->getPermutation(i))->getValue(),
+ builder.getIntType());
+ }
+
+ if (dim == 1)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0]);
+ }
+ else if (dim == 2)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1]);
+ }
+ else if (dim == 3)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2]);
+ }
+ else if (dim == 4)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3]);
+ }
+ else if (dim == 5)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3],
+ spvPermutations[4]);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unsupported tensor dimension");
+ }
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_MakeTensorAddressingTensorLayout:
+ result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
+ case kIROp_MakeTensorAddressingTensorView:
+ result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
case kIROp_Select:
result = emitInst(
parent,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f863858e4..e44954521 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -238,6 +238,10 @@ INST(RayQueryType, RayQuery, 1, HOISTABLE)
INST(HitObjectType, HitObject, 0, HOISTABLE)
INST(CoopVectorType, CoopVectorType, 2, HOISTABLE)
INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE)
+INST(TensorAddressingTensorLayoutType, TensorAddressingTensorLayoutType, 2, HOISTABLE)
+INST(TensorAddressingTensorViewType, TensorAddressingTensorViewType, 3, HOISTABLE)
+INST(MakeTensorAddressingTensorLayout, MakeTensorAddressingTensorLayout, 0, 0)
+INST(MakeTensorAddressingTensorView, MakeTensorAddressingTensorView, 0, 0)
// Opaque type that can be dynamically cast to other resource types.
INST(DynamicResourceType, DynamicResource, 0, HOISTABLE)
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 0287ae81a..3319aa89d 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1806,6 +1806,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
}
+
// Scan through the entry points and find the max version required.
auto processInst = [&](IRInst* globalInst)
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 461ed567a..9fd2263c1 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1882,6 +1882,23 @@ struct IRCoopMatrixType : IRType
IR_LEAF_ISA(CoopMatrixType)
};
+struct IRTensorAddressingTensorLayoutType : IRType
+{
+ IRInst* getDimension() { return getOperand(0); }
+ IRInst* getClampMode() { return getOperand(1); }
+
+ IR_LEAF_ISA(TensorAddressingTensorLayoutType)
+};
+
+struct IRTensorAddressingTensorViewType : IRType
+{
+ IRInst* getDimension() { return getOperand(0); }
+ IRInst* getHasDimension() { return getOperand(1); }
+ IRInst* getPermutation(int index) { return getOperand(2 + index); }
+
+ IR_LEAF_ISA(TensorAddressingTensorViewType)
+};
+
bool isDefinition(IRInst* inVal);
// A structure type is represented as a parent instruction,