diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-15 15:57:45 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-15 14:57:45 -0700 |
| commit | d0b6a0b1ab49b5958015f31364c5ad73d9cd03eb (patch) | |
| tree | e419bb3c89fa8c389eb0ccbbe8aaa29a1dcd515f | |
| parent | a6174ff9443507dece534aa193f8c45e8f0ce7db (diff) | |
Add cooperative matrix 1 support (#6565)
* initial wip for spirv
* working tiled example
* clean up store and load
* minor fixes
* fix loadAny name
* add initial tests, including broken/unimplemented intrinsics
* fix subscript
* run tests at 16x16, remove not supported arithmetic tests
* minor fixups on implementation
* rename CoopMatMatrixUse
* Update tests to pass validation layers locally
* Add mat-mul-add test and minor fixes
* Add more tests
* Remove dead code
* Add coopMatLoad function and tests, enforce constexpr for matrix layout
* Use getVectorOrCoopMatrixElementType in place of getVectorElementType
38 files changed, 1522 insertions, 11 deletions
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md index 1cb7f5bd5..d72a2768f 100644 --- a/docs/user-guide/a3-02-reference-capability-atoms.md +++ b/docs/user-guide/a3-02-reference-capability-atoms.md @@ -424,6 +424,9 @@ Extensions `SPV_NV_cooperative_vector` > Represents the SPIR-V extension for SPV_NV_cooperative_vector. +`SPV_KHR_cooperative_matrix` +> Represents the SPIR-V extension for SPV_KHR_cooperative_matrix. + `spvAtomicFloat32AddEXT` > Represents the SPIR-V capability for atomic float 32 add operations. @@ -535,6 +538,9 @@ Extensions `spvCooperativeVectorTrainingNV` > Represents the SPIR-V capability for cooperative vector training +`spvCooperativeMatrixKHR` +> Represents the SPIR-V capability for cooperative matrices + `spvMaximalReconvergenceKHR` > Represents the SPIR-V capability for maximal reconvergence. @@ -1206,6 +1212,9 @@ Other ---------------------- *Capabilities that may be deprecated* +`cooperative_matrix` +> Capabilities needed to use cooperative matrices + `SPIRV_1_0` > Use `spirv_1_0` instead diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index bdaa2bad0..e71997c6c 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -22009,6 +22009,546 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR } // +// Cooperative Matrix type +// + +__intrinsic_type($(kIROp_CoopMatrixType)) +[require(cooperative_matrix)] +struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic +{ + // + // Initialization + // + + [ForceInline] + [require(cooperative_matrix)] + __init() + { + } + + [ForceInline] + [require(cooperative_matrix)] + __init(T t) + { + this.fill(t); + } + + [ForceInline] + [require(cooperative_matrix)] + __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + { + this.copyFrom(other); + } + + [ForceInline] + __init(This x) + { + this = x; + } + + // Required for `IArithmetic`. + [OverloadRank(-10)] + [ForceInline] + __init(int i) + { + this = CoopMat<T, S, M, N, R>(T(i)); + } + + // + // Simple setters + // + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void fill(T t) + { + this = spirv_asm + { + result:$$CoopMat<T, S, M, N, R> = OpConstantComposite $t; + }; + } + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + { + if (__isFloat<T>() && __isInt<U>()) + this = __int_to_float_cast<T>(other); + else if (__isInt<T>() && __isFloat<U>()) + this = __float_to_int_cast<T>(other); + else if (__isFloat<T>() && __isFloat<U>()) + this = __real_cast<T>(other); + else if (__isInt<T>() && __isInt<U>()) + this = __int_cast<T>(other); + } + + // + // Subscript + // + + __intrinsic_op($(kIROp_GetElement)) + [__NoSideEffect] + T __indexRead(int index); + + __intrinsic_op($(kIROp_GetElementPtr)) + [__ref] + [__NoSideEffect] + Ref<T> __indexRef(int index); + + [ForceInline] + [__NoSideEffect] + int getCount() + { + return getLength(); + } + + [ForceInline] + [__NoSideEffect] + int getRowCount() + { + return M; + } + + [ForceInline] + [__NoSideEffect] + int getColumnCount() + { + return N; + } + + __subscript(int index) -> T + { + [__NoSideEffect] + [nonmutating] + get + { + return __indexRead(index); + } + + [mutating] + set + { + __indexRef(index) = newValue; + } + } + + /// Returns the number of components owned by each invocation. + [ForceInline] + [require(cooperative_matrix)] + uint getLength() + { + return spirv_asm + { + result:$$uint = OpCooperativeMatrixLengthKHR $$This; + }; + } + + // + // Store + // + + [ForceInline] + [require(cooperative_matrix)] + void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [ForceInline] + [require(cooperative_matrix)] + void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [require(cooperative_matrix)] + [ForceInline] + void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Load + // + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Arithmetic + // + + __intrinsic_op($(kIROp_Add)) + This add(This other); + + __intrinsic_op($(kIROp_Sub)) + This sub(This other); + + __intrinsic_op($(kIROp_Mul)) + This mul(This other); + + __intrinsic_op($(kIROp_Div)) + This div(This other); + + __intrinsic_op($(kIROp_Neg)) + This neg(); + + This mod(This other) + { + This ret; + for (int i = 0; i < getLength(); ++i) + { + ret[i] = this[i] % other[i]; + } + return ret; + } + + // + // Equality and ordering + // + + bool equals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] != other[i]) + { + return false; + } + } + return true; + } + + bool lessThan(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return false; + } + + bool lessThanOrEquals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return true; + } +} + +// +// Convenience loading functions for cooperative matrices which infer the +// element type for structured buffers, pointers and groupshared arrays. +// + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout); +} + +// +// Cooperative Matrix casting +// + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_IntCast)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_FloatCast)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_CastIntToFloat)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_CastFloatToInt)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val); + +// +// Cooperative Matrix multiplication with scalar +// + +__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs) +{ + return spirv_asm + { + result:$$CoopMat<T, S, M, N, R> = OpMatrixTimesScalar $lhs $rhs; + }; +} + +__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs) +{ + return rhs * lhs; +} + +// +// Cooperative Matrix enums +// + +enum CoopMatScope +{ + Device = 1, + Workgroup = 2, + Subgroup = 3, + QueueFamily = 5, +}; + +enum CoopMatMatrixUse +{ + MatrixA = 0, + MatrixB = 1, + MatrixAccumulator = 2, +}; + +enum CoopMatMatrixLayout +{ + RowMajor = 0, + ColumnMajor = 1, +}; + +enum CoopMatMatrixOperands +{ + None = 0x0, + MatrixASigned = 0x1, + MatrixBSigned = 0x2, + MatrixCSigned = 0x4, + MatrixResultSigned = 0x8, + SaturatingAccumulation = 0x10, +}; + +// +// Cooperative Matrix multiply accumulate +// + +[require(cooperative_matrix)] +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse> +CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands) +{ + static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`"); + return spirv_asm + { + result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; + }; +} + +// // Cooperative Vector // @@ -23435,6 +23975,7 @@ CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, l // Coop Vector matrix multiplication // + /// Specifies the memory layout for matrices used in cooperative vector operations. /// @remarks This enum defines different matrix layout options that affect how matrix data is stored and accessed, /// including standard row-major and column-major layouts as well as specialized layouts optimized for specific operations. diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 749a72e3b..fc1bc71a8 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -545,6 +545,10 @@ def SPV_EXT_replicated_composites : _spirv_1_0; /// [EXT] def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites; +/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix. +/// [EXT] +def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer; + // SPIRV Capabilities. /// Represents the SPIR-V capability for atomic float 32 add operations. @@ -695,6 +699,10 @@ def spvCooperativeVectorNV : SPV_NV_cooperative_vector; /// [EXT] def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector; +/// Represents the SPIR-V capability for cooperative matrices +/// [EXT] +def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix; + /// Represents the SPIR-V capability for maximal reconvergence. /// [EXT] def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence; @@ -1075,6 +1083,9 @@ alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV /// [Compound] alias cooperative_vector_training = spvCooperativeVectorTrainingNV; +/// Capabilities needed to use cooperative matrices +alias cooperative_matrix = spvCooperativeMatrixKHR; + // Non-internal shader stages // /// Pixel shader stage diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 880f6b083..017e58667 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -151,6 +151,28 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp componentCount); } +template<typename T1, typename T2> +SpvInst* emitOpTypeCoopMat( + IRInst* inst, + const T1& componentType, + const T2& scope, + const T2& rowCount, + const T2& columnCount, + const T2& matrixUse) +{ + static_assert(isSingular<T1>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeCooperativeMatrixKHR, + kResultID, + componentType, + scope, + rowCount, + columnCount, + matrixUse); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix template<typename T> SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index baef62f1c..d07d587e5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), coopVecType); } + case kIROp_CoopMatrixType: + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + IRBuilder builder(m_irModule); + auto coopMatType = static_cast<IRCoopMatrixType*>(inst); + return emitOpTypeCoopMat( + coopMatType, + coopMatType->getElementType(), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getScope())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), + builder.getIntType())); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) || + as<IRCoopMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); if (as<IRBoolType>(fromType)) { @@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isMatrixCast = false; if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || - as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) || + // Cooperative matrices behave like vectors where arithmetic operations can be performed + // directly without having to loop through the matrix and performing operations on the + // vectors. + as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV)) { - fromType = getVectorElementType(fromTypeV); - toType = getVectorElementType(toTypeV); + fromType = getVectorOrCoopMatrixElementType(fromTypeV); + toType = getVectorOrCoopMatrixElementType(toTypeV); } else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV)) { @@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + SLANG_ASSERT(isFloatingType(toType)); if (isIntegralType(fromType)) @@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); if (as<IRBoolType>(toType)) @@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = getVectorElementType(operands[0]->getDataType()); + IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c7ed5affe..3de40d2c0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -235,6 +235,7 @@ INST(Nop, nop, 0, 0) INST(RayQueryType, RayQuery, 1, HOISTABLE) INST(HitObjectType, HitObject, 0, HOISTABLE) INST(CoopVectorType, CoopVectorType, 2, HOISTABLE) +INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE) // Opaque type that can be dynamically cast to other resource types. INST(DynamicResourceType, DynamicResource, 0, HOISTABLE) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 58bb7aaf2..4919850eb 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -23,6 +23,18 @@ IRType* getVectorElementType(IRType* type) return vectorType->getElementType(); if (auto coopVecType = as<IRCoopVectorType>(type)) return coopVecType->getElementType(); + if (auto coopMatType = as<IRCoopMatrixType>(type)) + return coopMatType->getElementType(); + return type; +} + +IRType* getVectorOrCoopMatrixElementType(IRType* type) +{ + auto vectorElementType = getVectorElementType(type); + if (vectorElementType != type) + return vectorElementType; + if (auto coopMatrixType = as<IRCoopMatrixType>(type)) + return coopMatrixType->getElementType(); return type; } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 549981f58..0a8bc9b1d 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -78,6 +78,9 @@ bool isComInterfaceType(IRType* type); // If `type` is a vector, returns its element type. Otherwise, return `type`. IRType* getVectorElementType(IRType* type); +// If `type` is a vector or a coop matrix, returns its element type. Otherwise, return `type`. +IRType* getVectorOrCoopMatrixElementType(IRType* type); + // If `type` is a matrix, returns its element type. Otherwise, return `type`. IRType* getMatrixElementType(IRType* type); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f75fe2f48..c105a698a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5351,6 +5351,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); } + else if (auto coopMatType = as<IRCoopMatrixType>(valueType)) + { + type = coopMatType->getElementType(); + } else if (const auto basicType = as<IRBasicType>(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index dbc66c6a3..dbf2b91be 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1869,6 +1869,17 @@ struct IRCoopVectorType : IRType IR_LEAF_ISA(CoopVectorType) }; +struct IRCoopMatrixType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getScope() { return getOperand(1); } + IRInst* getRowCount() { return getOperand(2); } + IRInst* getColumnCount() { return getOperand(3); } + IRInst* getMatrixUse() { return getOperand(4); } + + IR_LEAF_ISA(CoopMatrixType) +}; + bool isDefinition(IRInst* inVal); // A structure type is represented as a parent instruction, diff --git a/tests/cooperative-matrix/add.slang b/tests/cooperative-matrix/add.slang new file mode 100644 index 000000000..3d8348d13 --- /dev/null +++ b/tests/cooperative-matrix/add.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 1 +// CHECK-NEXT: 3 +// CHECK-NEXT: 5 +// CHECK-NEXT: 7 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + let result = mat1 + mat2; + + result.store(outputBuffer, 0, stride, matrixLayout); +} + diff --git a/tests/cooperative-matrix/array.slang b/tests/cooperative-matrix/array.slang new file mode 100644 index 000000000..b46c0f66b --- /dev/null +++ b/tests/cooperative-matrix/array.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK: 1.000000 +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 3.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 5.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 7.000000 +// CHECK-NEXT: 8.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[5.0 6.0 7.0 8.0], stride=256),name=input1 +ByteAddressBuffer input2; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + CoopMatType coopMatArray[2]; + coopMatArray[0] = CoopMatType.load(input1, 0, stride, matrixLayout); + coopMatArray[1] = CoopMatType.load(input2, 0, stride, matrixLayout); + + coopMatArray[0].store(outputBuffer, 0, stride, matrixLayout); + coopMatArray[1].store(outputBuffer, 4, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/comparison.slang b/tests/cooperative-matrix/comparison.slang new file mode 100644 index 000000000..bcf0c90ae --- /dev/null +++ b/tests/cooperative-matrix/comparison.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: uint32_t +// CHECK-NEXT: 0 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[1.0 3.0 2.0 4.0], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain(uint3 threadIndex : SV_DispatchThreadID) +{ + let stride = 4; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + + uint32_t equals = mat1 == mat2 ? 1 : 0; + uint32_t lessThan = mat1 < mat2 ? 1 : 0; + uint32_t lessThanOrEquals = mat1 <= mat2 ? 1 : 0; + + outputBuffer[0] = equals; + outputBuffer[1] = lessThan; + outputBuffer[2] = lessThanOrEquals; +} diff --git a/tests/cooperative-matrix/conversion.slang b/tests/cooperative-matrix/conversion.slang new file mode 100644 index 000000000..745882ab8 --- /dev/null +++ b/tests/cooperative-matrix/conversion.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 8.000000 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input +ByteAddressBuffer input; + + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let intMat = CoopMat<int, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>.load(input, 0, stride, matrixLayout); + let floatMat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat); + let uintMat = CoopMat<uint, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(intMat); + let halfMat = CoopMat<half, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(uintMat); + let floatMat2 = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(halfMat); + + let result = floatMat + floatMat2; + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/copyFrom.slang b/tests/cooperative-matrix/copyFrom.slang new file mode 100644 index 000000000..f7270545e --- /dev/null +++ b/tests/cooperative-matrix/copyFrom.slang @@ -0,0 +1,17 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-COUNT-256: 4 + +//TEST_INPUT:ubuffer(stride=4, count = 256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let mat = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(4.0); + var result = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(0); + result.copyFrom(mat); + result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} + diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang new file mode 100644 index 000000000..0c4308308 --- /dev/null +++ b/tests/cooperative-matrix/diagnostics/mat-mul-add-different-scope.slang @@ -0,0 +1,20 @@ +//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv + +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>; +typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Workgroup, 16, 16, CoopMatMatrixUse::MatrixB>; +typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +// CHECK: error 39999: could not specialize generic for arguments of type + +[numthreads(32, 1, 1)] +void computeMain() +{ + let matA = CoopMatAType(3.0); + let matB = CoopMatBType(5.0); + let matC = CoopMatCType(1.0); + + const let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None); + result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} diff --git a/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang b/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang new file mode 100644 index 000000000..5b7dc7a5b --- /dev/null +++ b/tests/cooperative-matrix/diagnostics/mat-mul-add-incorrect-matrix-use.slang @@ -0,0 +1,20 @@ +//DIAGNOSTIC_TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv + +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>; +typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>; +typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +// CHECK: error 41400: static assertion failed, matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator` + +[numthreads(32, 1, 1)] +void computeMain() +{ + let matA = CoopMatAType(3.0); + let matB = CoopMatBType(5.0); + let matC = CoopMatCType(1.0); + + let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None); + result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} diff --git a/tests/cooperative-matrix/div.slang b/tests/cooperative-matrix/div.slang new file mode 100644 index 000000000..29207e0e4 --- /dev/null +++ b/tests/cooperative-matrix/div.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 2 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 0 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +//TEST_INPUT:ubuffer(data=[4 3 5 2], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + let result = mat1 / mat2; + + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/fill.slang b/tests/cooperative-matrix/fill.slang new file mode 100644 index 000000000..d1a46d053 --- /dev/null +++ b/tests/cooperative-matrix/fill.slang @@ -0,0 +1,16 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-COUNT-256: 10 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +[numthreads(32, 1, 1)] +void computeMain() +{ + var result : CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + result.fill(10); + result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} + diff --git a/tests/cooperative-matrix/inout.slang b/tests/cooperative-matrix/inout.slang new file mode 100644 index 000000000..7284953b4 --- /dev/null +++ b/tests/cooperative-matrix/inout.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 8.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1 +ByteAddressBuffer input; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +void doubleCoopMat(inout CoopMatType mat) +{ + mat = mat * 2.0; +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + var mat = CoopMatType.load(input, 0, stride, matrixLayout); + doubleCoopMat(mat); + mat.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang new file mode 100644 index 000000000..0afad3284 --- /dev/null +++ b/tests/cooperative-matrix/load-store-arbitrary-array-vec.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly + +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 +// CHECK-NEXT: 9 +// CHECK-NEXT: A +// CHECK-NEXT: B +// CHECK-NEXT: C +// CHECK-NEXT: D +// CHECK-NEXT: E +// CHECK-NEXT: F +// CHECK-NEXT: 10 + +//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16], stride=4, count=256):name=input +RWByteAddressBuffer input; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<uint32_t> outputBuffer; + +typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +groupshared float3[128] tempShared; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout); + mat.storeAny(tempShared, 0, stride, matrixLayout); + + let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout); + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/load-store-arbitrary-array.slang b/tests/cooperative-matrix/load-store-arbitrary-array.slang new file mode 100644 index 000000000..496e62387 --- /dev/null +++ b/tests/cooperative-matrix/load-store-arbitrary-array.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly + +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 +// CHECK-NEXT: 9 +// CHECK-NEXT: A +// CHECK-NEXT: B +// CHECK-NEXT: C +// CHECK-NEXT: D +// CHECK-NEXT: E +// CHECK-NEXT: F +// CHECK-NEXT: 10 + +//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16], stride=4, count=256):name=input +RWByteAddressBuffer input; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<uint32_t> outputBuffer; + +typealias CoopMatType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +groupshared float[256] tempShared; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout); + mat.storeAny(tempShared, 0, stride, matrixLayout); + + let result = CoopMatType.loadAny(tempShared, 0, stride, matrixLayout); + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/load-store-groupshared.slang b/tests/cooperative-matrix/load-store-groupshared.slang new file mode 100644 index 000000000..c2334c0ce --- /dev/null +++ b/tests/cooperative-matrix/load-store-groupshared.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly + +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 + +//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=input +RWByteAddressBuffer input; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<uint32_t> outputBuffer; + +groupshared uint32_t[256] tempShared; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(input, 0, stride, matrixLayout); + mat.store(tempShared, 0, stride, matrixLayout); + + let result = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(tempShared, 0, stride, matrixLayout); + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang new file mode 100644 index 000000000..08c90992a --- /dev/null +++ b/tests/cooperative-matrix/load-store-rwbyteaddressbuffer.slang @@ -0,0 +1,26 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 + +//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256):name=inputBuffer +RWByteAddressBuffer inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWByteAddressBuffer outputBuffer; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = coopMatLoad<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout); + mat.store(outputBuffer, 0, 16, matrixLayout); +} diff --git a/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang new file mode 100644 index 000000000..d71634082 --- /dev/null +++ b/tests/cooperative-matrix/load-store-rwstructuredbuffer.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 + +//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8], stride=4, count=256),name=buf +RWStructuredBuffer<int32_t> inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = coopMatLoad<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>(inputBuffer, 0, stride, matrixLayout); + mat.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang new file mode 100644 index 000000000..934104a28 --- /dev/null +++ b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang @@ -0,0 +1,45 @@ +//TEST(compute):SIMPLE(filecheck=CHECK): -entry computeMain -stage compute -target spirv + +// This test checks that the correct SPIRV Cooperative Matrix Operands are emitted for OpCooperativeMatrixMulAddKHR operaions +RWStructuredBuffer<int> outputBuffer1; +RWStructuredBuffer<int> outputBuffer2; +RWStructuredBuffer<int> outputBuffer3; +RWStructuredBuffer<int> outputBuffer4; +RWStructuredBuffer<int> outputBuffer5; + +typealias CoopMatAType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>; +typealias CoopMatBType = CoopMat<int16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>; +typealias CoopMatCType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let matA = CoopMatAType(3); + let matB = CoopMatBType(5); + let matC = CoopMatCType(1); + + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} NoneKHR + coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None).store(outputBuffer1, 0, 16, matrixLayout); + + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR + coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned).store(outputBuffer2, 0, 16, matrixLayout); + + + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixCSignedComponentsKHR + coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixCSigned).store(outputBuffer2, 0, 16, matrixLayout); + + + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixResultSignedComponentsKHR + coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::MatrixResultSigned).store(outputBuffer3, 0, 16, matrixLayout); + + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} SaturatingAccumulationKHR + coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::SaturatingAccumulation).store(outputBuffer4, 0, 16, matrixLayout); + + let allOperands = CoopMatMatrixOperands::MatrixASigned | CoopMatMatrixOperands::MatrixBSigned | CoopMatMatrixOperands::MatrixCSigned | CoopMatMatrixOperands::MatrixResultSigned | CoopMatMatrixOperands::SaturatingAccumulation; + // CHECK: OpCooperativeMatrixMulAddKHR {{.*}} MatrixASignedComponentsKHR|MatrixBSignedComponentsKHR|MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR|SaturatingAccumulationKHR + coopMatMulAdd(matA, matB, matC, allOperands).store(outputBuffer5, 0, 16, matrixLayout); +} + + diff --git a/tests/cooperative-matrix/mat-mul-add.slang b/tests/cooperative-matrix/mat-mul-add.slang new file mode 100644 index 000000000..aea61989d --- /dev/null +++ b/tests/cooperative-matrix/mat-mul-add.slang @@ -0,0 +1,23 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-COUNT-256: 241.0 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatAType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixA>; +typealias CoopMatBType = CoopMat<float16_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixB>; +typealias CoopMatCType = CoopMat<float32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + // ( 3.0 * 5.0 ) * 16 + 1.0 = 241.0 + let matA = CoopMatAType(3.0); + let matB = CoopMatBType(5.0); + let matC = CoopMatCType(1.0); + + let result = coopMatMulAdd(matA, matB, matC, CoopMatMatrixOperands::None); + result.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} diff --git a/tests/cooperative-matrix/mod.slang b/tests/cooperative-matrix/mod.slang new file mode 100644 index 000000000..116713481 --- /dev/null +++ b/tests/cooperative-matrix/mod.slang @@ -0,0 +1,53 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -emit-spirv-directly + +// CHECK: 0 +// CHECK-NEXT: 0 +// CHECK-NEXT: 1 +// CHECK-NEXT: 2 + +// CHECK: 0 +// CHECK-NEXT: 0 +// CHECK: 1 +// CHECK-NEXT: 2 + +// CHECK: 0 +// CHECK-NEXT: 0 +// CHECK: 1 +// CHECK-NEXT: 2 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWByteAddressBuffer outputBuffer; + +//TEST_INPUT:ubuffer(data=[4 3 5 7], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +typealias CoopMatIntType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; +typealias CoopMatUintType = CoopMat<uint32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; +typealias CoopMatFloatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatIntType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatIntType.load(input2, 0, stride, matrixLayout); + + let mat3 = CoopMatFloatType(mat1); + let mat4 = CoopMatFloatType(mat2); + + let mat5 = CoopMatUintType(mat1); + let mat6 = CoopMatUintType(mat2); + + let result = mat1 % mat2; + let result2 = CoopMatIntType(mat3 % mat4); + let result3 = CoopMatIntType(mat5 % mat6); + + result.store(outputBuffer, 0, stride, matrixLayout); + result2.store(outputBuffer, 16, stride, matrixLayout); + result3.store(outputBuffer, 32, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/mul.slang b/tests/cooperative-matrix/mul.slang new file mode 100644 index 000000000..9b9fd67af --- /dev/null +++ b/tests/cooperative-matrix/mul.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 2 +// CHECK-NEXT: 6 +// CHECK-NEXT: 12 +// CHECK-NEXT: 20 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + + let result = mat1 * mat2; + result.store(outputBuffer, 0, 4, matrixLayout); +} diff --git a/tests/cooperative-matrix/out.slang b/tests/cooperative-matrix/out.slang new file mode 100644 index 000000000..147c2dd77 --- /dev/null +++ b/tests/cooperative-matrix/out.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 8.000000 + +// XXX FW: having out instead of in below actually properly creates two output buffers, nice +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer +StructuredBuffer<float> inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +void doubleCoopMat(CoopMatType mat, out CoopMatType result) +{ + result = mat * 2.0; +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + + CoopMatType result; + doubleCoopMat(mat, result); + result.store(outputBuffer, 0, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/parameter.slang b/tests/cooperative-matrix/parameter.slang new file mode 100644 index 000000000..19adf4177 --- /dev/null +++ b/tests/cooperative-matrix/parameter.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 3.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 9.000000 +// CHECK-NEXT: 12.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer +StructuredBuffer<float> inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +static let stride = 16; +//static let matrixLayout = CoopMatMatrixLayout::RowMajor; +static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor; + +void processCoopMat(CoopMatType mat) +{ + // XXX: hmmm, some error when matrixLAyout is static let + (mat * 3.0).store(outputBuffer, 0, stride, matrixLayout); +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + processCoopMat(mat); +} diff --git a/tests/cooperative-matrix/return.slang b/tests/cooperative-matrix/return.slang new file mode 100644 index 000000000..339c9d04d --- /dev/null +++ b/tests/cooperative-matrix/return.slang @@ -0,0 +1,32 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type + +// CHECK: type: float +// CHECK-NEXT: 3.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 9.000000 +// CHECK-NEXT: 12.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer +StructuredBuffer<float> inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +CoopMatType doubleCoopmat(CoopMatType mat) +{ + return mat * 3.0; +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + + let result = doubleCoopmat(mat); + result.store(outputBuffer, 0, 4, matrixLayout); +} diff --git a/tests/cooperative-matrix/scalar-mul.slang b/tests/cooperative-matrix/scalar-mul.slang new file mode 100644 index 000000000..73f0fbbfc --- /dev/null +++ b/tests/cooperative-matrix/scalar-mul.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 4.500000 +// CHECK-NEXT: 9.000000 +// CHECK-NEXT: 13.500000 +// CHECK-NEXT: 18.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=inputBuffer +ByteAddressBuffer inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + + let result = mat * 4.5; + result.store(outputBuffer, 0, 4, matrixLayout); +} diff --git a/tests/cooperative-matrix/struct.slang b/tests/cooperative-matrix/struct.slang new file mode 100644 index 000000000..24bc2c367 --- /dev/null +++ b/tests/cooperative-matrix/struct.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 1.000000 +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 3.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 5.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 7.000000 +// CHECK-NEXT: 8.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[5.0 6.0 7.0 8.0], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +struct MyStruct +{ + CoopMatType mat1; + CoopMatType mat2; +}; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + MyStruct s; + s.mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + s.mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + + s.mat1.store(outputBuffer, 0, stride, matrixLayout); + s.mat2.store(outputBuffer, 4, stride, matrixLayout); +} diff --git a/tests/cooperative-matrix/sub.slang b/tests/cooperative-matrix/sub.slang new file mode 100644 index 000000000..7b20d7c11 --- /dev/null +++ b/tests/cooperative-matrix/sub.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 +// CHECK-NEXT: 1 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +//TEST_INPUT:ubuffer(data=[2 3 4 5], stride=4, count=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=input2 +ByteAddressBuffer input2; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat1 = CoopMatType.load(input1, 0, stride, matrixLayout); + let mat2 = CoopMatType.load(input2, 0, stride, matrixLayout); + + let result = mat1 - mat2; + result.store(outputBuffer, 0, 4, matrixLayout); +} diff --git a/tests/cooperative-matrix/subscript-in-func.slang b/tests/cooperative-matrix/subscript-in-func.slang new file mode 100644 index 000000000..ef63d62de --- /dev/null +++ b/tests/cooperative-matrix/subscript-in-func.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK-NEXT: 1.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 9.000000 +// CHECK-NEXT: 16.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=4, count=256):name=inputBuffer +StructuredBuffer<float> inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias CoopMatType = CoopMat<float, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +static const int stride = 16; +static const CoopMatMatrixLayout matrixLayout = CoopMatMatrixLayout::RowMajor; + +void squareCoopMatElements(CoopMatType mat) +{ + for (int i = 0; i < 4; ++i) + { + mat[i] = mat[i] * mat[i]; + } + mat.store(outputBuffer, 0, stride, matrixLayout); +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + squareCoopMatElements(mat); +} diff --git a/tests/cooperative-matrix/subscript.slang b/tests/cooperative-matrix/subscript.slang new file mode 100644 index 000000000..731edee82 --- /dev/null +++ b/tests/cooperative-matrix/subscript.slang @@ -0,0 +1,23 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: 2 +// CHECK-NEXT: 4 +// CHECK: 7 +// CHECK-NEXT: 11 + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + CoopMatType mat; + mat[0] = 2; + mat[1] = mat[0]+2; + mat[2] = mat[1]+3; + mat[3] = mat[2]+4; + mat.store(outputBuffer, 0, 16, CoopMatMatrixLayout::RowMajor); +} diff --git a/tests/cooperative-matrix/unary_neg.slang b/tests/cooperative-matrix/unary_neg.slang new file mode 100644 index 000000000..8c6436caf --- /dev/null +++ b/tests/cooperative-matrix/unary_neg.slang @@ -0,0 +1,27 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: int32_t +// CHECK-NEXT: -1 +// CHECK-NEXT: -2 +// CHECK-NEXT: -3 +// CHECK-NEXT: -4 + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4, count=256),name=inputBuffer +ByteAddressBuffer inputBuffer; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +typealias CoopMatType = CoopMat<int32_t, CoopMatScope::Subgroup, 16, 16, CoopMatMatrixUse::MatrixAccumulator>; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 4; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + let mat = CoopMatType.load(inputBuffer, 0, stride, matrixLayout); + + let result = -mat; + result.store(outputBuffer, 0, 4, matrixLayout); +} |
