summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/user-guide/a3-02-reference-capability-atoms.md9
-rw-r--r--source/slang/hlsl.meta.slang541
-rw-r--r--source/slang/slang-capabilities.capdef11
-rw-r--r--source/slang/slang-emit-spirv-ops.h22
-rw-r--r--source/slang/slang-emit-spirv.cpp51
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-util.cpp12
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp4
-rw-r--r--source/slang/slang-ir.h11
-rw-r--r--tests/cooperative-matrix/add.slang32
-rw-r--r--tests/cooperative-matrix/array.slang36
-rw-r--r--tests/cooperative-matrix/comparison.slang35
-rw-r--r--tests/cooperative-matrix/conversion.slang30
-rw-r--r--tests/cooperative-matrix/copyFrom.slang17
-rw-r--r--tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang20
-rw-r--r--tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang20
-rw-r--r--tests/cooperative-matrix/div.slang31
-rw-r--r--tests/cooperative-matrix/fill.slang16
-rw-r--r--tests/cooperative-matrix/inout.slang31
-rw-r--r--tests/cooperative-matrix/load-store-arbitrary-array-vec.slang41
-rw-r--r--tests/cooperative-matrix/load-store-arbitrary-array.slang41
-rw-r--r--tests/cooperative-matrix/load-store-groupshared.slang31
-rw-r--r--tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang26
-rw-r--r--tests/cooperative-matrix/load-store-rwstructuredbuffer.slang27
-rw-r--r--tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang45
-rw-r--r--tests/cooperative-matrix/mat-mul-add.slang23
-rw-r--r--tests/cooperative-matrix/mod.slang53
-rw-r--r--tests/cooperative-matrix/mul.slang31
-rw-r--r--tests/cooperative-matrix/out.slang34
-rw-r--r--tests/cooperative-matrix/parameter.slang32
-rw-r--r--tests/cooperative-matrix/return.slang32
-rw-r--r--tests/cooperative-matrix/scalar-mul.slang27
-rw-r--r--tests/cooperative-matrix/struct.slang42
-rw-r--r--tests/cooperative-matrix/sub.slang31
-rw-r--r--tests/cooperative-matrix/subscript-in-func.slang34
-rw-r--r--tests/cooperative-matrix/subscript.slang23
-rw-r--r--tests/cooperative-matrix/unary_neg.slang27
38 files changed, 1522 insertions, 11 deletions
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md
index 1cb7f5bd5..d72a2768f 100644
--- a/docs/user-guide/a3-02-reference-capability-atoms.md
+++ b/docs/user-guide/a3-02-reference-capability-atoms.md
@@ -424,6 +424,9 @@ Extensions
`SPV_NV_cooperative_vector`
> Represents the SPIR-V extension for SPV_NV_cooperative_vector.
+`SPV_KHR_cooperative_matrix`
+> Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+
`spvAtomicFloat32AddEXT`
> Represents the SPIR-V capability for atomic float 32 add operations.
@@ -535,6 +538,9 @@ Extensions
`spvCooperativeVectorTrainingNV`
> Represents the SPIR-V capability for cooperative vector training
+`spvCooperativeMatrixKHR`
+> Represents the SPIR-V capability for cooperative matrices
+
`spvMaximalReconvergenceKHR`
> Represents the SPIR-V capability for maximal reconvergence.
@@ -1206,6 +1212,9 @@ Other
----------------------
*Capabilities that may be deprecated*
+`cooperative_matrix`
+> Capabilities needed to use cooperative matrices
+
`SPIRV_1_0`
> Use `spirv_1_0` instead
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index bdaa2bad0..e71997c6c 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22009,6 +22009,546 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
}
//
+// Cooperative Matrix type
+//
+
+__intrinsic_type($(kIROp_CoopMatrixType))
+[require(cooperative_matrix)]
+struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic
+{
+ //
+ // Initialization
+ //
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init()
+ {
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init(T t)
+ {
+ this.fill(t);
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ {
+ this.copyFrom(other);
+ }
+
+ [ForceInline]
+ __init(This x)
+ {
+ this = x;
+ }
+
+ // Required for `IArithmetic`.
+ [OverloadRank(-10)]
+ [ForceInline]
+ __init(int i)
+ {
+ this = CoopMat<T, S, M, N, R>(T(i));
+ }
+
+ //
+ // Simple setters
+ //
+
+ [require(cooperative_matrix)]
+ [mutating]
+ [ForceInline]
+ void fill(T t)
+ {
+ this = spirv_asm
+ {
+ result:$$CoopMat<T, S, M, N, R> = OpConstantComposite $t;
+ };
+ }
+
+ [require(cooperative_matrix)]
+ [mutating]
+ [ForceInline]
+ void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ {
+ if (__isFloat<T>() && __isInt<U>())
+ this = __int_to_float_cast<T>(other);
+ else if (__isInt<T>() && __isFloat<U>())
+ this = __float_to_int_cast<T>(other);
+ else if (__isFloat<T>() && __isFloat<U>())
+ this = __real_cast<T>(other);
+ else if (__isInt<T>() && __isInt<U>())
+ this = __int_cast<T>(other);
+ }
+
+ //
+ // Subscript
+ //
+
+ __intrinsic_op($(kIROp_GetElement))
+ [__NoSideEffect]
+ T __indexRead(int index);
+
+ __intrinsic_op($(kIROp_GetElementPtr))
+ [__ref]
+ [__NoSideEffect]
+ Ref<T> __indexRef(int index);
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getCount()
+ {
+ return getLength();
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getRowCount()
+ {
+ return M;
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getColumnCount()
+ {
+ return N;
+ }
+
+ __subscript(int index) -> T
+ {
+ [__NoSideEffect]
+ [nonmutating]
+ get
+ {
+ return __indexRead(index);
+ }
+
+ [mutating]
+ set
+ {
+ __indexRef(index) = newValue;
+ }
+ }
+
+ /// Returns the number of components owned by each invocation.
+ [ForceInline]
+ [require(cooperative_matrix)]
+ uint getLength()
+ {
+ return spirv_asm
+ {
+ result:$$uint = OpCooperativeMatrixLengthKHR $$This;
+ };
+ }
+
+ //
+ // Store
+ //
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %pointer:$$T* = OpPtrAccessChain $buffer $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [require(cooperative_matrix)]
+ [ForceInline]
+ void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$T;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$U;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ //
+ // Load
+ //
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16;
+ return spirv_asm
+ {
+ %pointer:$$T* = OpPtrAccessChain $buffer $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$T;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$U;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ //
+ // Arithmetic
+ //
+
+ __intrinsic_op($(kIROp_Add))
+ This add(This other);
+
+ __intrinsic_op($(kIROp_Sub))
+ This sub(This other);
+
+ __intrinsic_op($(kIROp_Mul))
+ This mul(This other);
+
+ __intrinsic_op($(kIROp_Div))
+ This div(This other);
+
+ __intrinsic_op($(kIROp_Neg))
+ This neg();
+
+ This mod(This other)
+ {
+ This ret;
+ for (int i = 0; i < getLength(); ++i)
+ {
+ ret[i] = this[i] % other[i];
+ }
+ return ret;
+ }
+
+ //
+ // Equality and ordering
+ //
+
+ bool equals(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] != other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool lessThan(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return false;
+ }
+
+ bool lessThanOrEquals(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+}
+
+//
+// Convenience loading functions for cooperative matrices which infer the
+// element type for structured buffers, pointers and groupshared arrays.
+//
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout);
+}
+
+//
+// Cooperative Matrix casting
+//
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_IntCast))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_FloatCast))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_CastIntToFloat))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_CastFloatToInt))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val);
+
+//
+// Cooperative Matrix multiplication with scalar
+//
+
+__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
+{
+ return spirv_asm
+ {
+ result:$$CoopMat<T, S, M, N, R> = OpMatrixTimesScalar $lhs $rhs;
+ };
+}
+
+__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs)
+{
+ return rhs * lhs;
+}
+
+//
+// Cooperative Matrix enums
+//
+
+enum CoopMatScope
+{
+ Device = 1,
+ Workgroup = 2,
+ Subgroup = 3,
+ QueueFamily = 5,
+};
+
+enum CoopMatMatrixUse
+{
+ MatrixA = 0,
+ MatrixB = 1,
+ MatrixAccumulator = 2,
+};
+
+enum CoopMatMatrixLayout
+{
+ RowMajor = 0,
+ ColumnMajor = 1,
+};
+
+enum CoopMatMatrixOperands
+{
+ None = 0x0,
+ MatrixASigned = 0x1,
+ MatrixBSigned = 0x2,
+ MatrixCSigned = 0x4,
+ MatrixResultSigned = 0x8,
+ SaturatingAccumulation = 0x10,
+};
+
+//
+// Cooperative Matrix multiply accumulate
+//
+
+[require(cooperative_matrix)]
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse>
+CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands)
+{
+ static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`");
+ return spirv_asm
+ {
+ result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
+ };
+}
+
+//
// Cooperative Vector
//
@@ -23435,6 +23975,7 @@ CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, l
// Coop Vector matrix multiplication
//
+
/// Specifies the memory layout for matrices used in cooperative vector operations.
/// @remarks This enum defines different matrix layout options that affect how matrix data is stored and accessed,
/// including standard row-major and column-major layouts as well as specialized layouts optimized for specific operations.
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 749a72e3b..fc1bc71a8 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -545,6 +545,10 @@ def SPV_EXT_replicated_composites : _spirv_1_0;
/// [EXT]
def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
+/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+/// [EXT]
+def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+
// SPIRV Capabilities.
/// Represents the SPIR-V capability for atomic float 32 add operations.
@@ -695,6 +699,10 @@ def spvCooperativeVectorNV : SPV_NV_cooperative_vector;
/// [EXT]
def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector;
+/// Represents the SPIR-V capability for cooperative matrices
+/// [EXT]
+def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix;
+
/// Represents the SPIR-V capability for maximal reconvergence.
/// [EXT]
def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence;
@@ -1075,6 +1083,9 @@ alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV
/// [Compound]
alias cooperative_vector_training = spvCooperativeVectorTrainingNV;
+/// Capabilities needed to use cooperative matrices
+alias cooperative_matrix = spvCooperativeMatrixKHR;
+
// Non-internal shader stages
//
/// Pixel shader stage
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 880f6b083..017e58667 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -151,6 +151,28 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp
componentCount);
}
+template<typename T1, typename T2>
+SpvInst* emitOpTypeCoopMat(
+ IRInst* inst,
+ const T1& componentType,
+ const T2& scope,
+ const T2& rowCount,
+ const T2& columnCount,
+ const T2& matrixUse)
+{
+ static_assert(isSingular<T1>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeCooperativeMatrixKHR,
+ kResultID,
+ componentType,
+ scope,
+ rowCount,
+ columnCount,
+ matrixUse);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix
template<typename T>
SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index baef62f1c..d07d587e5 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(),
coopVecType);
}
+ case kIROp_CoopMatrixType:
+ {
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix"));
+
+ IRBuilder builder(m_irModule);
+ auto coopMatType = static_cast<IRCoopMatrixType*>(inst);
+ return emitOpTypeCoopMat(
+ coopMatType,
+ coopMatType->getElementType(),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getScope())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
+ builder.getIntType()));
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto baseTy = base->getDataType();
SLANG_ASSERT(
as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) ||
- as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy));
+ as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) ||
+ as<IRCoopMatrixType>(baseTy));
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
@@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
if (as<IRBoolType>(fromType))
{
@@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
bool isMatrixCast = false;
if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) ||
- as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV))
+ as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) ||
+ // Cooperative matrices behave like vectors where arithmetic operations can be performed
+ // directly without having to loop through the matrix and performing operations on the
+ // vectors.
+ as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV))
{
- fromType = getVectorElementType(fromTypeV);
- toType = getVectorElementType(toTypeV);
+ fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ toType = getVectorOrCoopMatrixElementType(toTypeV);
}
else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV))
{
@@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
+
SLANG_ASSERT(isFloatingType(toType));
if (isIntegralType(fromType))
@@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
SLANG_ASSERT(isFloatingType(fromType));
if (as<IRBoolType>(toType))
@@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
UInt operandCount,
ArrayView<IRInst*> operands)
{
- IRType* elementType = getVectorElementType(operands[0]->getDataType());
+ IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
bool isFloatingPoint = false;
bool isBool = false;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c7ed5affe..3de40d2c0 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -235,6 +235,7 @@ INST(Nop, nop, 0, 0)
INST(RayQueryType, RayQuery, 1, HOISTABLE)
INST(HitObjectType, HitObject, 0, HOISTABLE)
INST(CoopVectorType, CoopVectorType, 2, HOISTABLE)
+INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE)
// Opaque type that can be dynamically cast to other resource types.
INST(DynamicResourceType, DynamicResource, 0, HOISTABLE)
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 58bb7aaf2..4919850eb 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -23,6 +23,18 @@ IRType* getVectorElementType(IRType* type)
return vectorType->getElementType();
if (auto coopVecType = as<IRCoopVectorType>(type))
return coopVecType->getElementType();
+ if (auto coopMatType = as<IRCoopMatrixType>(type))
+ return coopMatType->getElementType();
+ return type;
+}
+
+IRType* getVectorOrCoopMatrixElementType(IRType* type)
+{
+ auto vectorElementType = getVectorElementType(type);
+ if (vectorElementType != type)
+ return vectorElementType;
+ if (auto coopMatrixType = as<IRCoopMatrixType>(type))
+ return coopMatrixType->getElementType();
return type;
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 549981f58..0a8bc9b1d 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -78,6 +78,9 @@ bool isComInterfaceType(IRType* type);
// If `type` is a vector, returns its element type. Otherwise, return `type`.
IRType* getVectorElementType(IRType* type);
+// If `type` is a vector or a coop matrix, returns its element type. Otherwise, return `type`.
+IRType* getVectorOrCoopMatrixElementType(IRType* type);
+
// If `type` is a matrix, returns its element type. Otherwise, return `type`.
IRType* getMatrixElementType(IRType* type);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f75fe2f48..c105a698a 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5351,6 +5351,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
{
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
}
+ else if (auto coopMatType = as<IRCoopMatrixType>(valueType))
+ {
+ type = coopMatType->getElementType();
+ }
else if (const auto basicType = as<IRBasicType>(valueType))
{
// HLSL support things like float.x, in which case we just return the base pointer.
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index dbc66c6a3..dbf2b91be 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1869,6 +1869,17 @@ struct IRCoopVectorType : IRType
IR_LEAF_ISA(CoopVectorType)
};
+struct IRCoopMatrixType : IRType
+{
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+ IRInst* getScope() { return getOperand(1); }
+ IRInst* getRowCount() { return getOperand(2); }
+ IRInst* getColumnCount() { return getOperand(3); }
+ IRInst* getMatrixUse() { return getOperand(4); }
+
+ IR_LEAF_ISA(CoopMatrixType)
+};
+
bool isDefinition(IRInst* inVal);
// A structure type is represented as a parent instruction,
diff --git a/tests/cooperative-matrix/add.slang b/tests/cooperative-matrix/add.slang
new file mode 100644
index 000000000..3d8348d13
--- /dev/null
+++ b/tests/cooperative-matrix/add.slang
@@ -0,0 +1,32 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 1
+// CHECK-NEXT: 3
+// CHECK-NEXT: 5
+// CHECK-NEXT: 7
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let result = mat1 + mat2;
+
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
+
diff --git a/tests/cooperative-matrix/array.slang b/tests/cooperative-matrix/array.slang
new file mode 100644
index 000000000..b46c0f66b
--- /dev/null
+++ b/tests/cooperative-matrix/array.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK: 1.000000
+// CHECK-NEXT: 2.000000
+// CHECK-NEXT: 3.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 5.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 7.000000
+// CHECK-NEXT: 8.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[5.0 6.0 7.0 8.0], stride=256),name=input1
+ByteAddressBuffer input2;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ CoopMatType coopMatArray[2];
+ coopMatArray[0] = CoopMatType.load(input1, 0, stride, matrixLayout);
+ coopMatArray[1] = CoopMatType.load(input2, 0, stride, matrixLayout);
+
+ coopMatArray[0].store(outputBuffer, 0, stride, matrixLayout);
+ coopMatArray[1].store(outputBuffer, 4, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/comparison.slang b/tests/cooperative-matrix/comparison.slang
new file mode 100644
index 000000000..bcf0c90ae
--- /dev/null
+++ b/tests/cooperative-matrix/comparison.slang
@@ -0,0 +1,35 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: uint32_t
+// CHECK-NEXT: 0
+// CHECK-NEXT: 1
+// CHECK-NEXT: 1
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[1.0 3.0 2.0 4.0], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain(uint3 threadIndex : SV_DispatchThreadID)
+{
+ let stride = 4;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+
+ uint32_t equals = mat1 == mat2 ? 1 : 0;
+ uint32_t lessThan = mat1 < mat2 ? 1 : 0;
+ uint32_t lessThanOrEquals = mat1 <= mat2 ? 1 : 0;
+
+ outputBuffer[0] = equals;
+ outputBuffer[1] = lessThan;
+ outputBuffer[2] = lessThanOrEquals;
+}
diff --git a/tests/cooperative-matrix/conversion.slang b/tests/cooperative-matrix/conversion.slang
new file mode 100644
index 000000000..745882ab8
--- /dev/null
+++ b/tests/cooperative-matrix/conversion.slang
@@ -0,0 +1,30 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 2.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 8.000000
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input
+ByteAddressBuffer input;
+
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let intMat = CoopMat<int, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>.load(input, 0, stride, matrixLayout);
+ let floatMat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
+ let uintMat = CoopMat<uint, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat);
+ let halfMat = CoopMat<half, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(uintMat);
+ let floatMat2 = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(halfMat);
+
+ let result = floatMat + floatMat2;
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/copyFrom.slang b/tests/cooperative-matrix/copyFrom.slang
new file mode 100644
index 000000000..f7270545e
--- /dev/null
+++ b/tests/cooperative-matrix/copyFrom.slang
@@ -0,0 +1,17 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-COUNT-256: 4
+
+//TEST_INPUT:ubuffer(stride=4, count = 256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let mat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(4.0);
+ var result = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(0);
+ result.copyFrom(mat);
+ result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
+
diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang
new file mode 100644
index 000000000..0c4308308
--- /dev/null
+++ b/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang
@@ -0,0 +1,20 @@
+//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
+
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Workgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
+typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+// CHECK: error 39999: could not specialize generic for arguments of type
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let matA = CoopMatAType(3.0);
+ let matB = CoopMatBType(5.0);
+ let matC = CoopMatCType(1.0);
+
+ const let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
+ result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang
new file mode 100644
index 000000000..5b7dc7a5b
--- /dev/null
+++ b/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang
@@ -0,0 +1,20 @@
+//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
+
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+// CHECK: error 41400: static assertion failed, matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let matA = CoopMatAType(3.0);
+ let matB = CoopMatBType(5.0);
+ let matC = CoopMatCType(1.0);
+
+ let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
+ result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
diff --git a/tests/cooperative-matrix/div.slang b/tests/cooperative-matrix/div.slang
new file mode 100644
index 000000000..29207e0e4
--- /dev/null
+++ b/tests/cooperative-matrix/div.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 2
+// CHECK-NEXT: 1
+// CHECK-NEXT: 1
+// CHECK-NEXT: 0
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[4 3 5 2], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+ let result = mat1 / mat2;
+
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/fill.slang b/tests/cooperative-matrix/fill.slang
new file mode 100644
index 000000000..d1a46d053
--- /dev/null
+++ b/tests/cooperative-matrix/fill.slang
@@ -0,0 +1,16 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-COUNT-256: 10
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ var result : CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+ result.fill(10);
+ result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
+
diff --git a/tests/cooperative-matrix/inout.slang b/tests/cooperative-matrix/inout.slang
new file mode 100644
index 000000000..7284953b4
--- /dev/null
+++ b/tests/cooperative-matrix/inout.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 2.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 8.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1
+ByteAddressBuffer input;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+void doubleCoopMat(inout CoopMatType mat)
+{
+ mat = mat * 2.0;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ var mat = CoopMatType.load(input, 0, stride, matrixLayout);
+ doubleCoopMat(mat);
+ mat.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang
new file mode 100644
index 000000000..0afad3284
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly
+
+// CHECK: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+// CHECK-NEXT: 4
+// CHECK-NEXT: 5
+// CHECK-NEXT: 6
+// CHECK-NEXT: 7
+// CHECK-NEXT: 8
+// CHECK-NEXT: 9
+// CHECK-NEXT: A
+// CHECK-NEXT: B
+// CHECK-NEXT: C
+// CHECK-NEXT: D
+// CHECK-NEXT: E
+// CHECK-NEXT: F
+// CHECK-NEXT: 10
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16], stride=4, count=256):name=input
+RWByteAddressBuffer input;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<uint32_t> outputBuffer;
+
+typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+groupshared float3[128] tempShared;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
+ mat.storeAny(tempShared, 0, stride, matrixLayout);
+
+ let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout);
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/load-store-arbitrary-array.slang b/tests/cooperative-matrix/load-store-arbitrary-array.slang
new file mode 100644
index 000000000..496e62387
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-arbitrary-array.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly
+
+// CHECK: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+// CHECK-NEXT: 4
+// CHECK-NEXT: 5
+// CHECK-NEXT: 6
+// CHECK-NEXT: 7
+// CHECK-NEXT: 8
+// CHECK-NEXT: 9
+// CHECK-NEXT: A
+// CHECK-NEXT: B
+// CHECK-NEXT: C
+// CHECK-NEXT: D
+// CHECK-NEXT: E
+// CHECK-NEXT: F
+// CHECK-NEXT: 10
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16], stride=4, count=256):name=input
+RWByteAddressBuffer input;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<uint32_t> outputBuffer;
+
+typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+groupshared float[256] tempShared;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
+ mat.storeAny(tempShared, 0, stride, matrixLayout);
+
+ let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout);
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/load-store-groupshared.slang b/tests/cooperative-matrix/load-store-groupshared.slang
new file mode 100644
index 000000000..c2334c0ce
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-groupshared.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly
+
+// CHECK: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+// CHECK-NEXT: 4
+// CHECK-NEXT: 5
+// CHECK-NEXT: 6
+// CHECK-NEXT: 7
+// CHECK-NEXT: 8
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=input
+RWByteAddressBuffer input;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<uint32_t> outputBuffer;
+
+groupshared uint32_t[256] tempShared;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout);
+ mat.store(tempShared, 0, stride, matrixLayout);
+
+ let result = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(tempShared, 0, stride, matrixLayout);
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang
new file mode 100644
index 000000000..08c90992a
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang
@@ -0,0 +1,26 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+// CHECK-NEXT: 4
+// CHECK-NEXT: 5
+// CHECK-NEXT: 6
+// CHECK-NEXT: 7
+// CHECK-NEXT: 8
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=inputBuffer
+RWByteAddressBuffer inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWByteAddressBuffer outputBuffer;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout);
+ mat.store(outputBuffer, 0, 16, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang
new file mode 100644
index 000000000..d71634082
--- /dev/null
+++ b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang
@@ -0,0 +1,27 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+// CHECK-NEXT: 4
+// CHECK-NEXT: 5
+// CHECK-NEXT: 6
+// CHECK-NEXT: 7
+// CHECK-NEXT: 8
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256),name=buf
+RWStructuredBuffer<int32_t> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = coopMatLoad<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout);
+ mat.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang
new file mode 100644
index 000000000..934104a28
--- /dev/null
+++ b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang
@@ -0,0 +1,45 @@
+//TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv
+
+// This test checks that the correct SPIRV Cooperative Matrix Operands are emitted for OpCooperativeMatrixMulAddKHR operaions
+RWStructuredBuffer<int> outputBuffer1;
+RWStructuredBuffer<int> outputBuffer2;
+RWStructuredBuffer<int> outputBuffer3;
+RWStructuredBuffer<int> outputBuffer4;
+RWStructuredBuffer<int> outputBuffer5;
+
+typealias CoopMatAType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatBType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
+typealias CoopMatCType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let matA = CoopMatAType(3);
+ let matB = CoopMatBType(5);
+ let matC = CoopMatCType(1);
+
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} NoneKHR
+ coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None).store(outputBuffer1, 0, 16, matrixLayout);
+
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR
+ coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned).store(outputBuffer2, 0, 16, matrixLayout);
+
+
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixCSignedComponentsKHR
+ coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixCSigned).store(outputBuffer2, 0, 16, matrixLayout);
+
+
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixResultSignedComponentsKHR
+ coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixResultSigned).store(outputBuffer3, 0, 16, matrixLayout);
+
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} SaturatingAccumulationKHR
+ coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::SaturatingAccumulation).store(outputBuffer4, 0, 16, matrixLayout);
+
+ let allOperands = CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned | CoopMatMatrixOperands::MatrixCSigned | CoopMatMatrixOperands::MatrixResultSigned | CoopMatMatrixOperands::SaturatingAccumulation;
+ // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR|SaturatingAccumulationKHR
+ coopMatMulAdd(matA, matB, matC, allOperands).store(outputBuffer5, 0, 16, matrixLayout);
+}
+
+
diff --git a/tests/cooperative-matrix/mat-mul-add.slang b/tests/cooperative-matrix/mat-mul-add.slang
new file mode 100644
index 000000000..aea61989d
--- /dev/null
+++ b/tests/cooperative-matrix/mat-mul-add.slang
@@ -0,0 +1,23 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-COUNT-256: 241.0
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>;
+typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>;
+typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ // ( 3.0 * 5.0 ) * 16 + 1.0 = 241.0
+ let matA = CoopMatAType(3.0);
+ let matB = CoopMatBType(5.0);
+ let matC = CoopMatCType(1.0);
+
+ let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None);
+ result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
diff --git a/tests/cooperative-matrix/mod.slang b/tests/cooperative-matrix/mod.slang
new file mode 100644
index 000000000..116713481
--- /dev/null
+++ b/tests/cooperative-matrix/mod.slang
@@ -0,0 +1,53 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly
+
+// CHECK: 0
+// CHECK-NEXT: 0
+// CHECK-NEXT: 1
+// CHECK-NEXT: 2
+
+// CHECK: 0
+// CHECK-NEXT: 0
+// CHECK: 1
+// CHECK-NEXT: 2
+
+// CHECK: 0
+// CHECK-NEXT: 0
+// CHECK: 1
+// CHECK-NEXT: 2
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWByteAddressBuffer outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[4 3 5 7], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+typealias CoopMatIntType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+typealias CoopMatUintType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+typealias CoopMatFloatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatIntType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatIntType.load(input2, 0, stride, matrixLayout);
+
+ let mat3 = CoopMatFloatType(mat1);
+ let mat4 = CoopMatFloatType(mat2);
+
+ let mat5 = CoopMatUintType(mat1);
+ let mat6 = CoopMatUintType(mat2);
+
+ let result = mat1 % mat2;
+ let result2 = CoopMatIntType(mat3 % mat4);
+ let result3 = CoopMatIntType(mat5 % mat6);
+
+ result.store(outputBuffer, 0, stride, matrixLayout);
+ result2.store(outputBuffer, 16, stride, matrixLayout);
+ result3.store(outputBuffer, 32, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/mul.slang b/tests/cooperative-matrix/mul.slang
new file mode 100644
index 000000000..9b9fd67af
--- /dev/null
+++ b/tests/cooperative-matrix/mul.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 2
+// CHECK-NEXT: 6
+// CHECK-NEXT: 12
+// CHECK-NEXT: 20
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+
+ let result = mat1 * mat2;
+ result.store(outputBuffer, 0, 4, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/out.slang b/tests/cooperative-matrix/out.slang
new file mode 100644
index 000000000..147c2dd77
--- /dev/null
+++ b/tests/cooperative-matrix/out.slang
@@ -0,0 +1,34 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 2.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 8.000000
+
+// XXX FW: having out instead of in below actually properly creates two output buffers, nice
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer
+StructuredBuffer<float> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+void doubleCoopMat(CoopMatType mat, out CoopMatType result)
+{
+ result = mat * 2.0;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+
+ CoopMatType result;
+ doubleCoopMat(mat, result);
+ result.store(outputBuffer, 0, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/parameter.slang b/tests/cooperative-matrix/parameter.slang
new file mode 100644
index 000000000..19adf4177
--- /dev/null
+++ b/tests/cooperative-matrix/parameter.slang
@@ -0,0 +1,32 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 3.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 9.000000
+// CHECK-NEXT: 12.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer
+StructuredBuffer<float> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+static let stride = 16;
+//static let matrixLayout = CoopMatMatrixLayout::RowMajor;
+static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+void processCoopMat(CoopMatType mat)
+{
+ // XXX: hmmm, some error when matrixLAyout is static let
+ (mat * 3.0).store(outputBuffer, 0, stride, matrixLayout);
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ processCoopMat(mat);
+}
diff --git a/tests/cooperative-matrix/return.slang b/tests/cooperative-matrix/return.slang
new file mode 100644
index 000000000..339c9d04d
--- /dev/null
+++ b/tests/cooperative-matrix/return.slang
@@ -0,0 +1,32 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type
+
+// CHECK: type: float
+// CHECK-NEXT: 3.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 9.000000
+// CHECK-NEXT: 12.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer
+StructuredBuffer<float> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+CoopMatType doubleCoopmat(CoopMatType mat)
+{
+ return mat * 3.0;
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+
+ let result = doubleCoopmat(mat);
+ result.store(outputBuffer, 0, 4, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/scalar-mul.slang b/tests/cooperative-matrix/scalar-mul.slang
new file mode 100644
index 000000000..73f0fbbfc
--- /dev/null
+++ b/tests/cooperative-matrix/scalar-mul.slang
@@ -0,0 +1,27 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 4.500000
+// CHECK-NEXT: 9.000000
+// CHECK-NEXT: 13.500000
+// CHECK-NEXT: 18.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=inputBuffer
+ByteAddressBuffer inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+
+ let result = mat * 4.5;
+ result.store(outputBuffer, 0, 4, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/struct.slang b/tests/cooperative-matrix/struct.slang
new file mode 100644
index 000000000..24bc2c367
--- /dev/null
+++ b/tests/cooperative-matrix/struct.slang
@@ -0,0 +1,42 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 1.000000
+// CHECK-NEXT: 2.000000
+// CHECK-NEXT: 3.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 5.000000
+// CHECK-NEXT: 6.000000
+// CHECK-NEXT: 7.000000
+// CHECK-NEXT: 8.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[5.0 6.0 7.0 8.0], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+struct MyStruct
+{
+ CoopMatType mat1;
+ CoopMatType mat2;
+};
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ MyStruct s;
+ s.mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ s.mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+
+ s.mat1.store(outputBuffer, 0, stride, matrixLayout);
+ s.mat2.store(outputBuffer, 4, stride, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/sub.slang b/tests/cooperative-matrix/sub.slang
new file mode 100644
index 000000000..7b20d7c11
--- /dev/null
+++ b/tests/cooperative-matrix/sub.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 1
+// CHECK-NEXT: 1
+// CHECK-NEXT: 1
+// CHECK-NEXT: 1
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input1
+ByteAddressBuffer input1;
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input2
+ByteAddressBuffer input2;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 16;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout);
+ let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout);
+
+ let result = mat1 - mat2;
+ result.store(outputBuffer, 0, 4, matrixLayout);
+}
diff --git a/tests/cooperative-matrix/subscript-in-func.slang b/tests/cooperative-matrix/subscript-in-func.slang
new file mode 100644
index 000000000..ef63d62de
--- /dev/null
+++ b/tests/cooperative-matrix/subscript-in-func.slang
@@ -0,0 +1,34 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: float
+// CHECK-NEXT: 1.000000
+// CHECK-NEXT: 4.000000
+// CHECK-NEXT: 9.000000
+// CHECK-NEXT: 16.000000
+
+//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer
+StructuredBuffer<float> inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+static const int stride = 16;
+static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+void squareCoopMatElements(CoopMatType mat)
+{
+ for (int i = 0; i < 4; ++i)
+ {
+ mat[i] = mat[i] * mat[i];
+ }
+ mat.store(outputBuffer, 0, stride, matrixLayout);
+}
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+ squareCoopMatElements(mat);
+}
diff --git a/tests/cooperative-matrix/subscript.slang b/tests/cooperative-matrix/subscript.slang
new file mode 100644
index 000000000..731edee82
--- /dev/null
+++ b/tests/cooperative-matrix/subscript.slang
@@ -0,0 +1,23 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: 2
+// CHECK-NEXT: 4
+// CHECK: 7
+// CHECK-NEXT: 11
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ CoopMatType mat;
+ mat[0] = 2;
+ mat[1] = mat[0]+2;
+ mat[2] = mat[1]+3;
+ mat[3] = mat[2]+4;
+ mat.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor);
+}
diff --git a/tests/cooperative-matrix/unary_neg.slang b/tests/cooperative-matrix/unary_neg.slang
new file mode 100644
index 000000000..8c6436caf
--- /dev/null
+++ b/tests/cooperative-matrix/unary_neg.slang
@@ -0,0 +1,27 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// CHECK: type: int32_t
+// CHECK-NEXT: -1
+// CHECK-NEXT: -2
+// CHECK-NEXT: -3
+// CHECK-NEXT: -4
+
+//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=inputBuffer
+ByteAddressBuffer inputBuffer;
+
+//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
+RWStructuredBuffer<int32_t> outputBuffer;
+
+typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>;
+
+[numthreads(32, 1, 1)]
+void computeMain()
+{
+ let stride = 4;
+ let matrixLayout = CoopMatMatrixLayout::RowMajor;
+
+ let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout);
+
+ let result = -mat;
+ result.store(outputBuffer, 0, 4, matrixLayout);
+}