diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 12 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 537 | ||||
| -rw-r--r-- | source/slang/slang-capabilities.capdef | 54 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 1 |
4 files changed, 470 insertions, 134 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 5911c997c..140c9ba16 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3478,6 +3478,18 @@ enum MemoryOrder SeqCst = $(kIRMemoryOrder_SeqCst), } +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id +enum MemoryScope +{ + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + ShaderCallKHR = 6, +}; + /// Represents types that can be used in any atomic operations. /// Implemented by builtin scalar types: `int`, `uint`, `int64_t`, `uint64_t`, `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `float`, `double` and `half`. [sealed] interface IAtomicable {} diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 2382d4a9a..829d5ce97 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -22509,13 +22509,52 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } } +namespace linalg +{ + +// +// Cooperative Matrix enums +// + +enum CoopMatMatrixUse +{ + MatrixA = 0, + MatrixB = 1, + MatrixAccumulator = 2, +}; + +enum CoopMatMatrixLayout +{ + RowMajor = 0, + ColumnMajor = 1, +}; + +enum CoopMatClampMode +{ + Undefined, + Constant, + ClampToEdge, + Repeat, + RepeatMirrored +}; + + // // Cooperative Matrix type // __intrinsic_type($(kIROp_CoopMatrixType)) [require(cooperative_matrix)] -struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic +__generic< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +> +struct CoopMat + : IArray<T> + , IArithmetic { // // Initialization @@ -22535,21 +22574,25 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } [ForceInline] - [require(cooperative_matrix)] - __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + [require(cooperative_matrix_conversion)] + __init< + U : __BuiltinArithmeticType + >(CoopMat<U, S, M, N, R> other) { this.copyFrom(other); } [ForceInline] + [require(cooperative_matrix)] __init(This x) { this = x; } // Required for `IArithmetic`. - [OverloadRank(-10)] [ForceInline] + [OverloadRank(-10)] + [require(cooperative_matrix)] __init(int i) { this = CoopMat<T, S, M, N, R>(T(i)); @@ -22559,9 +22602,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l // Simple setters // - [require(cooperative_matrix)] - [mutating] + /// Fills the cooperative matrix with the specified value. + /// @param t The value to fill the matrix with. [ForceInline] + [mutating] + [require(cooperative_matrix)] void fill(T t) { this = spirv_asm @@ -22570,10 +22615,15 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l }; } - [require(cooperative_matrix)] - [mutating] + /// Copies the contents from another cooperative matrix into this matrix. + /// @param U The element type of the source cooperative matrix. + /// @param other The source cooperative matrix to copy from. [ForceInline] - void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + [mutating] + [require(cooperative_matrix_conversion)] + void copyFrom< + U : __BuiltinArithmeticType + >(CoopMat<U, S, M, N, R> other) { if (__isFloat<T>() && __isInt<U>()) this = __int_to_float_cast<T>(other); @@ -22598,25 +22648,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l [__NoSideEffect] Ref<T> __indexRef(int index); + /// Returns the count as an integer value. [ForceInline] + [require(cooperative_matrix)] [__NoSideEffect] int getCount() { - return getLength(); - } - - [ForceInline] - [__NoSideEffect] - int getRowCount() - { - return M; - } - - [ForceInline] - [__NoSideEffect] - int getColumnCount() - { - return N; + return GetLength(); } __subscript(int index) -> T @@ -22635,14 +22673,111 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } } - /// Returns the number of components owned by each invocation. + // + // CoopMat operations + // + + /// Returns the number of elements for the current thread. + /// Depending on the number of threads for the given matrix, each + /// thread will get smaller length. + /// + /// @remarks The return value is unlikely to be same to M * N. [ForceInline] [require(cooperative_matrix)] - uint getLength() + static uint GetLength() + { + return spirv_asm + { + result:$$uint = OpCooperativeMatrixLengthKHR $$CoopMat<T, S, M, N, R>; + }; + } + + /// Returns the number of rows in the matrix. + [ForceInline] + [__NoSideEffect] + static int GetRowCount() + { + return M; + } + + /// Returns the number of columns in the matrix. + [ForceInline] + [__NoSideEffect] + static int GetColumnCount() + { + return N; + } + + [require(cooperative_matrix_conversion)] + CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> Transpose() { return spirv_asm { - result:$$uint = OpCooperativeMatrixLengthKHR $$This; + OpCapability CooperativeMatrixConversionsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> = OpCooperativeMatrixTransposeNV $this; + }; + } + + [require(cooperative_matrix_reduction)] + CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRow< + let RN : int + >(functype(T, T) -> T combineOp) + { + return spirv_asm + { + OpCapability CooperativeMatrixReductionsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row $combineOp; + }; + } + + [require(cooperative_matrix_reduction)] + CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> ReduceColumn< + let RM : int + >(functype(T, T) -> T combineOp) + { + return spirv_asm + { + OpCapability CooperativeMatrixReductionsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Column $combineOp; + }; + } + + [require(cooperative_matrix_reduction)] + CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRowAndColumn< + let RM : int, + let RN : int + >(functype(T, T) -> T combineOp) + { + return spirv_asm + { + OpCapability CooperativeMatrixReductionsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row|Column $combineOp; + }; + } + + [require(cooperative_matrix_reduction)] + CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> Reduce2x2(functype(T, T)->T combineOp) + { + return spirv_asm + { + OpCapability CooperativeMatrixReductionsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this 0x4 $combineOp; + }; + } + + [require(cooperative_matrix_map_element)] + This MapElement(functype(uint32_t, uint32_t, T)->T mapOp) + { + return spirv_asm + { + OpCapability CooperativeMatrixPerElementOperationsNV; + OpExtension "SPV_NV_cooperative_matrix2"; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixPerElementOpNV $this $mapOp; }; } @@ -22650,16 +22785,24 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l // Store // - [ForceInline] [require(cooperative_matrix)] - void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + void Store< + let matrixLayout : CoopMatMatrixLayout + >(RWStructuredBuffer<T> buffer, uint element, uint stride) { - return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + __Store(buffer, element, stride, matrixLayout); + } + + [require(cooperative_matrix)] + void Store< + let matrixLayout : CoopMatMatrixLayout + >(RWByteAddressBuffer buffer, uint element, uint stride) + { + __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); } - [ForceInline] [require(cooperative_matrix)] - void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) { let zero = 0; let alignment = 16U; @@ -22671,10 +22814,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l }; } - [__NoSideEffect] - [ForceInline] [require(cooperative_matrix)] - void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + void Store< + let matrixLayout : CoopMatMatrixLayout + >(T* buffer, uint element, uint stride) { let alignment = 16U; return spirv_asm @@ -22684,9 +22827,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l }; } - [require(cooperative_matrix)] [ForceInline] - void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + [require(cooperative_matrix)] + void Store< + let matrixLayout : CoopMatMatrixLayout, + let V : int + >(__ref groupshared T[V] data, uint element, uint stride) { let alignment = 16U; spirv_asm @@ -22698,9 +22844,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } [ForceInline] - [ForceInline] [require(cooperative_matrix)] - void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + void Store< + let matrixLayout : CoopMatMatrixLayout, + U, + let V : int + >(__ref groupshared U[V] data, uint element, uint stride) { let alignment = 16U; spirv_asm @@ -22712,9 +22861,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } [ForceInline] - [ForceInline] [require(cooperative_matrix)] - void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + void Store< + let matrixLayout : CoopMatMatrixLayout, + U, + let V : int, + let L : int + >(__ref groupshared vector<U, L>[V] data, uint element, uint stride) { let alignment = 16U; spirv_asm @@ -22725,30 +22878,34 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l }; } + // // Load // [__NoSideEffect] - [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout + >(ByteAddressBuffer buffer, uint element, uint stride) { - return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } [__NoSideEffect] - [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout + >(RWByteAddressBuffer buffer, uint element, uint stride) { - return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride); } [__NoSideEffect] - [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout + >(StructuredBuffer<T> buffer, uint element, uint stride) { let zero = 0; let alignment = 16U; @@ -22761,9 +22918,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } [__NoSideEffect] - [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout + >(RWStructuredBuffer<T> buffer, uint element, uint stride) { let zero = 0; let alignment = 16U; @@ -22775,10 +22933,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l }; } - [__NoSideEffect] [ForceInline] + [__NoSideEffect] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout + >(T* buffer, uint element, uint stride) { let alignment = 16; return spirv_asm @@ -22790,7 +22950,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout, + let V : int + >(__constref groupshared T[V] data, uint element, uint stride) { let alignment = 16U; return spirv_asm @@ -22803,7 +22966,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout, + U, + let V : int + >(__constref groupshared U[V] data, uint element, uint stride) { let alignment = 16U; return spirv_asm @@ -22816,7 +22983,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l [ForceInline] [require(cooperative_matrix)] - static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + static This Load< + let matrixLayout : CoopMatMatrixLayout, + U, + let V : int, + let L : int + >(__constref groupshared vector<U, L>[V] data, uint element, uint stride) { let alignment = 16U; return spirv_asm @@ -22849,7 +23021,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l This mod(This other) { This ret; - for (int i = 0; i < getLength(); ++i) + for (int i = 0; i < GetLength(); ++i) { ret[i] = this[i] % other[i]; } @@ -22862,7 +23034,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l bool equals(This other) { - for (int i = 0; i < getLength(); i++) + for (int i = 0; i < GetLength(); i++) { if (this[i] != other[i]) { @@ -22874,7 +23046,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l bool lessThan(This other) { - for (int i = 0; i < getLength(); i++) + for (int i = 0; i < GetLength(); i++) { if (this[i] < other[i]) { @@ -22890,7 +23062,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l bool lessThanOrEquals(This other) { - for (int i = 0; i < getLength(); i++) + for (int i = 0; i < GetLength(); i++) { if (this[i] < other[i]) { @@ -22903,7 +23075,9 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l } return true; } -} + +} // struct CoopMat + // // Convenience loading functions for cooperative matrices which infer the @@ -22912,78 +23086,168 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout +>( + ByteAddressBuffer buffer, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); } [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout +>( + RWByteAddressBuffer buffer, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); } [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout +>( + StructuredBuffer<T> buffer, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); } [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout +>( + RWStructuredBuffer<T> buffer, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); } [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout +>( + T* buffer, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride); } [ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +CoopMat<T, S, M, N, R> coopMatLoad< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse, + let matrixLayout : CoopMatMatrixLayout, + let U : int +>( + __constref groupshared T[U] data, + uint element, + uint stride) { - return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout); + return CoopMat<T, S, M, N, R>.Load<matrixLayout>(data, element, stride); } + // // Cooperative Matrix casting // -__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +[require(cooperative_matrix_conversion)] __intrinsic_op($(kIROp_IntCast)) -[require(cooperative_matrix)] -CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val); - -__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +CoopMat<T,S,M,N,R> __int_cast< + T : __BuiltinArithmeticType, + U : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(CoopMat<U,S,M,N,R> val); + +[require(cooperative_matrix_conversion)] __intrinsic_op($(kIROp_FloatCast)) -[require(cooperative_matrix)] -CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val); - -__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +CoopMat<T,S,M,N,R> __real_cast< + T : __BuiltinArithmeticType, + U : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(CoopMat<U,S,M,N,R> val); + +[require(cooperative_matrix_conversion)] __intrinsic_op($(kIROp_CastIntToFloat)) -[require(cooperative_matrix)] -CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val); - -__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +CoopMat<T,S,M,N,R> __int_to_float_cast< + T : __BuiltinArithmeticType, + U : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(CoopMat<U,S,M,N,R> val); + +[require(cooperative_matrix_conversion)] __intrinsic_op($(kIROp_CastFloatToInt)) -[require(cooperative_matrix)] -CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val); +CoopMat<T,S,M,N,R> __float_to_int_cast< + T : __BuiltinArithmeticType, + U : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(CoopMat<U,S,M,N,R> val); // // Cooperative Matrix multiplication with scalar // -__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> -[ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs) +CoopMat<T, S, M, N, R> operator *< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(CoopMat<T, S, M, N, R> lhs, const T rhs) { return spirv_asm { @@ -22991,64 +23255,69 @@ CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs) }; } -__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> -[ForceInline] [require(cooperative_matrix)] -CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs) +CoopMat<T, S, M, N, R> operator *< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : CoopMatMatrixUse +>(const T lhs, CoopMat<T, S, M, N, R> rhs) { return rhs * lhs; } // -// Cooperative Matrix enums -// - -enum CoopMatScope -{ - Device = 1, - Workgroup = 2, - Subgroup = 3, - QueueFamily = 5, -}; - -enum CoopMatMatrixUse -{ - MatrixA = 0, - MatrixB = 1, - MatrixAccumulator = 2, -}; - -enum CoopMatMatrixLayout -{ - RowMajor = 0, - ColumnMajor = 1, -}; - -enum CoopMatMatrixOperands -{ - None = 0x0, - MatrixASigned = 0x1, - MatrixBSigned = 0x2, - MatrixCSigned = 0x4, - MatrixResultSigned = 0x8, - SaturatingAccumulation = 0x10, -}; - -// -// Cooperative Matrix multiply accumulate +// Cooperative Matrix Multiply-Accumulate // [require(cooperative_matrix)] -__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse> -CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands) +CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd< + T : __BuiltinArithmeticType, + let saturatingAccumulation : bool, + U : __BuiltinArithmeticType, + V : __BuiltinArithmeticType, + W : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let K : int, + let N : int +>( + CoopMat<U, S, M, K, CoopMatMatrixUse.MatrixA> matA, + CoopMat<V, S, K, N, CoopMatMatrixUse.MatrixB> matB, + CoopMat<W, S, M, N, CoopMatMatrixUse.MatrixAccumulator> matC) { - static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`"); + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands + int operands = 0; // NoneKHR + if (__isSignedInt<U>()) + { + operands |= 0x01; // MatrixASignedComponentsKHR + } + if (__isSignedInt<V>()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + } + if (__isSignedInt<W>()) + { + operands |= 0x04; // MatrixCSignedComponentsKHR + } + if (__isSignedInt<T>()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + if (saturatingAccumulation) + { + operands |= 0x10; // SaturatingAccumulationKHR + } + return spirv_asm { - result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; + result:$$CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; }; } +} // namespace linalg + // // Cooperative Vector // diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index f509835aa..b909fa0f9 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -575,6 +575,18 @@ def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites; /// [EXT] def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer; +/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2. +/// [EXT] +def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix; + +/// Represents the SPIR-V extension for SPV_NV_tensor_addressing. +/// [EXT] +def SPV_NV_tensor_addressing : _spirv_1_6; + +/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model. +/// [EXT] +def SPV_KHR_vulkan_memory_model : _spirv_1_3; + // SPIRV Capabilities. /// Represents the SPIR-V capability for atomic float 32 add operations. @@ -737,6 +749,30 @@ def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector; /// [EXT] def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix; +/// Represents the SPIR-V capability for cooperative matrix 2 +/// [EXT] +def spvCooperativeMatrixReductionsNV : SPV_NV_cooperative_matrix2; + +/// Represents the SPIR-V capability for cooperative matrix 2 +/// [EXT] +def spvCooperativeMatrixConversionsNV : SPV_NV_cooperative_matrix2; + +/// Represents the SPIR-V capability for cooperative matrix 2 +/// [EXT] +def spvCooperativeMatrixPerElementOperationsNV : SPV_NV_cooperative_matrix2; + +/// Represents the SPIR-V capability for cooperative matrix 2 +/// [EXT] +def spvCooperativeMatrixTensorAddressingNV : SPV_NV_cooperative_matrix2; + +/// Represents the SPIR-V capability for cooperative matrix 2 +/// [EXT] +def spvCooperativeMatrixBlockLoadsNV : SPV_NV_cooperative_matrix2; + +/// Represents the SPIR-V capability for tensor addressing +/// [EXT] +def spvTensorAddressingNV : SPV_NV_tensor_addressing; + /// Represents the SPIR-V capability for maximal reconvergence. /// [EXT] def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence; @@ -1129,6 +1165,24 @@ alias cooperative_vector_training = spvCooperativeVectorTrainingNV; /// Capabilities needed to use cooperative matrices alias cooperative_matrix = spvCooperativeMatrixKHR; +/// Capabilities needed to use reduction operations with cooperative matrix +/// [Compound] +alias cooperative_matrix_reduction = spvCooperativeMatrixReductionsNV; +/// Capabilities needed to convert cooperative matrices +/// [Compound] +alias cooperative_matrix_conversion = spvCooperativeMatrixConversionsNV; +/// Capabilities needed to use MapElement operation with cooperative matrix +/// [Compound] +alias cooperative_matrix_map_element = spvCooperativeMatrixPerElementOperationsNV; +/// Capabilities needed to load or store with tensor_addressing extension +/// [Compound] +alias cooperative_matrix_tensor_addressing = spvCooperativeMatrixTensorAddressingNV; +/// Capabilities needed to use decodeFunc with cooperative matrix load +/// [Compound] +alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV; +/// Capabilities needed to use tensor addressing +/// [Compound] +alias tensor_addressing = spvTensorAddressingNV; // Non-internal shader stages // diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 017e58667..b716e470d 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -151,6 +151,7 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp componentCount); } +// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_cooperative_matrix.html#OpTypeCooperativeMatrixNV template<typename T1, typename T2> SpvInst* emitOpTypeCoopMat( IRInst* inst, |
