summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/command-line-slangc-reference.md5
-rw-r--r--docs/user-guide/a3-02-reference-capability-atoms.md13
-rw-r--r--source/slang/hlsl.meta.slang625
-rw-r--r--source/slang/slang-capabilities.capdef19
-rw-r--r--source/slang/slang-emit-spirv-ops.h50
-rw-r--r--source/slang/slang-emit-spirv.cpp100
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp1
-rw-r--r--source/slang/slang-ir.h17
-rw-r--r--tests/cooperative-matrix/load-store-tensorlayout.slang88
-rw-r--r--tests/cooperative-matrix/load-store-tensorview.slang88
11 files changed, 964 insertions, 46 deletions
diff --git a/docs/command-line-slangc-reference.md b/docs/command-line-slangc-reference.md
index de1424af4..1562ce02f 100644
--- a/docs/command-line-slangc-reference.md
+++ b/docs/command-line-slangc-reference.md
@@ -1109,11 +1109,11 @@ A capability describes an optional feature that a target may or may not support.
* `SPV_KHR_compute_shader_derivatives` : enables the SPV_KHR_compute_shader_derivatives extension
* `SPV_GOOGLE_user_type` : enables the SPV_GOOGLE_user_type extension
* `SPV_EXT_replicated_composites` : enables the SPV_EXT_replicated_composites extension
+* `SPV_KHR_vulkan_memory_model` : enables the SPV_KHR_vulkan_memory_model extension
* `SPV_NV_cooperative_vector` : enables the SPV_NV_cooperative_vector extension
* `SPV_KHR_cooperative_matrix` : enables the SPV_KHR_cooperative_matrix extension
-* `SPV_NV_cooperative_matrix2` : enables the SPV_NV_cooperative_matrix2 extension
* `SPV_NV_tensor_addressing` : enables the SPV_NV_tensor_addressing extension
-* `SPV_KHR_vulkan_memory_model` : enables the SPV_KHR_vulkan_memory_model extension
+* `SPV_NV_cooperative_matrix2` : enables the SPV_NV_cooperative_matrix2 extension
* `spvAtomicFloat32AddEXT`
* `spvAtomicFloat16AddEXT`
* `spvAtomicFloat64AddEXT`
@@ -1292,6 +1292,7 @@ A capability describes an optional feature that a target may or may not support.
* `cooperative_matrix_tensor_addressing`
* `cooperative_matrix_block_load`
* `tensor_addressing`
+* `cooperative_matrix_2`
* `pixel`
* `tesscontrol`
* `tesseval`
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md
index a504f6f41..549d98ece 100644
--- a/docs/user-guide/a3-02-reference-capability-atoms.md
+++ b/docs/user-guide/a3-02-reference-capability-atoms.md
@@ -447,20 +447,20 @@ Extensions
`SPV_EXT_replicated_composites`
> Represents the SPIR-V extension for SPV_EXT_replicated_composites.
+`SPV_KHR_vulkan_memory_model`
+> Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+
`SPV_NV_cooperative_vector`
> Represents the SPIR-V extension for SPV_NV_cooperative_vector.
`SPV_KHR_cooperative_matrix`
> Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
-`SPV_NV_cooperative_matrix2`
-> Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
-
`SPV_NV_tensor_addressing`
> Represents the SPIR-V extension for SPV_NV_tensor_addressing.
-`SPV_KHR_vulkan_memory_model`
-> Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+`SPV_NV_cooperative_matrix2`
+> Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
`spvAtomicFloat32AddEXT`
> Represents the SPIR-V capability for atomic float 32 add operations.
@@ -990,6 +990,9 @@ Compound Capabilities
`tensor_addressing`
> Capabilities needed to use tensor addressing
+`cooperative_matrix_2`
+> Capabilities needed to use tensor addressing
+
`any_stage`
> Collection of all shader stages
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 829d5ce97..d7a65c031 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22539,6 +22539,243 @@ enum CoopMatClampMode
};
+${{{{
+// SPIRV described that the max value for `Dim` is 5.
+//
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV
+// OpTypeTensorLayoutNV:
+// Dim is the number of dimensions in the tensor layout, and must be a constant
+// instruction with scalar 32-bit integer type. The value must be greater than
+// zero and less than or equal to 5.
+//
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV
+// OpTypeTensorViewNV:
+// Dim is the number of dimensions in the tensor view, and must be a constant
+// instruction with scalar 32-bit integer type. The value must be greater than
+// zero and less than or equal to 5.
+//
+const int kMaxCoopMatTensorDimension = 5;
+}}}}
+
+//
+// TensorLayout
+//
+
+__intrinsic_type($(kIROp_TensorAddressingTensorLayoutType))
+[require(tensor_addressing)]
+__generic<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode = CoopMatClampMode.Undefined
+>
+struct TensorLayout
+{
+ __intrinsic_op($(kIROp_MakeTensorAddressingTensorLayout))
+ __init();
+};
+
+
+${{{{
+ for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim)
+ {
+ StringBuilder dimParams;
+ StringBuilder dimAsms;
+ StringBuilder strideParams;
+ StringBuilder strideAsms;
+ StringBuilder sliceParams;
+ StringBuilder sliceAsms;
+ StringBuilder blockSizeParams;
+ StringBuilder blockSizeAsms;
+ for (int j = 1; j < iDim; ++j)
+ {
+ dimParams << ", uint32_t dim" << j;
+ dimAsms << " $dim" << j;
+ strideParams << ", uint32_t stride" << j;
+ strideAsms << " $stride" << j;
+ sliceParams << ", uint32_t offset" << j << ", uint32_t span" << j;
+ sliceAsms << " $offset" << j << " $span" << j;
+ blockSizeParams << ", uint32_t blockSize" << j;
+ blockSizeAsms << " $blockSize" << j;
+ }
+}}}}
+
+
+extension<
+ let ClampMode : CoopMatClampMode
+> TensorLayout<$(iDim), ClampMode>
+{
+ [require(tensor_addressing)]
+ This Dimension(uint32_t dim0 $(dimParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetDimensionNV $this $dim0 $(dimAsms)
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Stride(uint32_t stride0 $(strideParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetStrideNV $this $stride0 $(strideAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Slice(uint32_t offset0, uint32_t span0 $(sliceParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSliceNV $this $offset0 $span0 $(sliceAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This ClampValue(CoopMatClampMode clampMode)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetClampValueNV $this $clampMode;
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This BlockSize(uint32_t blockSize0 $(blockSizeParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorLayoutSetBlockSizeNV $this $blockSize0 $(blockSizeAsms);
+ };
+ }
+ }
+};
+
+${{{{
+ } // iDim
+}}}}
+
+//
+// TensorView
+//
+
+${{{{
+ StringBuilder tensorViewStruct;
+ for (int j = 0; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ // Assigning the max value as a default value,
+ // because the max value is an invalid value and it allows us to check if the value
+ // is explicitly set by the user or not.
+ tensorViewStruct << ", let p" << j << " : uint32_t = 0xff";
+ }
+}}}}
+
+__intrinsic_type($(kIROp_TensorAddressingTensorViewType))
+__generic<
+ let Dim : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewStruct)
+>
+struct TensorView
+{
+ __intrinsic_op($(kIROp_MakeTensorAddressingTensorView))
+ __init();
+};
+
+${{{{
+ for (int iDim = 1; iDim < kMaxCoopMatTensorDimension; ++iDim)
+ {
+ StringBuilder tensorViewTypes;
+ StringBuilder tensorViewExtensions;
+ StringBuilder dimParams;
+ StringBuilder dimAsms;
+ StringBuilder strideParams;
+ StringBuilder strideAsms;
+ for (int j = 1; j < iDim; ++j)
+ {
+ tensorViewTypes << ", let Dim" << j << " : uint32_t";
+ tensorViewExtensions << ", Dim" << j;
+ dimParams << ", uint32_t dim" << j;
+ dimAsms << " $dim" << j;
+ strideParams << ", uint32_t stride" << j;
+ strideAsms << " $stride" << j;
+ }
+ for (int j = iDim; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ tensorViewExtensions << ", 0xff";
+ }
+}}}}
+
+[require(tensor_addressing)]
+extension<
+ let HasDimensions : bool,
+ let Dim0 : uint32_t
+ $(tensorViewTypes)
+> TensorView<$(iDim), HasDimensions, Dim0 $(tensorViewExtensions)>
+{
+ [require(tensor_addressing)]
+ This Dimension(uint32_t dim0 $(dimParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetDimensionNV $this $dim0 $(dimAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Stride(uint32_t stride0 $(strideParams))
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetStrideNV $this $stride0 $(strideAsms);
+ };
+ }
+ }
+
+ [require(tensor_addressing)]
+ This Clip(uint clipRowOffset, uint clipRowSpan, uint clipColOffset, uint clipColSpan)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$This = OpTensorViewSetClipNV $this $clipRowOffset $clipRowSpan $clipColOffset $clipColSpan;
+ };
+ }
+ }
+};
+
+${{{{
+ } // iDim
+}}}}
+
+
//
// Cooperative Matrix type
//
@@ -22788,21 +23025,23 @@ struct CoopMat
[require(cooperative_matrix)]
void Store<
let matrixLayout : CoopMatMatrixLayout
- >(RWStructuredBuffer<T> buffer, uint element, uint stride)
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
{
- __Store(buffer, element, stride, matrixLayout);
+ __store<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[require(cooperative_matrix)]
void Store<
let matrixLayout : CoopMatMatrixLayout
- >(RWByteAddressBuffer buffer, uint element, uint stride)
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
- __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ __store<matrixLayout>(buffer, element, stride);
}
[require(cooperative_matrix)]
- internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ internal void __store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22883,20 +23122,19 @@ struct CoopMat
// Load
//
- [__NoSideEffect]
- [require(cooperative_matrix)]
- static This Load<
- let matrixLayout : CoopMatMatrixLayout
- >(ByteAddressBuffer buffer, uint element, uint stride)
+${{{{
+ for (const char* RW : { "", "RW" })
{
- return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
- }
+}}}}
[__NoSideEffect]
[require(cooperative_matrix)]
static This Load<
let matrixLayout : CoopMatMatrixLayout
- >(RWByteAddressBuffer buffer, uint element, uint stride)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
@@ -22905,7 +23143,10 @@ struct CoopMat
[require(cooperative_matrix)]
static This Load<
let matrixLayout : CoopMatMatrixLayout
- >(StructuredBuffer<T> buffer, uint element, uint stride)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22917,21 +23158,9 @@ struct CoopMat
};
}
- [__NoSideEffect]
- [require(cooperative_matrix)]
- static This Load<
- let matrixLayout : CoopMatMatrixLayout
- >(RWStructuredBuffer<T> buffer, uint element, uint stride)
- {
- let zero = 0;
- let alignment = 16U;
- return spirv_asm
- {
- %storagePointerType = OpTypePointer StorageBuffer $$T;
- %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
- result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
- };
- }
+${{{{
+ } // RW
+}}}}
[ForceInline]
[__NoSideEffect]
@@ -23076,6 +23305,340 @@ struct CoopMat
return true;
}
+ //
+ // Load with TensorLayout and TensorView
+ //
+
+${{{{
+ StringBuilder tensorViewTypes;
+ StringBuilder tensorViewParams;
+ for (int j = 0; j < kMaxCoopMatTensorDimension; ++j)
+ {
+ tensorViewTypes << ", let p" << j << " : uint32_t = " << kMaxCoopMatTensorDimension;
+ tensorViewParams << ", p" << j;
+ }
+
+ for (const char* RW : { "", "RW" })
+ {
+}}}}
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ return __loadLayout<Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ return __loadLayout<Dim, ClampMode>(buffer, element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This __loadLayout<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment None;
+ };
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This Load<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView)
+ {
+ return __loadView<Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ static This __loadView<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams) > tensorView)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView $tensorView;
+ };
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadLayoutDecode<U, Dim, ClampMode>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadLayoutDecode<U, Dim, ClampMode>(buffer, element, tensorLayout, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This __loadLayoutDecode<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixBlockLoadsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment DecodeFunc $decodeFunc;
+ };
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)ByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This Load<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ return __loadViewDecode<U, Dim, ClampMode, DimView, HasDimensions $(tensorViewParams)>(buffer, element, tensorLayout, tensorView, decodeFunc);
+ }
+
+ [require(cooperative_matrix_block_load)]
+ static This __loadViewDecode<
+ U,
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ $(RW)StructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView,
+ functype(U*, uint32_t[Dim], uint32_t[Dim]) -> T decodeFunc)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixLoadTensorNV
+ This ret;
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpCapability CooperativeMatrixBlockLoadsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$This = OpCooperativeMatrixLoadTensorNV %pointer $ret $tensorLayout Aligned !alignment TensorView|DecodeFunc $tensorView $decodeFunc;
+ };
+ }
+
+${{{{
+ } // RW
+}}}}
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ RWByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode
+ >(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment None;
+ };
+ }
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ RWByteAddressBuffer buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ Store(__getEquivalentStructuredBuffer<T>(buffer), element, tensorLayout, tensorView);
+ }
+
+ [require(cooperative_matrix_tensor_addressing)]
+ void Store<
+ let Dim : uint32_t,
+ let ClampMode : CoopMatClampMode,
+ let DimView : uint32_t,
+ let HasDimensions : bool
+ $(tensorViewTypes)
+ >(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ TensorLayout<Dim, ClampMode> tensorLayout,
+ TensorView<DimView, HasDimensions $(tensorViewParams)> tensorView)
+ {
+ let zero = 0;
+ let alignment = 16U;
+
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeMatrixTensorAddressingNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreTensorNV %pointer $this $tensorLayout Aligned !alignment TensorView $tensorView;
+ };
+ }
+ }
+
} // struct CoopMat
@@ -23098,7 +23661,7 @@ CoopMat<T, S, M, N, R> coopMatLoad<
uint element,
uint stride)
{
- return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[ForceInline]
@@ -23115,7 +23678,7 @@ CoopMat<T, S, M, N, R> coopMatLoad<
uint element,
uint stride)
{
- return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[ForceInline]
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index b909fa0f9..937584908 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -567,25 +567,25 @@ def SPV_GOOGLE_user_type : _spirv_1_0;
/// [EXT]
def SPV_EXT_replicated_composites : _spirv_1_0;
-/// Represents the SPIR-V extension for SPV_NV_cooperative_vector.
+/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
/// [EXT]
-def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
+def SPV_KHR_vulkan_memory_model : _spirv_1_3;
-/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+/// Represents the SPIR-V extension for SPV_NV_cooperative_vector.
/// [EXT]
-def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites + SPV_KHR_vulkan_memory_model;
-/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
+/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
/// [EXT]
-def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix;
+def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer + SPV_KHR_vulkan_memory_model;
/// Represents the SPIR-V extension for SPV_NV_tensor_addressing.
/// [EXT]
def SPV_NV_tensor_addressing : _spirv_1_6;
-/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
/// [EXT]
-def SPV_KHR_vulkan_memory_model : _spirv_1_3;
+def SPV_NV_cooperative_matrix2 : SPV_NV_tensor_addressing + SPV_KHR_cooperative_matrix;
// SPIRV Capabilities.
@@ -1183,6 +1183,9 @@ alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV;
/// Capabilities needed to use tensor addressing
/// [Compound]
alias tensor_addressing = spvTensorAddressingNV;
+/// Capabilities needed to use tensor addressing
+/// [Compound]
+alias cooperative_matrix_2 = spvCooperativeMatrixKHR + spvCooperativeMatrixReductionsNV + spvCooperativeMatrixConversionsNV + spvCooperativeMatrixPerElementOperationsNV + spvCooperativeMatrixTensorAddressingNV + spvCooperativeMatrixBlockLoadsNV + spvTensorAddressingNV;
// Non-internal shader stages
//
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index b716e470d..81ad4bddf 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -174,6 +174,56 @@ SpvInst* emitOpTypeCoopMat(
matrixUse);
}
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorLayoutNV
+template<typename T1, typename T2>
+SpvInst* emitOpTypeTensorLayout(IRInst* inst, const T1& dim, const T2& clampMode)
+{
+ static_assert(isSingular<T1>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeTensorLayoutNV,
+ kResultID,
+ dim,
+ clampMode);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpTypeTensorViewNV
+template<typename T1, typename T2, typename... TPerms>
+SpvInst* emitOpTypeTensorView(
+ IRInst* inst,
+ const T1& dim,
+ const T2& hasDimensions,
+ const TPerms&... perms)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeTensorViewNV,
+ kResultID,
+ dim,
+ hasDimensions,
+ perms...);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorLayoutNV
+template<typename T1>
+SpvInst* emitOpCreateTensorLayout(SpvInstParent* parent, IRInst* inst, const T1& idResultType)
+{
+ static_assert(isSingular<T1>);
+ return emitInst(parent, inst, SpvOpCreateTensorLayoutNV, idResultType, kResultID);
+}
+
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_tensor_addressing.html#OpCreateTensorViewNV
+template<typename T1>
+SpvInst* emitOpCreateTensorView(SpvInstParent* parent, IRInst* inst, const T1& idResultType)
+{
+ static_assert(isSingular<T1>);
+ return emitInst(parent, inst, SpvOpCreateTensorViewNV, idResultType, kResultID);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix
template<typename T>
SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 32d3ba7c3..7d202c7c1 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1854,6 +1854,100 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
builder.getIntType()));
}
+ case kIROp_TensorAddressingTensorLayoutType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorLayoutType = static_cast<IRTensorAddressingTensorLayoutType*>(inst);
+ return emitOpTypeTensorLayout(
+ tensorLayoutType,
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getDimension())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(tensorLayoutType->getClampMode())->getValue(),
+ builder.getIntType()));
+ }
+ case kIROp_TensorAddressingTensorViewType:
+ {
+ requireSPIRVCapability(SpvCapabilityTensorAddressingNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_tensor_addressing"));
+
+ IRBuilder builder(m_irModule);
+ auto tensorViewType = static_cast<IRTensorAddressingTensorViewType*>(inst);
+
+ IRIntegerValue dim =
+ static_cast<IRIntLit*>(tensorViewType->getDimension())->getValue();
+ SpvInst* spvDim = emitIntConstant(dim, builder.getIntType());
+
+ SpvInst* spvHasDimension =
+ ensureInst(static_cast<IRBoolLit*>(tensorViewType->getHasDimension()));
+
+ SpvInst* spvPermutations[5] = {nullptr, nullptr, nullptr, nullptr, nullptr};
+ for (int i = 0; i < (int)dim; i++)
+ {
+ spvPermutations[i] = emitIntConstant(
+ static_cast<IRIntLit*>(tensorViewType->getPermutation(i))->getValue(),
+ builder.getIntType());
+ }
+
+ if (dim == 1)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0]);
+ }
+ else if (dim == 2)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1]);
+ }
+ else if (dim == 3)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2]);
+ }
+ else if (dim == 4)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3]);
+ }
+ else if (dim == 5)
+ {
+ return emitOpTypeTensorView(
+ tensorViewType,
+ spvDim,
+ spvHasDimension,
+ spvPermutations[0],
+ spvPermutations[1],
+ spvPermutations[2],
+ spvPermutations[3],
+ spvPermutations[4]);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unsupported tensor dimension");
+ }
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -4070,6 +4164,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_MakeTensorAddressingTensorLayout:
+ result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
+ case kIROp_MakeTensorAddressingTensorView:
+ result = emitOpCreateTensorView(parent, inst, getID(ensureInst(inst->getDataType())));
+ break;
case kIROp_Select:
result = emitInst(
parent,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f863858e4..e44954521 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -238,6 +238,10 @@ INST(RayQueryType, RayQuery, 1, HOISTABLE)
INST(HitObjectType, HitObject, 0, HOISTABLE)
INST(CoopVectorType, CoopVectorType, 2, HOISTABLE)
INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE)
+INST(TensorAddressingTensorLayoutType, TensorAddressingTensorLayoutType, 2, HOISTABLE)
+INST(TensorAddressingTensorViewType, TensorAddressingTensorViewType, 3, HOISTABLE)
+INST(MakeTensorAddressingTensorLayout, MakeTensorAddressingTensorLayout, 0, 0)
+INST(MakeTensorAddressingTensorView, MakeTensorAddressingTensorView, 0, 0)
// Opaque type that can be dynamically cast to other resource types.
INST(DynamicResourceType, DynamicResource, 0, HOISTABLE)
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 0287ae81a..3319aa89d 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1806,6 +1806,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
}
+
// Scan through the entry points and find the max version required.
auto processInst = [&](IRInst* globalInst)
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 461ed567a..9fd2263c1 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1882,6 +1882,23 @@ struct IRCoopMatrixType : IRType
IR_LEAF_ISA(CoopMatrixType)
};
+struct IRTensorAddressingTensorLayoutType : IRType
+{
+ IRInst* getDimension() { return getOperand(0); }
+ IRInst* getClampMode() { return getOperand(1); }
+
+ IR_LEAF_ISA(TensorAddressingTensorLayoutType)
+};
+
+struct IRTensorAddressingTensorViewType : IRType
+{
+ IRInst* getDimension() { return getOperand(0); }
+ IRInst* getHasDimension() { return getOperand(1); }
+ IRInst* getPermutation(int index) { return getOperand(2 + index); }
+
+ IR_LEAF_ISA(TensorAddressingTensorViewType)
+};
+
bool isDefinition(IRInst* inVal);
// A structure type is represented as a parent instruction,
diff --git a/tests/cooperative-matrix/load-store-tensorlayout.slang b/tests/cooperative-matrix/load-store-tensorlayout.slang
new file mode 100644
index 000000000..849c85c0e
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-tensorlayout.slang
@@ -0,0 +1,88 @@
+//TEST(compute):SIMPLE(filecheck=SPIRV):-target spirv-asm -entry computeMain -stage compute -skip-spirv-validation
+//TEST(compute):SIMPLE(filecheck=SPIRV_BL):-target spirv-asm -entry computeMain -stage compute -skip-spirv-validation -DBLOCK_LOAD
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-tensor-addressing
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-tensor-addressing -Xslang -DRW
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-block-loads -Xslang -DBLOCK_LOAD
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-block-loads -Xslang -DBLOCK_LOAD -Xslang -DRW
+
+//CHECK: 0
+//CHECK-NEXT: 0
+//CHECK-NEXT: 0
+//CHECK-NEXT: 0
+//CHECK-NEXT: 5
+//CHECK-NEXT: 6
+//CHECK-NEXT: 0
+//CHECK-NEXT: 0
+//CHECK-NEXT: 9
+
+//CHECK_BL: 0
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: 7
+//CHECK_BL-NEXT: C
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: C
+//CHECK_BL-NEXT: 11
+//CHECK_BL-NEXT: 0
+//CHECK_BL-NEXT: 0
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24], stride=4, count=256),name=buf
+
+#if defined(RW)
+ RWByteAddressBuffer inputBuffer;
+#else // #if defined(RW)
+ ByteAddressBuffer inputBuffer;
+#endif // #else // #if defined(RW)
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWByteAddressBuffer outputBuffer;
+
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
+
+int32_t decodeFunc(uint32_t* encoded, uint32_t blockCoord[2], uint32_t coordInBlock[2])
+{
+ uint32_t coord = blockCoord[1] * 4 + blockCoord[0];
+ uint32_t mask = (0xff << (coordInBlock[0] * 8));
+ return int32_t(encoded[coord] & mask) + 1;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ //SPIRV: = OpCreateTensorLayoutNV %
+ TensorLayout<2, CoopMatClampMode.Undefined> tl;
+
+ //SPIRV: = OpTensorLayoutSetDimensionNV %
+ let tl1 = tl.Dimension(32, 16);
+
+ //SPIRV: = OpTensorLayoutSetStrideNV %
+ let tl2 = tl1.Stride(4, 1);
+
+ //SPIRV: = OpTensorLayoutSliceNV %
+ let tl3 = tl2.Slice(4, 24, 0, 16);
+
+ //SPIRV: = OpTensorLayoutSetClampValueNV %
+ let tl4 = tl3.ClampValue(CoopMatClampMode.Repeat);
+
+ //SPIRV: = OpTensorLayoutSetBlockSizeNV %
+ let tl5 = tl4.BlockSize(4, 8);
+
+#if defined(BLOCK_LOAD)
+ //SPIRV_BL: = OpCooperativeMatrixLoadTensorNV %{{.*}} DecodeFunc %
+ let mat = CoopMatType.Load<uint32_t>(inputBuffer, 0, tl5, decodeFunc);
+
+#else // #if defined(BLOCK_LOAD)
+ //SPIRV: = OpCooperativeMatrixLoadTensorNV %{{.*}} None
+ let mat = CoopMatType.Load(inputBuffer, 0, tl5);
+
+#endif // #else // #if defined(BLOCK_LOAD)
+
+ //SPIRV:OpCooperativeMatrixStoreTensorNV %{{.*}} None
+ mat.Store(outputBuffer, 0, tl5);
+}
diff --git a/tests/cooperative-matrix/load-store-tensorview.slang b/tests/cooperative-matrix/load-store-tensorview.slang
new file mode 100644
index 000000000..6757338be
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-tensorview.slang
@@ -0,0 +1,88 @@
+//TEST(compute):SIMPLE(filecheck=SPIRV):-target spirv-asm -entry computeMain -stage compute -skip-spirv-validation
+//TEST(compute):SIMPLE(filecheck=SPIRV_BL):-target spirv-asm -entry computeMain -stage compute -skip-spirv-validation -DBLOCK_LOAD
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-tensor-addressing
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-tensor-addressing -Xslang -DRW
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -Xslang -DBLOCK_LOAD -render-feature cooperative-matrix-block-loads -skip-spirv-validation
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -Xslang -DBLOCK_LOAD -render-feature cooperative-matrix-block-loads -skip-spirv-validation -Xslang -DRW
+
+//CHECK: 2
+//CHECK-NEXT: 2
+//CHECK-NEXT: 2
+//CHECK-NEXT: 2
+//CHECK-NEXT: 12
+//CHECK-NEXT: 12
+//CHECK-NEXT: 12
+//CHECK-NEXT: 12
+//CHECK-NEXT: 0
+
+//CHECK_BL: 7
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 18
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 1
+//CHECK_BL-NEXT: 0
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24], stride=4, count=256),name=buf
+
+#if defined(RW)
+ RWByteAddressBuffer inputBuffer;
+#else // #if defined(RW)
+ ByteAddressBuffer inputBuffer;
+#endif // #else // #if defined(RW)
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWByteAddressBuffer outputBuffer;
+
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
+
+int32_t decodeFunc(uint32_t* encoded, uint32_t blockCoord[2], uint32_t coordInBlock[2])
+{
+ uint32_t coord = blockCoord[1] * 4 + blockCoord[0];
+ uint32_t mask = (0xff << (coordInBlock[0] * 8));
+ return int32_t(encoded[coord] & mask) + 1;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ TensorLayout<2, CoopMatClampMode.Undefined> tl;
+
+ let tl1 = tl.Dimension(16, 16);
+ let tl2 = tl1.Slice(0, 16, 0, 16);
+ let tl3 = tl2.BlockSize(4, 1);
+
+ //SPIRV: = OpTypeTensorViewNV %{{[^%]*}} %false %{{[^%]*}} %{{[^%]*$}}
+ //SPIRV: = OpCreateTensorViewNV %
+ TensorView<2, false, 0, 1> tvRowMajor;
+ TensorView<2, false, 1, 0> tvColumnMajor;
+
+ //SPIRV: = OpTensorViewSetDimensionNV %
+ let tvColumnMajor1 = tvColumnMajor.Dimension(16, 8);
+
+ //SPIRV: = OpTensorViewSetStrideNV %
+ let tvColumnMajor2 = tvColumnMajor1.Stride(8, 1);
+
+ //SPIRV: = OpTensorViewSetClipNV %
+ let tvColumnMajor3 = tvColumnMajor2.Clip(0, 8, 0, 64);
+
+#if defined(BLOCK_LOAD)
+ //SPIRV_BL: = OpCooperativeMatrixLoadTensorNV %{{.*}} TensorView|DecodeFunc %
+ let mat = CoopMatType.Load<uint32_t>(inputBuffer, 0, tl3, tvRowMajor, decodeFunc);
+
+#else // #if defined(BLOCK_LOAD)
+ //SPIRV: OpCooperativeMatrixLoadTensorNV %
+ //SPIRV-SAME: TensorView
+ let mat = CoopMatType.Load(inputBuffer, 0, tl3, tvRowMajor);
+
+#endif // #else // #if defined(BLOCK_LOAD)
+
+ //SPIRV: OpCooperativeMatrixStoreTensorNV {{.*}} TensorView %
+ mat.Store(outputBuffer, 0, tl3, tvColumnMajor3);
+}