summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang537
-rw-r--r--source/slang/slang-capabilities.capdef54
-rw-r--r--source/slang/slang-emit-spirv-ops.h1
4 files changed, 470 insertions, 134 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 5911c997c..140c9ba16 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3478,6 +3478,18 @@ enum MemoryOrder
SeqCst = $(kIRMemoryOrder_SeqCst),
}
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
+enum MemoryScope
+{
+ CrossDevice = 0,
+ Device = 1,
+ Workgroup = 2,
+ Subgroup = 3,
+ Invocation = 4,
+ QueueFamily = 5,
+ ShaderCallKHR = 6,
+};
+
/// Represents types that can be used in any atomic operations.
/// Implemented by builtin scalar types: `int`, `uint`, `int64_t`, `uint64_t`, `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `float`, `double` and `half`.
[sealed] interface IAtomicable {}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 2382d4a9a..829d5ce97 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22509,13 +22509,52 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; }
}
+namespace linalg
+{
+
+//
+// Cooperative Matrix enums
+//
+
+enum CoopMatMatrixUse
+{
+ MatrixA = 0,
+ MatrixB = 1,
+ MatrixAccumulator = 2,
+};
+
+enum CoopMatMatrixLayout
+{
+ RowMajor = 0,
+ ColumnMajor = 1,
+};
+
+enum CoopMatClampMode
+{
+ Undefined,
+ Constant,
+ ClampToEdge,
+ Repeat,
+ RepeatMirrored
+};
+
+
//
// Cooperative Matrix type
//
__intrinsic_type($(kIROp_CoopMatrixType))
[require(cooperative_matrix)]
-struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic
+__generic<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>
+struct CoopMat
+ : IArray<T>
+ , IArithmetic
{
//
// Initialization
@@ -22535,21 +22574,25 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [require(cooperative_matrix)]
- __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ [require(cooperative_matrix_conversion)]
+ __init<
+ U : __BuiltinArithmeticType
+ >(CoopMat<U, S, M, N, R> other)
{
this.copyFrom(other);
}
[ForceInline]
+ [require(cooperative_matrix)]
__init(This x)
{
this = x;
}
// Required for `IArithmetic`.
- [OverloadRank(-10)]
[ForceInline]
+ [OverloadRank(-10)]
+ [require(cooperative_matrix)]
__init(int i)
{
this = CoopMat<T, S, M, N, R>(T(i));
@@ -22559,9 +22602,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
// Simple setters
//
- [require(cooperative_matrix)]
- [mutating]
+ /// Fills the cooperative matrix with the specified value.
+ /// @param t The value to fill the matrix with.
[ForceInline]
+ [mutating]
+ [require(cooperative_matrix)]
void fill(T t)
{
this = spirv_asm
@@ -22570,10 +22615,15 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [require(cooperative_matrix)]
- [mutating]
+ /// Copies the contents from another cooperative matrix into this matrix.
+ /// @param U The element type of the source cooperative matrix.
+ /// @param other The source cooperative matrix to copy from.
[ForceInline]
- void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ [mutating]
+ [require(cooperative_matrix_conversion)]
+ void copyFrom<
+ U : __BuiltinArithmeticType
+ >(CoopMat<U, S, M, N, R> other)
{
if (__isFloat<T>() && __isInt<U>())
this = __int_to_float_cast<T>(other);
@@ -22598,25 +22648,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[__NoSideEffect]
Ref<T> __indexRef(int index);
+ /// Returns the count as an integer value.
[ForceInline]
+ [require(cooperative_matrix)]
[__NoSideEffect]
int getCount()
{
- return getLength();
- }
-
- [ForceInline]
- [__NoSideEffect]
- int getRowCount()
- {
- return M;
- }
-
- [ForceInline]
- [__NoSideEffect]
- int getColumnCount()
- {
- return N;
+ return GetLength();
}
__subscript(int index) -> T
@@ -22635,14 +22673,111 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
}
- /// Returns the number of components owned by each invocation.
+ //
+ // CoopMat operations
+ //
+
+ /// Returns the number of elements for the current thread.
+ /// Depending on the number of threads for the given matrix, each
+ /// thread will get smaller length.
+ ///
+ /// @remarks The return value is unlikely to be same to M * N.
[ForceInline]
[require(cooperative_matrix)]
- uint getLength()
+ static uint GetLength()
+ {
+ return spirv_asm
+ {
+ result:$$uint = OpCooperativeMatrixLengthKHR $$CoopMat<T, S, M, N, R>;
+ };
+ }
+
+ /// Returns the number of rows in the matrix.
+ [ForceInline]
+ [__NoSideEffect]
+ static int GetRowCount()
+ {
+ return M;
+ }
+
+ /// Returns the number of columns in the matrix.
+ [ForceInline]
+ [__NoSideEffect]
+ static int GetColumnCount()
+ {
+ return N;
+ }
+
+ [require(cooperative_matrix_conversion)]
+ CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> Transpose()
{
return spirv_asm
{
- result:$$uint = OpCooperativeMatrixLengthKHR $$This;
+ OpCapability CooperativeMatrixConversionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, N, M, CoopMatMatrixUse.MatrixB> = OpCooperativeMatrixTransposeNV $this;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRow<
+ let RN : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> ReduceColumn<
+ let RM : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, RM, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Column $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> ReduceRowAndColumn<
+ let RM : int,
+ let RN : int
+ >(functype(T, T) -> T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, RM, RN, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this Row|Column $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_reduction)]
+ CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> Reduce2x2(functype(T, T)->T combineOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixReductionsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M / 2, N / 2, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixReduceNV $this 0x4 $combineOp;
+ };
+ }
+
+ [require(cooperative_matrix_map_element)]
+ This MapElement(functype(uint32_t, uint32_t, T)->T mapOp)
+ {
+ return spirv_asm
+ {
+ OpCapability CooperativeMatrixPerElementOperationsNV;
+ OpExtension "SPV_NV_cooperative_matrix2";
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixPerElementOpNV $this $mapOp;
};
}
@@ -22650,16 +22785,24 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
// Store
//
- [ForceInline]
[require(cooperative_matrix)]
- void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
- return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ __Store(buffer, element, stride, matrixLayout);
+ }
+
+ [require(cooperative_matrix)]
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
+ {
+ __Store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
}
- [ForceInline]
[require(cooperative_matrix)]
- void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ internal void __Store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
{
let zero = 0;
let alignment = 16U;
@@ -22671,10 +22814,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout
+ >(T* buffer, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22684,9 +22827,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [require(cooperative_matrix)]
[ForceInline]
- void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ [require(cooperative_matrix)]
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ let V : int
+ >(__ref groupshared T[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22698,9 +22844,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [ForceInline]
[require(cooperative_matrix)]
- void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int
+ >(__ref groupshared U[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22712,9 +22861,13 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[ForceInline]
- [ForceInline]
[require(cooperative_matrix)]
- void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ void Store<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int,
+ let L : int
+ >(__ref groupshared vector<U, L>[V] data, uint element, uint stride)
{
let alignment = 16U;
spirv_asm
@@ -22725,30 +22878,34 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
+
//
// Load
//
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(ByteAddressBuffer buffer, uint element, uint stride)
{
- return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWByteAddressBuffer buffer, uint element, uint stride)
{
- return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ return Load<matrixLayout>(__getEquivalentStructuredBuffer<T>(buffer), element, stride);
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(StructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22761,9 +22918,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
[__NoSideEffect]
- [ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(RWStructuredBuffer<T> buffer, uint element, uint stride)
{
let zero = 0;
let alignment = 16U;
@@ -22775,10 +22933,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
};
}
- [__NoSideEffect]
[ForceInline]
+ [__NoSideEffect]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout
+ >(T* buffer, uint element, uint stride)
{
let alignment = 16;
return spirv_asm
@@ -22790,7 +22950,10 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ let V : int
+ >(__constref groupshared T[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22803,7 +22966,11 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int
+ >(__constref groupshared U[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22816,7 +22983,12 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
- static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ static This Load<
+ let matrixLayout : CoopMatMatrixLayout,
+ U,
+ let V : int,
+ let L : int
+ >(__constref groupshared vector<U, L>[V] data, uint element, uint stride)
{
let alignment = 16U;
return spirv_asm
@@ -22849,7 +23021,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
This mod(This other)
{
This ret;
- for (int i = 0; i < getLength(); ++i)
+ for (int i = 0; i < GetLength(); ++i)
{
ret[i] = this[i] % other[i];
}
@@ -22862,7 +23034,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool equals(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] != other[i])
{
@@ -22874,7 +23046,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool lessThan(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] < other[i])
{
@@ -22890,7 +23062,7 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
bool lessThanOrEquals(This other)
{
- for (int i = 0; i < getLength(); i++)
+ for (int i = 0; i < GetLength(); i++)
{
if (this[i] < other[i])
{
@@ -22903,7 +23075,9 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
}
return true;
}
-}
+
+} // struct CoopMat
+
//
// Convenience loading functions for cooperative matrices which infer the
@@ -22912,78 +23086,168 @@ struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, l
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ ByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ RWByteAddressBuffer buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ StructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ RWStructuredBuffer<T> buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout
+>(
+ T* buffer,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(buffer, element, stride);
}
[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+CoopMat<T, S, M, N, R> coopMatLoad<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse,
+ let matrixLayout : CoopMatMatrixLayout,
+ let U : int
+>(
+ __constref groupshared T[U] data,
+ uint element,
+ uint stride)
{
- return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout);
+ return CoopMat<T, S, M, N, R>.Load<matrixLayout>(data, element, stride);
}
+
//
// Cooperative Matrix casting
//
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_IntCast))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __int_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_FloatCast))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __real_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_CastIntToFloat))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val);
-
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+CoopMat<T,S,M,N,R> __int_to_float_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
+
+[require(cooperative_matrix_conversion)]
__intrinsic_op($(kIROp_CastFloatToInt))
-[require(cooperative_matrix)]
-CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val);
+CoopMat<T,S,M,N,R> __float_to_int_cast<
+ T : __BuiltinArithmeticType,
+ U : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<U,S,M,N,R> val);
//
// Cooperative Matrix multiplication with scalar
//
-__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
-[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
+CoopMat<T, S, M, N, R> operator *<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(CoopMat<T, S, M, N, R> lhs, const T rhs)
{
return spirv_asm
{
@@ -22991,64 +23255,69 @@ CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
};
}
-__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
-[ForceInline]
[require(cooperative_matrix)]
-CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs)
+CoopMat<T, S, M, N, R> operator *<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : CoopMatMatrixUse
+>(const T lhs, CoopMat<T, S, M, N, R> rhs)
{
return rhs * lhs;
}
//
-// Cooperative Matrix enums
-//
-
-enum CoopMatScope
-{
- Device = 1,
- Workgroup = 2,
- Subgroup = 3,
- QueueFamily = 5,
-};
-
-enum CoopMatMatrixUse
-{
- MatrixA = 0,
- MatrixB = 1,
- MatrixAccumulator = 2,
-};
-
-enum CoopMatMatrixLayout
-{
- RowMajor = 0,
- ColumnMajor = 1,
-};
-
-enum CoopMatMatrixOperands
-{
- None = 0x0,
- MatrixASigned = 0x1,
- MatrixBSigned = 0x2,
- MatrixCSigned = 0x4,
- MatrixResultSigned = 0x8,
- SaturatingAccumulation = 0x10,
-};
-
-//
-// Cooperative Matrix multiply accumulate
+// Cooperative Matrix Multiply-Accumulate
//
[require(cooperative_matrix)]
-__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse>
-CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands)
+CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd<
+ T : __BuiltinArithmeticType,
+ let saturatingAccumulation : bool,
+ U : __BuiltinArithmeticType,
+ V : __BuiltinArithmeticType,
+ W : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let K : int,
+ let N : int
+>(
+ CoopMat<U, S, M, K, CoopMatMatrixUse.MatrixA> matA,
+ CoopMat<V, S, K, N, CoopMatMatrixUse.MatrixB> matB,
+ CoopMat<W, S, M, N, CoopMatMatrixUse.MatrixAccumulator> matC)
{
- static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`");
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
+ int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x01; // MatrixASignedComponentsKHR
+ }
+ if (__isSignedInt<V>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ }
+ if (__isSignedInt<W>())
+ {
+ operands |= 0x04; // MatrixCSignedComponentsKHR
+ }
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
+ if (saturatingAccumulation)
+ {
+ operands |= 0x10; // SaturatingAccumulationKHR
+ }
+
return spirv_asm
{
- result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
+ result:$$CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
};
}
+} // namespace linalg
+
//
// Cooperative Vector
//
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index f509835aa..b909fa0f9 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -575,6 +575,18 @@ def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
/// [EXT]
def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+/// Represents the SPIR-V extension for SPV_NV_cooperative_matrix2.
+/// [EXT]
+def SPV_NV_cooperative_matrix2 : _spirv_1_6 + SPV_KHR_cooperative_matrix;
+
+/// Represents the SPIR-V extension for SPV_NV_tensor_addressing.
+/// [EXT]
+def SPV_NV_tensor_addressing : _spirv_1_6;
+
+/// Represents the SPIR-V extension for SPV_KHR_vulkan_memory_model.
+/// [EXT]
+def SPV_KHR_vulkan_memory_model : _spirv_1_3;
+
// SPIRV Capabilities.
/// Represents the SPIR-V capability for atomic float 32 add operations.
@@ -737,6 +749,30 @@ def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector;
/// [EXT]
def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix;
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixReductionsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixConversionsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixPerElementOperationsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixTensorAddressingNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for cooperative matrix 2
+/// [EXT]
+def spvCooperativeMatrixBlockLoadsNV : SPV_NV_cooperative_matrix2;
+
+/// Represents the SPIR-V capability for tensor addressing
+/// [EXT]
+def spvTensorAddressingNV : SPV_NV_tensor_addressing;
+
/// Represents the SPIR-V capability for maximal reconvergence.
/// [EXT]
def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence;
@@ -1129,6 +1165,24 @@ alias cooperative_vector_training = spvCooperativeVectorTrainingNV;
/// Capabilities needed to use cooperative matrices
alias cooperative_matrix = spvCooperativeMatrixKHR;
+/// Capabilities needed to use reduction operations with cooperative matrix
+/// [Compound]
+alias cooperative_matrix_reduction = spvCooperativeMatrixReductionsNV;
+/// Capabilities needed to convert cooperative matrices
+/// [Compound]
+alias cooperative_matrix_conversion = spvCooperativeMatrixConversionsNV;
+/// Capabilities needed to use MapElement operation with cooperative matrix
+/// [Compound]
+alias cooperative_matrix_map_element = spvCooperativeMatrixPerElementOperationsNV;
+/// Capabilities needed to load or store with tensor_addressing extension
+/// [Compound]
+alias cooperative_matrix_tensor_addressing = spvCooperativeMatrixTensorAddressingNV;
+/// Capabilities needed to use decodeFunc with cooperative matrix load
+/// [Compound]
+alias cooperative_matrix_block_load = spvCooperativeMatrixBlockLoadsNV;
+/// Capabilities needed to use tensor addressing
+/// [Compound]
+alias tensor_addressing = spvTensorAddressingNV;
// Non-internal shader stages
//
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 017e58667..b716e470d 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -151,6 +151,7 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp
componentCount);
}
+// https://github.khronos.org/SPIRV-Registry/extensions/NV/SPV_NV_cooperative_matrix.html#OpTypeCooperativeMatrixNV
template<typename T1, typename T2>
SpvInst* emitOpTypeCoopMat(
IRInst* inst,