summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-15 03:47:43 +0000
committerGitHub <noreply@github.com>2025-05-14 20:47:43 -0700
commit2580bb02f7a079ab1c0106b5960a21ed1627bca0 (patch)
treec4fe31e6314f514c9bb079d0fa15ee53adf7396f
parentb4d3d3017640581c21b52a12413d3f074ab1c5c1 (diff)
Add new coopmat2 functions: Reduce and Transpose (#7027)
This commit adds three new functions for CoopMat as described in the proposal document, Cooperative matrix 2 proposal spec#12 The new functions are: CoopMat<T,S,M,N,R>::Transpose CoopMat<T,S,M,N,R>::ReduceRow CoopMat<T,S,M,N,R>::ReduceColumn CoopMat<T,S,M,N,R>::ReduceRowAndColumn CoopMat<T,S,M,N,R>::Reduce2x2
-rw-r--r--docs/command-line-slangc-reference.md15
-rw-r--r--docs/user-guide/a3-02-reference-capability-atoms.md45
-rw-r--r--source/slang/core.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang537
-rw-r--r--source/slang/slang-capabilities.capdef54
-rw-r--r--source/slang/slang-emit-spirv-ops.h1
-rw-r--r--tests/cooperative-matrix/add.slang10
-rw-r--r--tests/cooperative-matrix/array.slang12
-rw-r--r--tests/cooperative-matrix/comparison.slang8
-rw-r--r--tests/cooperative-matrix/conversion.slang13
-rw-r--r--tests/cooperative-matrix/copyFrom.slang8
-rw-r--r--tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang20
-rw-r--r--tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang20
-rw-r--r--tests/cooperative-matrix/div.slang10
-rw-r--r--tests/cooperative-matrix/fill.slang6
-rw-r--r--tests/cooperative-matrix/inout.slang8
-rw-r--r--tests/cooperative-matrix/length.slang21
-rw-r--r--tests/cooperative-matrix/load-store-arbitrary-array-vec.slang14
-rw-r--r--tests/cooperative-matrix/load-store-arbitrary-array.slang13
-rw-r--r--tests/cooperative-matrix/load-store-groupshared.slang13
-rw-r--r--tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang13
-rw-r--r--tests/cooperative-matrix/load-store-rwstructuredbuffer.slang13
-rw-r--r--tests/cooperative-matrix/map-element-single.slang50
-rw-r--r--tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang67
-rw-r--r--tests/cooperative-matrix/mat-mul-add.slang12
-rw-r--r--tests/cooperative-matrix/mod.slang21
-rw-r--r--tests/cooperative-matrix/mul.slang10
-rw-r--r--tests/cooperative-matrix/out.slang8
-rw-r--r--tests/cooperative-matrix/parameter.slang8
-rw-r--r--tests/cooperative-matrix/reduce.slang52
-rw-r--r--tests/cooperative-matrix/return.slang8
-rw-r--r--tests/cooperative-matrix/scalar-mul.slang8
-rw-r--r--tests/cooperative-matrix/struct.slang12
-rw-r--r--tests/cooperative-matrix/sub.slang10
-rw-r--r--tests/cooperative-matrix/subscript-in-func.slang8
-rw-r--r--tests/cooperative-matrix/subscript.slang6
-rw-r--r--tests/cooperative-matrix/transpose.slang36
-rw-r--r--tests/cooperative-matrix/unary_neg.slang8
38 files changed, 889 insertions, 301 deletions
diff --git a/docs/command-line-slangc-reference.md b/docs/command-line-slangc-reference.md
index f296a6f02..de1424af4 100644
--- a/docs/command-line-slangc-reference.md
+++ b/docs/command-line-slangc-reference.md
@@ -1111,6 +1111,9 @@ A capability describes an optional feature that a target may or may not support.
* `SPV_EXT_replicated_composites` : enables the SPV_EXT_replicated_composites extension
* `SPV_NV_cooperative_vector` : enables the SPV_NV_cooperative_vector extension
* `SPV_KHR_cooperative_matrix` : enables the SPV_KHR_cooperative_matrix extension
+* `SPV_NV_cooperative_matrix2` : enables the SPV_NV_cooperative_matrix2 extension
+* `SPV_NV_tensor_addressing` : enables the SPV_NV_tensor_addressing extension
+* `SPV_KHR_vulkan_memory_model` : enables the SPV_KHR_vulkan_memory_model extension
* `spvAtomicFloat32AddEXT`
* `spvAtomicFloat16AddEXT`
* `spvAtomicFloat64AddEXT`
@@ -1151,6 +1154,12 @@ A capability describes an optional feature that a target may or may not support.
* `spvCooperativeVectorNV`
* `spvCooperativeVectorTrainingNV`
* `spvCooperativeMatrixKHR`
+* `spvCooperativeMatrixReductionsNV`
+* `spvCooperativeMatrixConversionsNV`
+* `spvCooperativeMatrixPerElementOperationsNV`
+* `spvCooperativeMatrixTensorAddressingNV`
+* `spvCooperativeMatrixBlockLoadsNV`
+* `spvTensorAddressingNV`
* `spvMaximalReconvergenceKHR`
* `spvQuadControlKHR`
* `metallib_latest`
@@ -1277,6 +1286,12 @@ A capability describes an optional feature that a target may or may not support.
* `cooperative_vector`
* `cooperative_vector_training`
* `cooperative_matrix`
+* `cooperative_matrix_reduction`
+* `cooperative_matrix_conversion`
+* `cooperative_matrix_map_element`
+* `cooperative_matrix_tensor_addressing`
+* `cooperative_matrix_block_load`
+* `tensor_addressing`
* `pixel`
* `tesscontrol`
* `tesseval`
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md
index 116701119..a504f6f41 100644
--- a/docs/user-guide/a3-02-reference-capability-atoms.md
+++ b/docs/user-guide/a3-02-reference-capability-atoms.md
@@ -453,6 +453,15 @@ Extensions
`SPV_KHR_cooperative_matrix`
> Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+`SPV_NV_cooperative_matrix2`
+> Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
+
+`SPV_NV_tensor_addressing`
+> Represents the SPIR-V extension for SPV_NV_tensor_addressing.
+
+`SPV_KHR_vulkan_memory_model`
+> Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+
`spvAtomicFloat32AddEXT`
> Represents the SPIR-V capability for atomic float 32 add operations.
@@ -573,6 +582,24 @@ Extensions
`spvCooperativeMatrixKHR`
> Represents the SPIR-V capability for cooperative matrices
+`spvCooperativeMatrixReductionsNV`
+> Represents the SPIR-V capability for cooperative matrix 2
+
+`spvCooperativeMatrixConversionsNV`
+> Represents the SPIR-V capability for cooperative matrix 2
+
+`spvCooperativeMatrixPerElementOperationsNV`
+> Represents the SPIR-V capability for cooperative matrix 2
+
+`spvCooperativeMatrixTensorAddressingNV`
+> Represents the SPIR-V capability for cooperative matrix 2
+
+`spvCooperativeMatrixBlockLoadsNV`
+> Represents the SPIR-V capability for cooperative matrix 2
+
+`spvTensorAddressingNV`
+> Represents the SPIR-V capability for tensor addressing
+
`spvMaximalReconvergenceKHR`
> Represents the SPIR-V capability for maximal reconvergence.
@@ -945,6 +972,24 @@ Compound Capabilities
`cooperative_vector_training`
> Capabilities needed to train cooperative vectors
+`cooperative_matrix_reduction`
+> Capabilities needed to use reduction operations with cooperative matrix
+
+`cooperative_matrix_conversion`
+> Capabilities needed to convert cooperative matrices
+
+`cooperative_matrix_map_element`
+> Capabilities needed to use MapElement operation with cooperative matrix
+
+`cooperative_matrix_tensor_addressing`
+> Capabilities needed to load or store with tensor_addressing extension
+
+`cooperative_matrix_block_load`
+> Capabilities needed to use decodeFunc with cooperative matrix load
+
+`tensor_addressing`
+> Capabilities needed to use tensor addressing
+
`any_stage`
> Collection of all shader stages
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 5911c997c..140c9ba16 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3478,6 +3478,18 @@ enum MemoryOrder
SeqCst = $(kIRMemoryOrder_SeqCst),
}
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
+enum MemoryScope
+{
+ CrossDevice = 0,
+ Device = 1,
+ Workgroup = 2,
+ Subgroup = 3,
+ Invocation = 4,
+ QueueFamily = 5,
+ ShaderCallKHR = 6,
+};
+
/// Represents types that can be used in any atomic operations.
/// Implemented by builtin scalar types: `int`, `uint`, `int64_t`, `uint64_t`, `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `float`, `double` and `half`.
[sealed] interface IAtomicable {}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 2382d4a9a..829d5ce97 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22509,13 +22509,52 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; }
}
+namespace linalg
+{
+
+//
+// Cooperative Matrix enums
+//
+
+enum CoopMatMatrixUse
+{
+ MatrixA = 0,
+ MatrixB = 1,
+ MatrixAccumulator = 2,
+};
+
+enum CoopMatMatrixLayout
+{
+ RowMajor = 0,
+ ColumnMajor = 1,
+};
+
+enum CoopMatClampMode
+{
+ Undefined,
+ Constant,
+ ClampToEdge,
+ Repeat,
+ RepeatMirrored
+};
+
+
//
// Cooperative Matrix type
//
__intrinsic_type($(kIROp_CoopMatrixType))
[require(cooperative_matrix)]
-struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic
+__generic<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>
+struct CoopMat
+ : IArray<T>
+ , IArithmetic
{
//
// Initialization
@@ -22535,21 +22574,25 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [require(cooperative_matrix)]
- __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ [require(cooperative_matrix_conversion)]
+ __init<
+ U : __BuiltinArithmeticType
+ >(CoopMat<U, S, M, N, R> other)
{
this.copyFrom(other);
}
[ForceInline]
+ [require(cooperative_matrix)]
__init(This x)
{
this = x;
}
// Required for `IArithmetic`.
- [OverloadRank(-10)]
[ForceInline]
+ [OverloadRank(-10)]
+ [require(cooperative_matrix)]
__init(int i)
{
this = CoopMat<T, S, M, N, R>(T(i));
@@ -22559,9 +22602,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
// Simple setters
//
- [require(cooperative_matrix)]
- [mutating]
+ /// Fills the cooperative matrix with the specified value.
+ /// @param t The value to fill the matrix with.
[ForceInline]
+ [mutating]
+ [require(cooperative_matrix)]
void fill(T t)
{
this = spirv_asm
@@ -22570,10 +22615,15 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [require(cooperative_matrix)]
- [mutating]
+ /// Copies the contents from another cooperative matrix into this matrix.
+ /// @param U The element type of the source cooperative matrix.
+ /// @param other The source cooperative matrix to copy from.
[ForceInline]
- void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ [mutating]
+ [require(cooperative_matrix_conversion)]
+ void copyFrom<
+ U : __BuiltinArithmeticType
+ >(CoopMat<U, S, M, N, R> other)
{
if (__isFloat<T>() && __isInt<U>())
this = __int_to_float_cast<T>(other);
@@ -22598,25 +22648,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[__NoSideEffect]
Ref<T> __indexRef(int index);
+ /// Returns the count as an integer value.
[ForceInline]
+ [require(cooperative_matrix)]
[__NoSideEffect]
int getCount()
{
- return getLength();
- }
-
- [ForceInline]
- [__NoSideEffect]
- int getRowCount()
- {
- return M;
- }
-
- [ForceInline]
- [__NoSideEffect]
- int getColumnCount()
- {
- return N;
+ return GetLength();
}
__subscript(int index) -> T
@@ -22635,14 +22673,111 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
}
- /// Returns the number of components owned by each invocation.
+ //
+ // CoopMat operations
+ //
+
+ /// Returns the number of elements for the current thread.
+ /// Depending on the number of threads for the given matrix, each
+ /// thread will get smaller length.
+ ///
+ /// @remarks The return value is unlikely to be same to M * N.
[ForceInline]
[require(cooperative_matrix)]
- uint getLength()
+ static uint GetLength()
+ {
+ return spirv_asm
+ {
+ result:$$uint = OpCooperativeMatrixLengthKHR $$CoopMat<T, S, M, N, R>;
+ };
+ }
+
+ /// Returns the number of rows in the matrix.
+ [ForceInline]
+ [__NoSideEffect]
+ static int GetRowCount()
+ {
+ return M;
+ }
+
+ /// Returns the number of columns in the matrix.
+ [ForceInline]
+ [__NoSideEffect]
+ static int GetColumnCount()
+ {
+ return N;
+ }
+
+ [require(cooperative_matrix_conversion)]
+ CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> Transpose()
{
return spirv_asm
{
- result:$$uint = OpCooperativeMatrixLengthKHR $$This;
+ OpCapability CooperativeMatrixConversionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> = OpCooperativeMatrixTransposeNV $this;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRow<
+ let RN : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> ReduceColumn<
+ let RM : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Column $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRowAndColumn<
+ let RM : int,
+ let RN : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row|Column $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> Reduce2x2(functype(T, T)->T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this 0x4 $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_map_element)]
+ This MapElement(functype(uint32_t, uint32_t, T)->T mapOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixPerElementOperationsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixPerElementOpNV $this $mapOp;
};
}
@@ -22650,16 +22785,24 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
// Store
//
- [ForceInline]
[require(cooperative_matrix)]
- void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
- return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ __Store(buffer, element, stride, matrixLayout);
+ }
+
+ [require(cooperative_matrix)]
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
+ {
+ __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
}
- [ForceInline]
[require(cooperative_matrix)]
- void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
{
let zero = 0;
let alignment = 16U;
@@ -22671,10 +22814,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(T* buffer, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22684,9 +22827,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [require(cooperative_matrix)]
[ForceInline]
- void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ [require(cooperative_matrix)]
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ let V : int
+ >(__ref groupshared T[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22698,9 +22844,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [ForceInline]
[require(cooperative_matrix)]
- void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int
+ >(__ref groupshared U[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22712,9 +22861,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [ForceInline]
[require(cooperative_matrix)]
- void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int,
+ let L : int
+ >(__ref groupshared vector<U, L>[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22725,30 +22878,34 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
+
//
// Load
//
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(ByteAddressBuffer buffer, uint element, uint stride)
{
- return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
{
- return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(StructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22761,9 +22918,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22775,10 +22933,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [__NoSideEffect]
[ForceInline]
+ [__NoSideEffect]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(T* buffer, uint element, uint stride)
{
let alignment = 16;
return spirv_asm
@@ -22790,7 +22950,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ let V : int
+ >(__constref groupshared T[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22803,7 +22966,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int
+ >(__constref groupshared U[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22816,7 +22983,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int,
+ let L : int
+ >(__constref groupshared vector<U, L>[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22849,7 +23021,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
This mod(This other)
{
This ret;
- for (int i = 0; i < getLength(); ++i)
+ for (int i = 0; i < GetLength(); ++i)
{
ret[i] = this[i] % other[i];
}
@@ -22862,7 +23034,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool equals(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] != other[i])
{
@@ -22874,7 +23046,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool lessThan(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] < other[i])
{
@@ -22890,7 +23062,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool lessThanOrEquals(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] < other[i])
{
@@ -22903,7 +23075,9 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
return true;
}
-}
+
+} // struct CoopMat
+
//
// Convenience loading functions for cooperative matrices which infer the
@@ -22912,78 +23086,168 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ ByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ RWByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ StructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ T* buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout,
+ let U : int
+>(
+ __constref groupshared T[U] data,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(data, element, stride);
}
+
//
// Cooperative Matrix casting
//
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_IntCast))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __int_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_FloatCast))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __real_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_CastIntToFloat))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __int_to_float_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_CastFloatToInt))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val);
+CoopMat<T,S,M,N,R> __float_to_int_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
//
// Cooperative Matrix multiplication with scalar
//
-__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
-[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
+CoopMat<T, S, M, N, R> operator *<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<T, S, M, N, R> lhs, const T rhs)
{
return spirv_asm
{
@@ -22991,64 +23255,69 @@ CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
};
}
-__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
-[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs)
+CoopMat<T, S, M, N, R> operator *<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(const T lhs, CoopMat<T, S, M, N, R> rhs)
{
return rhs * lhs;
}
//
-// Cooperative Matrix enums
-//
-
-enum CoopMatScope
-{
- Device = 1,
- Workgroup = 2,
- Subgroup = 3,
- QueueFamily = 5,
-};
-
-enum CoopMatMatrixUse
-{
- MatrixA = 0,
- MatrixB = 1,
- MatrixAccumulator = 2,
-};
-
-enum CoopMatMatrixLayout
-{
- RowMajor = 0,
- ColumnMajor = 1,
-};
-
-enum CoopMatMatrixOperands
-{
- None = 0x0,
- MatrixASigned = 0x1,
- MatrixBSigned = 0x2,
- MatrixCSigned = 0x4,
- MatrixResultSigned = 0x8,
- SaturatingAccumulation = 0x10,
-};
-
-//
-// Cooperative Matrix multiply accumulate
+// Cooperative Matrix Multiply-Accumulate
//
[require(cooperative_matrix)]
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse>
-CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands)
+CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd<
+ T : __BuiltinArithmeticType,
+ let saturatingAccumulation : bool,
+ U : __BuiltinArithmeticType,
+ V : __BuiltinArithmeticType,
+ W : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let K : int,
+ let N : int
+>(
+ CoopMat<U, S, M, K, CoopMatMatrixUse.MatrixA> matA,
+ CoopMat<V, S, K, N, CoopMatMatrixUse.MatrixB> matB,
+ CoopMat<W, S, M, N, CoopMatMatrixUse.MatrixAccumulator> matC)
{
- static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`");
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
+ int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x01; // MatrixASignedComponentsKHR
+ }
+ if (__isSignedInt<V>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ }
+ if (__isSignedInt<W>())
+ {
+ operands |= 0x04; // MatrixCSignedComponentsKHR
+ }
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
+ if (saturatingAccumulation)
+ {
+ operands |= 0x10; // SaturatingAccumulationKHR
+ }
+
return spirv_asm
{
- result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
+ result:$$CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
};
}
+} // namespace linalg
+
//
// Cooperative Vector
//
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index f509835aa..b909fa0f9 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -575,6 +575,18 @@ def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
/// [EXT]
def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
+/// [EXT]
+def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix;
+
+/// Represents the SPIR-V extension for SPV_NV_tensor_addressing.
+/// [EXT]
+def SPV_NV_tensor_addressing : _spirv_1_6;
+
+/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+/// [EXT]
+def SPV_KHR_vulkan_memory_model : _spirv_1_3;
+
// SPIRV Capabilities.
/// Represents the SPIR-V capability for atomic float 32 add operations.
@@ -737,6 +749,30 @@ def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector;
/// [EXT]
def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix;
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixReductionsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixConversionsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixPerElementOperationsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixTensorAddressingNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixBlockLoadsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for tensor addressing
+/// [EXT]
+def spvTensorAddressingNV : SPV_NV_tensor_addressing;
+
/// Represents the SPIR-V capability for maximal reconvergence.
/// [EXT]
def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence;
@@ -1129,6 +1165,24 @@ alias cooperative_vector_training = spvCooperativeVectorTrainingNV;
/// Capabilities needed to use cooperative matrices
alias cooperative_matrix = spvCooperativeMatrixKHR;
+/// Capabilities needed to use reduction operations with cooperative matrix
+/// [Compound]
+alias cooperative_matrix_reduction = spvCooperativeMatrixReductionsNV;
+/// Capabilities needed to convert cooperative matrices
+/// [Compound]
+alias cooperative_matrix_conversion = spvCooperativeMatrixConversionsNV;
+/// Capabilities needed to use MapElement operation with cooperative matrix
+/// [Compound]
+alias cooperative_matrix_map_element = spvCooperativeMatrixPerElementOperationsNV;
+/// Capabilities needed to load or store with tensor_addressing extension
+/// [Compound]
+alias cooperative_matrix_tensor_addressing = spvCooperativeMatrixTensorAddressingNV;
+/// Capabilities needed to use decodeFunc with cooperative matrix load
+/// [Compound]
+alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV;
+/// Capabilities needed to use tensor addressing
+/// [Compound]
+alias tensor_addressing = spvTensorAddressingNV;
// Non-internal shader stages
//
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 017e58667..b716e470d 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -151,6 +151,7 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp
componentCount);
}
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_cooperative_matrix.html#OpTypeCooperativeMatrixNV
template<typename T1, typename T2>
SpvInst* emitOpTypeCoopMat(
IRInst* inst,
diff --git a/tests/cooperative-matrix/add.slang b/tests/cooperative-matrix/add.slang
index ecf2f16ba..c904da0f7 100644
--- a/tests/cooperative-matrix/add.slang
+++ b/tests/cooperative-matrix/add.slang
@@ -15,7 +15,9 @@ ByteAddressBuffer input1;
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4, count=256),name=input2
ByteAddressBuffer input2;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -23,10 +25,10 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
let result = mat1 + mat2;
- result.store(outputBuffer, 0, stride, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/array.slang b/tests/cooperative-matrix/array.slang
index ab1f92a99..eb26bacbf 100644
--- a/tests/cooperative-matrix/array.slang
+++ b/tests/cooperative-matrix/array.slang
@@ -19,7 +19,9 @@ ByteAddressBuffer input2;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -28,9 +30,9 @@ void computeMain()
let matrixLayout = CoopMatMatrixLayout::RowMajor;
CoopMatType coopMatArray[2];
- coopMatArray[0] = CoopMatType.load(input1, 0, stride, matrixLayout);
- coopMatArray[1] = CoopMatType.load(input2, 0, stride, matrixLayout);
+ coopMatArray[0] = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ coopMatArray[1] = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
- coopMatArray[0].store(outputBuffer, 0, stride, matrixLayout);
- coopMatArray[1].store(outputBuffer, 4, stride, matrixLayout);
+ coopMatArray[0].Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
+ coopMatArray[1].Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 4, stride);
}
diff --git a/tests/cooperative-matrix/comparison.slang b/tests/cooperative-matrix/comparison.slang
index ce99f1550..8b29876df 100644
--- a/tests/cooperative-matrix/comparison.slang
+++ b/tests/cooperative-matrix/comparison.slang
@@ -14,7 +14,9 @@ ByteAddressBuffer input2;
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain(uint3 threadIndex : SV_DispatchThreadID)
@@ -22,8 +24,8 @@ void computeMain(uint3 threadIndex : SV_DispatchThreadID)
let stride = 4;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
uint32_t equals = mat1 == mat2 ? 1 : 0;
uint32_t lessThan = mat1 < mat2 ? 1 : 0;
diff --git a/tests/cooperative-matrix/conversion.slang b/tests/cooperative-matrix/conversion.slang
index fbc422b7e..24fba3cc6 100644
--- a/tests/cooperative-matrix/conversion.slang
+++ b/tests/cooperative-matrix/conversion.slang
@@ -12,6 +12,7 @@ RWStructuredBuffer<float> outputBuffer;
//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input
ByteAddressBuffer input;
+using namespace linalg;
[numthreads(32, 1, 1)]
void computeMain()
@@ -19,12 +20,12 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let intMat = CoopMat<int, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>.load(input, 0, stride, matrixLayout);
- let floatMat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
- let uintMat = CoopMat<uint, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
- let halfMat = CoopMat<half, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(uintMat);
- let floatMat2 = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(halfMat);
+ let intMat = CoopMat<int, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>.Load<CoopMatMatrixLayout::RowMajor>(input, 0, stride);
+ let floatMat = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
+ let uintMat = CoopMat<uint, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
+ let halfMat = CoopMat<half, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(uintMat);
+ let floatMat2 = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(halfMat);
let result = floatMat + floatMat2;
- result.store(outputBuffer, 0, stride, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/copyFrom.slang b/tests/cooperative-matrix/copyFrom.slang
index 4be537489..9ab8acd4a 100644
--- a/tests/cooperative-matrix/copyFrom.slang
+++ b/tests/cooperative-matrix/copyFrom.slang
@@ -6,12 +6,14 @@
//TEST_INPUT:ubuffer(stride=4, count = 256):out,name=outputBuffer
RWStructuredBuffer<int32_t> outputBuffer;
+using namespace linalg;
+
[numthreads(32, 1, 1)]
void computeMain()
{
- let mat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(4.0);
- var result = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(0);
+ let mat = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(4.0);
+ var result = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(0);
result.copyFrom(mat);
- result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 16);
}
diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang
deleted file mode 100644
index 0c4308308..000000000
--- a/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang
+++ /dev/null
@@ -1,20 +0,0 @@
-//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
-
-RWStructuredBuffer<float> outputBuffer;
-
-typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
-typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Workgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
-typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
-
-// CHECK: error 39999: could not specialize generic for arguments of type
-
-[numthreads(32, 1, 1)]
-void computeMain()
-{
- let matA = CoopMatAType(3.0);
- let matB = CoopMatBType(5.0);
- let matC = CoopMatCType(1.0);
-
- const let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
- result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
-}
diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang
deleted file mode 100644
index 5b7dc7a5b..000000000
--- a/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang
+++ /dev/null
@@ -1,20 +0,0 @@
-//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
-
-RWStructuredBuffer<float> outputBuffer;
-
-typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
-typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
-typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
-
-// CHECK: error 41400: static assertion failed, matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`
-
-[numthreads(32, 1, 1)]
-void computeMain()
-{
- let matA = CoopMatAType(3.0);
- let matB = CoopMatBType(5.0);
- let matC = CoopMatCType(1.0);
-
- let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
- result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
-}
diff --git a/tests/cooperative-matrix/div.slang b/tests/cooperative-matrix/div.slang
index 4f697fb4e..17b237280 100644
--- a/tests/cooperative-matrix/div.slang
+++ b/tests/cooperative-matrix/div.slang
@@ -15,7 +15,9 @@ ByteAddressBuffer input1;
//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
ByteAddressBuffer input2;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -23,9 +25,9 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
let result = mat1 / mat2;
- result.store(outputBuffer, 0, stride, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/fill.slang b/tests/cooperative-matrix/fill.slang
index e5c7c4765..f9068577b 100644
--- a/tests/cooperative-matrix/fill.slang
+++ b/tests/cooperative-matrix/fill.slang
@@ -6,11 +6,13 @@
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<int32_t> outputBuffer;
+using namespace linalg;
+
[numthreads(32, 1, 1)]
void computeMain()
{
- var result : CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+ var result : CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
result.fill(10);
- result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 16);
}
diff --git a/tests/cooperative-matrix/inout.slang b/tests/cooperative-matrix/inout.slang
index 8e04a9f7a..3ff9a16d6 100644
--- a/tests/cooperative-matrix/inout.slang
+++ b/tests/cooperative-matrix/inout.slang
@@ -12,7 +12,9 @@ ByteAddressBuffer input;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
void doubleCoopMat(inout CoopMatType mat)
{
@@ -25,7 +27,7 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- var mat = CoopMatType.load(input, 0, stride, matrixLayout);
+ var mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input, 0, stride);
doubleCoopMat(mat);
- mat.store(outputBuffer, 0, stride, matrixLayout);
+ mat.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/length.slang b/tests/cooperative-matrix/length.slang
new file mode 100644
index 000000000..580b713f3
--- /dev/null
+++ b/tests/cooperative-matrix/length.slang
@@ -0,0 +1,21 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation
+
+// Note the length is NOT row * column.
+// When the memory scope is set to subgroup, each thread gets 16 * 16 / 32 = 8 where 32 is the value used in `numthreads`.
+
+//CHK:8
+
+//TEST_INPUT:ubuffer(stride=4, count=1):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+using namespace linalg;
+
+// It appears that only Subgroup is supposed at the moment
+typealias CoopMatSubgroup = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ outputBuffer[0] = CoopMatSubgroup.GetLength();
+}
+
diff --git a/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang
index 51af16ada..d34494fd8 100644
--- a/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang
+++ b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang
@@ -23,7 +23,9 @@ RWByteAddressBuffer input;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<uint32_t> outputBuffer;
-typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
groupshared float3[128] tempShared;
@@ -31,11 +33,9 @@ groupshared float3[128] tempShared;
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
-
- let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
- mat.storeAny(tempShared, 0, stride, matrixLayout);
+ let mat = coopMatLoad<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(input, 0, stride);
+ mat.Store<CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
- let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout);
- result.store(outputBuffer, 0, stride, matrixLayout);
+ let result = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/load-store-arbitrary-array.slang b/tests/cooperative-matrix/load-store-arbitrary-array.slang
index 160aad9da..8de4e84cc 100644
--- a/tests/cooperative-matrix/load-store-arbitrary-array.slang
+++ b/tests/cooperative-matrix/load-store-arbitrary-array.slang
@@ -23,7 +23,9 @@ RWByteAddressBuffer input;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<uint32_t> outputBuffer;
-typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
groupshared float[256] tempShared;
@@ -31,11 +33,10 @@ groupshared float[256] tempShared;
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
- mat.storeAny(tempShared, 0, stride, matrixLayout);
+ let mat = coopMatLoad<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(input, 0, stride);
+ mat.Store<CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
- let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout);
- result.store(outputBuffer, 0, stride, matrixLayout);
+ let result = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/load-store-groupshared.slang b/tests/cooperative-matrix/load-store-groupshared.slang
index 8d867abb7..db32e85fd 100644
--- a/tests/cooperative-matrix/load-store-groupshared.slang
+++ b/tests/cooperative-matrix/load-store-groupshared.slang
@@ -10,22 +10,23 @@
// CHECK-NEXT: 8
//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=input
-RWByteAddressBuffer input;
+ByteAddressBuffer input;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<uint32_t> outputBuffer;
+using namespace linalg;
+
groupshared uint32_t[256] tempShared;
[numthreads(32, 1, 1)]
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
- mat.store(tempShared, 0, stride, matrixLayout);
+ let mat = coopMatLoad<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(input, 0, stride);
+ mat.Store<CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
- let result = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(tempShared, 0, stride, matrixLayout);
- result.store(outputBuffer, 0, stride, matrixLayout);
+ let result = coopMatLoad<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(tempShared, 0, stride);
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang
index 6bba8331e..6894bdfe5 100644
--- a/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang
+++ b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang
@@ -1,4 +1,5 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -Xslang -DRWBAB
// CHECK: 1
// CHECK-NEXT: 2
@@ -10,17 +11,21 @@
// CHECK-NEXT: 8
//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=inputBuffer
+#if defined(RWBAB)
RWByteAddressBuffer inputBuffer;
+#else
+ByteAddressBuffer inputBuffer;
+#endif
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWByteAddressBuffer outputBuffer;
+using namespace linalg;
+
[numthreads(32, 1, 1)]
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
-
- let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout);
- mat.store(outputBuffer, 0, 16, matrixLayout);
+ let mat = coopMatLoad<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(inputBuffer, 0, stride);
+ mat.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, 16);
}
diff --git a/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang
index e161fb7b2..6a94fd30e 100644
--- a/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang
+++ b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang
@@ -1,4 +1,5 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -Xslang -DRWSB
// CHECK: type: int32_t
// CHECK-NEXT: 1
@@ -11,17 +12,21 @@
// CHECK-NEXT: 8
//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256),name=buf
+#if defined(RWSB)
RWStructuredBuffer<int32_t> inputBuffer;
+#else
+StructuredBuffer<int32_t> inputBuffer;
+#endif
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<int32_t> outputBuffer;
+using namespace linalg;
+
[numthreads(32, 1, 1)]
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
-
- let mat = coopMatLoad<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout);
- mat.store(outputBuffer, 0, stride, matrixLayout);
+ let mat = coopMatLoad<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(inputBuffer, 0, stride);
+ mat.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/map-element-single.slang b/tests/cooperative-matrix/map-element-single.slang
new file mode 100644
index 000000000..583630d14
--- /dev/null
+++ b/tests/cooperative-matrix/map-element-single.slang
@@ -0,0 +1,50 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -skip-spirv-validation -Xslang -DTEST_MODE=0
+//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -skip-spirv-validation -Xslang -DTEST_MODE=1
+//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -skip-spirv-validation -Xslang -DTEST_MODE=2
+
+//CHECK: type: int32_t
+//CHECK-NEXT: 2
+//CHECK-NEXT: 4
+//CHECK-NEXT: 6
+//CHECK-NEXT: 8
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4),name=input1
+StructuredBuffer<int> input1;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
+
+int MapOp(uint32_t row, uint32_t col, int value)
+{
+ return value * 2;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ CoopMatType mat1 = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(input1, 0, stride);
+
+ CoopMatType result;
+
+#if TEST_MODE == 0
+ result = mat1.MapElement(MapOp);
+
+#elif TEST_MODE == 1
+ // Lambda through IFunc.
+ // TODO: Not working due to issue #7024
+ IFunc<int, uint32_t, uint32_t, int> func = ((uint32_t row, uint32_t column, int value) => value * 2);
+ result = mat1.MapElement(func);
+
+#elif TEST_MODE == 2
+ // Directly use lambda.
+ // TODO: Not working due to issue #7024
+ result = mat1.MapElement((uint32_t row, uint32_t column, int value) => (int)(value));
+#endif
+
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
+}
diff --git a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang
index 7ad6c639f..a93efb4b9 100644
--- a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang
+++ b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang
@@ -1,45 +1,68 @@
//TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv -skip-spirv-validation
// This test checks that the correct SPIRV Cooperative Matrix Operands are emitted for OpCooperativeMatrixMulAddKHR operaions
-RWStructuredBuffer<int> outputBuffer1;
-RWStructuredBuffer<int> outputBuffer2;
-RWStructuredBuffer<int> outputBuffer3;
+RWStructuredBuffer<uint32_t> outputBuffer1;
+RWStructuredBuffer<uint32_t> outputBuffer2;
+RWStructuredBuffer<uint32_t> outputBuffer3;
RWStructuredBuffer<int> outputBuffer4;
-RWStructuredBuffer<int> outputBuffer5;
+RWStructuredBuffer<uint32_t> outputBuffer5;
+RWStructuredBuffer<int> outputBuffer6;
-typealias CoopMatAType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
-typealias CoopMatBType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
-typealias CoopMatCType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+__generic<T : __BuiltinArithmeticType>
+typealias CoopMatAType = CoopMat<T, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+
+__generic<T : __BuiltinArithmeticType>
+typealias CoopMatBType = CoopMat<T, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
+
+__generic<T : __BuiltinArithmeticType>
+typealias CoopMatCType = CoopMat<T, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
{
- let matA = CoopMatAType(3);
- let matB = CoopMatBType(5);
- let matC = CoopMatCType(1);
-
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
-
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} NoneKHR
- coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None).store(outputBuffer1, 0, 16, matrixLayout);
+ coopMatMulAdd<uint32_t, false>(
+ CoopMatAType<uint16_t>(2),
+ CoopMatBType<uint32_t>(3),
+ CoopMatCType<uint16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer1, 0, 16);
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR
- coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned).store(outputBuffer2, 0, 16, matrixLayout);
-
+ coopMatMulAdd<uint32_t, false>(
+ CoopMatAType<int16_t>(2),
+ CoopMatBType<int32_t>(3),
+ CoopMatCType<uint16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer2, 0, 16);
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixCSignedComponentsKHR
- coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixCSigned).store(outputBuffer2, 0, 16, matrixLayout);
-
+ coopMatMulAdd<uint32_t, false>(
+ CoopMatAType<uint16_t>(2),
+ CoopMatBType<uint32_t>(3),
+ CoopMatCType<int16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer3, 0, 16);
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixResultSignedComponentsKHR
- coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixResultSigned).store(outputBuffer3, 0, 16, matrixLayout);
+ coopMatMulAdd<int, false>(
+ CoopMatAType<uint16_t>(2),
+ CoopMatBType<uint32_t>(3),
+ CoopMatCType<uint16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer4, 0, 16);
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} SaturatingAccumulationKHR
- coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::SaturatingAccumulation).store(outputBuffer4, 0, 16, matrixLayout);
+ coopMatMulAdd<uint32_t, true>(
+ CoopMatAType<uint16_t>(2),
+ CoopMatBType<uint32_t>(3),
+ CoopMatCType<uint16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer5, 0, 16);
- let allOperands = CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned | CoopMatMatrixOperands::MatrixCSigned | CoopMatMatrixOperands::MatrixResultSigned | CoopMatMatrixOperands::SaturatingAccumulation;
// CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR|SaturatingAccumulationKHR
- coopMatMulAdd(matA, matB, matC, allOperands).store(outputBuffer5, 0, 16, matrixLayout);
+ coopMatMulAdd<int, true>(
+ CoopMatAType<int16_t>(2),
+ CoopMatBType<int32_t>(3),
+ CoopMatCType<int16_t>(4)
+ ).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer6, 0, 16);
}
diff --git a/tests/cooperative-matrix/mat-mul-add.slang b/tests/cooperative-matrix/mat-mul-add.slang
index 417f06d10..3dc472de7 100644
--- a/tests/cooperative-matrix/mat-mul-add.slang
+++ b/tests/cooperative-matrix/mat-mul-add.slang
@@ -6,9 +6,11 @@
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
-typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
-typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatAType = CoopMat<float16_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatBType = CoopMat<float16_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
+typealias CoopMatCType = CoopMat<float32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -18,6 +20,6 @@ void computeMain()
let matB = CoopMatBType(5.0);
let matC = CoopMatCType(1.0);
- let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
- result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+ let result = coopMatMulAdd<float, false>(matA, matB, matC);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 16);
}
diff --git a/tests/cooperative-matrix/mod.slang b/tests/cooperative-matrix/mod.slang
index 5167c8ef8..015dd227a 100644
--- a/tests/cooperative-matrix/mod.slang
+++ b/tests/cooperative-matrix/mod.slang
@@ -16,7 +16,7 @@
// CHECK-NEXT: 2
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
-RWByteAddressBuffer outputBuffer;
+RWStructuredBuffer<int> outputBuffer;
//TEST_INPUT:ubuffer(data=[4 3 5 7], stride=4, count=256),name=input1
ByteAddressBuffer input1;
@@ -24,18 +24,19 @@ ByteAddressBuffer input1;
//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
ByteAddressBuffer input2;
-typealias CoopMatIntType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
-typealias CoopMatUintType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
-typealias CoopMatFloatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatIntType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+typealias CoopMatUintType = CoopMat<uint32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+typealias CoopMatFloatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
{
let stride = 16;
- let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatIntType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatIntType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatIntType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatIntType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
let mat3 = CoopMatFloatType(mat1);
let mat4 = CoopMatFloatType(mat2);
@@ -47,7 +48,7 @@ void computeMain()
let result2 = CoopMatIntType(mat3 % mat4);
let result3 = CoopMatIntType(mat5 % mat6);
- result.store(outputBuffer, 0, stride, matrixLayout);
- result2.store(outputBuffer, 16, stride, matrixLayout);
- result3.store(outputBuffer, 32, stride, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
+ result2.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 16, stride);
+ result3.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 32, stride);
}
diff --git a/tests/cooperative-matrix/mul.slang b/tests/cooperative-matrix/mul.slang
index 0ac332698..4f5751c6e 100644
--- a/tests/cooperative-matrix/mul.slang
+++ b/tests/cooperative-matrix/mul.slang
@@ -15,7 +15,9 @@ ByteAddressBuffer input1;
//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
ByteAddressBuffer input2;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -23,9 +25,9 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
let result = mat1 * mat2;
- result.store(outputBuffer, 0, 4, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 4);
}
diff --git a/tests/cooperative-matrix/out.slang b/tests/cooperative-matrix/out.slang
index 5b342afc0..7d4a83fc2 100644
--- a/tests/cooperative-matrix/out.slang
+++ b/tests/cooperative-matrix/out.slang
@@ -13,7 +13,9 @@ StructuredBuffer<float> inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
void doubleCoopMat(CoopMatType mat, out CoopMatType result)
{
@@ -26,9 +28,9 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
CoopMatType result;
doubleCoopMat(mat, result);
- result.store(outputBuffer, 0, stride, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
diff --git a/tests/cooperative-matrix/parameter.slang b/tests/cooperative-matrix/parameter.slang
index 8a4bb3315..eb6823aa1 100644
--- a/tests/cooperative-matrix/parameter.slang
+++ b/tests/cooperative-matrix/parameter.slang
@@ -12,7 +12,9 @@ StructuredBuffer<float> inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
static let stride = 16;
//static let matrixLayout = CoopMatMatrixLayout::RowMajor;
@@ -21,12 +23,12 @@ static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor;
void processCoopMat(CoopMatType mat)
{
// XXX: hmmm, some error when matrixLAyout is static let
- (mat * 3.0).store(outputBuffer, 0, stride, matrixLayout);
+ (mat * 3.0).Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
[numthreads(32, 1, 1)]
void computeMain()
{
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
processCoopMat(mat);
}
diff --git a/tests/cooperative-matrix/reduce.slang b/tests/cooperative-matrix/reduce.slang
new file mode 100644
index 000000000..bbba587ce
--- /dev/null
+++ b/tests/cooperative-matrix/reduce.slang
@@ -0,0 +1,52 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK_ROW):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-reductions -skip-spirv-validation -Xslang -DTEST_MODE=0
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK_COLUMN):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-reductions -skip-spirv-validation -Xslang -DTEST_MODE=1
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK_ROW_AND_COLUMN):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-reductions -skip-spirv-validation -Xslang -DTEST_MODE=2
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK_2X2):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-reductions -skip-spirv-validation -Xslang -DTEST_MODE=3
+
+//CHK_ROW-COUNT-8: 36
+
+//CHK_COLUMN:1
+//CHK_COLUMN-NEXT:2
+//CHK_COLUMN-NEXT:3
+//CHK_COLUMN-NEXT:4
+//CHK_COLUMN-NEXT:5
+//CHK_COLUMN-NEXT:6
+//CHK_COLUMN-NEXT:7
+//CHK_COLUMN-NEXT:8
+
+//CHK_ROW_AND_COLUMN: 36
+
+//CHK_2X2:3
+//CHK_2X2:7
+//CHK_2X2:11
+//CHK_2X2:15
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256),name=buf
+StructuredBuffer<int32_t> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+using namespace linalg;
+
+int32_t CombineOp(int32_t lhs, int32_t rhs)
+{
+ return lhs + rhs;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let mat = coopMatLoad<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(inputBuffer, 0, stride);
+#if TEST_MODE == 0
+ let result = mat.ReduceRow<8>(CombineOp);
+#elif TEST_MODE == 1
+ let result = mat.ReduceColumn<8>(CombineOp);
+#elif TEST_MODE == 2
+ let result = mat.ReduceRowAndColumn<8, 8>(CombineOp);
+#elif TEST_MODE == 3
+ let result = mat.Reduce2x2(CombineOp);
+#endif
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
+}
diff --git a/tests/cooperative-matrix/return.slang b/tests/cooperative-matrix/return.slang
index 722a31b8b..b67117892 100644
--- a/tests/cooperative-matrix/return.slang
+++ b/tests/cooperative-matrix/return.slang
@@ -12,7 +12,9 @@ StructuredBuffer<float> inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
CoopMatType doubleCoopmat(CoopMatType mat)
{
@@ -25,8 +27,8 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
let result = doubleCoopmat(mat);
- result.store(outputBuffer, 0, 4, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 4);
}
diff --git a/tests/cooperative-matrix/scalar-mul.slang b/tests/cooperative-matrix/scalar-mul.slang
index 9d266920e..29d8f03f0 100644
--- a/tests/cooperative-matrix/scalar-mul.slang
+++ b/tests/cooperative-matrix/scalar-mul.slang
@@ -12,7 +12,9 @@ ByteAddressBuffer inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -20,8 +22,8 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
let result = mat * 4.5;
- result.store(outputBuffer, 0, 4, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 4);
}
diff --git a/tests/cooperative-matrix/struct.slang b/tests/cooperative-matrix/struct.slang
index 38d98b44f..592e553ae 100644
--- a/tests/cooperative-matrix/struct.slang
+++ b/tests/cooperative-matrix/struct.slang
@@ -19,7 +19,9 @@ ByteAddressBuffer input2;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
struct MyStruct
{
@@ -34,9 +36,9 @@ void computeMain()
let matrixLayout = CoopMatMatrixLayout::RowMajor;
MyStruct s;
- s.mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- s.mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ s.mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ s.mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
- s.mat1.store(outputBuffer, 0, stride, matrixLayout);
- s.mat2.store(outputBuffer, 4, stride, matrixLayout);
+ s.mat1.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
+ s.mat2.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 4, stride);
}
diff --git a/tests/cooperative-matrix/sub.slang b/tests/cooperative-matrix/sub.slang
index 7fe6f3fea..a42fee9e8 100644
--- a/tests/cooperative-matrix/sub.slang
+++ b/tests/cooperative-matrix/sub.slang
@@ -15,7 +15,9 @@ ByteAddressBuffer input1;
//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input2
ByteAddressBuffer input2;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -23,9 +25,9 @@ void computeMain()
let stride = 16;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
- let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let mat1 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input1, 0, stride);
+ let mat2 = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(input2, 0, stride);
let result = mat1 - mat2;
- result.store(outputBuffer, 0, 4, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 4);
}
diff --git a/tests/cooperative-matrix/subscript-in-func.slang b/tests/cooperative-matrix/subscript-in-func.slang
index 585eabf92..1bb8df433 100644
--- a/tests/cooperative-matrix/subscript-in-func.slang
+++ b/tests/cooperative-matrix/subscript-in-func.slang
@@ -12,7 +12,9 @@ StructuredBuffer<float> inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
-typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<float, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
static const int stride = 16;
static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor;
@@ -23,12 +25,12 @@ void squareCoopMatElements(CoopMatType mat)
{
mat[i] = mat[i] * mat[i];
}
- mat.store(outputBuffer, 0, stride, matrixLayout);
+ mat.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, stride);
}
[numthreads(32, 1, 1)]
void computeMain()
{
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
squareCoopMatElements(mat);
}
diff --git a/tests/cooperative-matrix/subscript.slang b/tests/cooperative-matrix/subscript.slang
index cfe164f04..0d765b6b3 100644
--- a/tests/cooperative-matrix/subscript.slang
+++ b/tests/cooperative-matrix/subscript.slang
@@ -9,7 +9,9 @@
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<int32_t> outputBuffer;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -19,5 +21,5 @@ void computeMain()
mat[1] = mat[0]+2;
mat[2] = mat[1]+3;
mat[3] = mat[2]+4;
- mat.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+ mat.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 16);
}
diff --git a/tests/cooperative-matrix/transpose.slang b/tests/cooperative-matrix/transpose.slang
new file mode 100644
index 000000000..cd317cac2
--- /dev/null
+++ b/tests/cooperative-matrix/transpose.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -skip-spirv-validation -render-feature cooperative-matrix-conversions
+
+//CHECK: type: int32_t
+//CHECK: 1
+//CHECK-COUNT-15: 0
+//CHECK: 2
+//CHECK-COUNT-15: 0
+//CHECK: 3
+//CHECK-COUNT-15: 0
+//CHECK: 4
+//CHECK-COUNT-15: 0
+//CHECK: 5
+//CHECK-COUNT-15: 0
+//CHECK: 6
+//CHECK-COUNT-15: 0
+//CHECK: 7
+//CHECK-COUNT-15: 0
+//CHECK: 8
+//CHECK-COUNT-15: 0
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256),name=buf
+StructuredBuffer<int32_t> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+using namespace linalg;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let mat = coopMatLoad<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator, CoopMatMatrixLayout.RowMajor>(inputBuffer, 0, stride);
+ let result = mat.Transpose();
+ result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride);
+}
diff --git a/tests/cooperative-matrix/unary_neg.slang b/tests/cooperative-matrix/unary_neg.slang
index 5d58f1395..be803d8c6 100644
--- a/tests/cooperative-matrix/unary_neg.slang
+++ b/tests/cooperative-matrix/unary_neg.slang
@@ -12,7 +12,9 @@ ByteAddressBuffer inputBuffer;
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWStructuredBuffer<int32_t> outputBuffer;
-typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+using namespace linalg;
+
+typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
[numthreads(32, 1, 1)]
void computeMain()
@@ -20,8 +22,8 @@ void computeMain()
let stride = 4;
let matrixLayout = CoopMatMatrixLayout::RowMajor;
- let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ let mat = CoopMatType.Load<CoopMatMatrixLayout::RowMajor>(inputBuffer, 0, stride);
let result = -mat;
- result.store(outputBuffer, 0, 4, matrixLayout);
+ result.Store<CoopMatMatrixLayout::RowMajor>(outputBuffer, 0, 4);
}