summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-04-15 15:57:45 -0600
committerGitHub <noreply@github.com>2025-04-15 14:57:45 -0700
commitd0b6a0b1ab49b5958015f31364c5ad73d9cd03eb (patch)
treee419bb3c89fa8c389eb0ccbbe8aaa29a1dcd515f /source
parenta6174ff9443507dece534aa193f8c45e8f0ce7db (diff)
Add cooperative matrix 1 support (#6565)
* initial wip for spirv * working tiled example * clean up store and load * minor fixes * fix loadAny name * add initial tests, including broken/unimplemented intrinsics * fix subscript * run tests at 16x16, remove not supported arithmetic tests * minor fixups on implementation * rename CoopMatMatrixUse * Update tests to pass validation layers locally * Add mat-mul-add test and minor fixes * Add more tests * Remove dead code * Add coopMatLoad function and tests, enforce constexpr for matrix layout * Use getVectorOrCoopMatrixElementType in place of getVectorElementType
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang541
-rw-r--r--source/slang/slang-capabilities.capdef11
-rw-r--r--source/slang/slang-emit-spirv-ops.h22
-rw-r--r--source/slang/slang-emit-spirv.cpp51
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-util.cpp12
-rw-r--r--source/slang/slang-ir-util.h3
-rw-r--r--source/slang/slang-ir.cpp4
-rw-r--r--source/slang/slang-ir.h11
9 files changed, 645 insertions, 11 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index bdaa2bad0..e71997c6c 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -22009,6 +22009,546 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
}
//
+// Cooperative Matrix type
+//
+
+__intrinsic_type($(kIROp_CoopMatrixType))
+[require(cooperative_matrix)]
+struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic
+{
+ //
+ // Initialization
+ //
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init()
+ {
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init(T t)
+ {
+ this.fill(t);
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ {
+ this.copyFrom(other);
+ }
+
+ [ForceInline]
+ __init(This x)
+ {
+ this = x;
+ }
+
+ // Required for `IArithmetic`.
+ [OverloadRank(-10)]
+ [ForceInline]
+ __init(int i)
+ {
+ this = CoopMat<T, S, M, N, R>(T(i));
+ }
+
+ //
+ // Simple setters
+ //
+
+ [require(cooperative_matrix)]
+ [mutating]
+ [ForceInline]
+ void fill(T t)
+ {
+ this = spirv_asm
+ {
+ result:$$CoopMat<T, S, M, N, R> = OpConstantComposite $t;
+ };
+ }
+
+ [require(cooperative_matrix)]
+ [mutating]
+ [ForceInline]
+ void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other)
+ {
+ if (__isFloat<T>() && __isInt<U>())
+ this = __int_to_float_cast<T>(other);
+ else if (__isInt<T>() && __isFloat<U>())
+ this = __float_to_int_cast<T>(other);
+ else if (__isFloat<T>() && __isFloat<U>())
+ this = __real_cast<T>(other);
+ else if (__isInt<T>() && __isInt<U>())
+ this = __int_cast<T>(other);
+ }
+
+ //
+ // Subscript
+ //
+
+ __intrinsic_op($(kIROp_GetElement))
+ [__NoSideEffect]
+ T __indexRead(int index);
+
+ __intrinsic_op($(kIROp_GetElementPtr))
+ [__ref]
+ [__NoSideEffect]
+ Ref<T> __indexRef(int index);
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getCount()
+ {
+ return getLength();
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getRowCount()
+ {
+ return M;
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getColumnCount()
+ {
+ return N;
+ }
+
+ __subscript(int index) -> T
+ {
+ [__NoSideEffect]
+ [nonmutating]
+ get
+ {
+ return __indexRead(index);
+ }
+
+ [mutating]
+ set
+ {
+ __indexRef(index) = newValue;
+ }
+ }
+
+ /// Returns the number of components owned by each invocation.
+ [ForceInline]
+ [require(cooperative_matrix)]
+ uint getLength()
+ {
+ return spirv_asm
+ {
+ result:$$uint = OpCooperativeMatrixLengthKHR $$This;
+ };
+ }
+
+ //
+ // Store
+ //
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %pointer:$$T* = OpPtrAccessChain $buffer $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [require(cooperative_matrix)]
+ [ForceInline]
+ void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$T;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$U;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ //
+ // Load
+ //
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout);
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let zero = 0;
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %storagePointerType = OpTypePointer StorageBuffer $$T;
+ %pointer:%storagePointerType = OpAccessChain $buffer $zero $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [__NoSideEffect]
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16;
+ return spirv_asm
+ {
+ %pointer:$$T* = OpPtrAccessChain $buffer $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$T;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$U;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ [ForceInline]
+ [require(cooperative_matrix)]
+ static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+ {
+ let alignment = 16U;
+ return spirv_asm
+ {
+ %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>;
+ %pointer:%workgroupPointerType = OpAccessChain &data $element;
+ result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment;
+ };
+ }
+
+ //
+ // Arithmetic
+ //
+
+ __intrinsic_op($(kIROp_Add))
+ This add(This other);
+
+ __intrinsic_op($(kIROp_Sub))
+ This sub(This other);
+
+ __intrinsic_op($(kIROp_Mul))
+ This mul(This other);
+
+ __intrinsic_op($(kIROp_Div))
+ This div(This other);
+
+ __intrinsic_op($(kIROp_Neg))
+ This neg();
+
+ This mod(This other)
+ {
+ This ret;
+ for (int i = 0; i < getLength(); ++i)
+ {
+ ret[i] = this[i] % other[i];
+ }
+ return ret;
+ }
+
+ //
+ // Equality and ordering
+ //
+
+ bool equals(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] != other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool lessThan(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return false;
+ }
+
+ bool lessThanOrEquals(This other)
+ {
+ for (int i = 0; i < getLength(); i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+}
+
+//
+// Convenience loading functions for cooperative matrices which infer the
+// element type for structured buffers, pointers and groupshared arrays.
+//
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout);
+}
+
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout)
+{
+ return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout);
+}
+
+//
+// Cooperative Matrix casting
+//
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_IntCast))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_FloatCast))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_CastIntToFloat))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val);
+
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+__intrinsic_op($(kIROp_CastFloatToInt))
+[require(cooperative_matrix)]
+CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val);
+
+//
+// Cooperative Matrix multiplication with scalar
+//
+
+__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs)
+{
+ return spirv_asm
+ {
+ result:$$CoopMat<T, S, M, N, R> = OpMatrixTimesScalar $lhs $rhs;
+ };
+}
+
+__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>
+[ForceInline]
+[require(cooperative_matrix)]
+CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs)
+{
+ return rhs * lhs;
+}
+
+//
+// Cooperative Matrix enums
+//
+
+enum CoopMatScope
+{
+ Device = 1,
+ Workgroup = 2,
+ Subgroup = 3,
+ QueueFamily = 5,
+};
+
+enum CoopMatMatrixUse
+{
+ MatrixA = 0,
+ MatrixB = 1,
+ MatrixAccumulator = 2,
+};
+
+enum CoopMatMatrixLayout
+{
+ RowMajor = 0,
+ ColumnMajor = 1,
+};
+
+enum CoopMatMatrixOperands
+{
+ None = 0x0,
+ MatrixASigned = 0x1,
+ MatrixBSigned = 0x2,
+ MatrixCSigned = 0x4,
+ MatrixResultSigned = 0x8,
+ SaturatingAccumulation = 0x10,
+};
+
+//
+// Cooperative Matrix multiply accumulate
+//
+
+[require(cooperative_matrix)]
+__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse>
+CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands)
+{
+ static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`");
+ return spirv_asm
+ {
+ result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands;
+ };
+}
+
+//
// Cooperative Vector
//
@@ -23435,6 +23975,7 @@ CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, l
// Coop Vector matrix multiplication
//
+
/// Specifies the memory layout for matrices used in cooperative vector operations.
/// @remarks This enum defines different matrix layout options that affect how matrix data is stored and accessed,
/// including standard row-major and column-major layouts as well as specialized layouts optimized for specific operations.
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 749a72e3b..fc1bc71a8 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -545,6 +545,10 @@ def SPV_EXT_replicated_composites : _spirv_1_0;
/// [EXT]
def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
+/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix.
+/// [EXT]
+def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer;
+
// SPIRV Capabilities.
/// Represents the SPIR-V capability for atomic float 32 add operations.
@@ -695,6 +699,10 @@ def spvCooperativeVectorNV : SPV_NV_cooperative_vector;
/// [EXT]
def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector;
+/// Represents the SPIR-V capability for cooperative matrices
+/// [EXT]
+def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix;
+
/// Represents the SPIR-V capability for maximal reconvergence.
/// [EXT]
def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence;
@@ -1075,6 +1083,9 @@ alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV
/// [Compound]
alias cooperative_vector_training = spvCooperativeVectorTrainingNV;
+/// Capabilities needed to use cooperative matrices
+alias cooperative_matrix = spvCooperativeMatrixKHR;
+
// Non-internal shader stages
//
/// Pixel shader stage
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 880f6b083..017e58667 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -151,6 +151,28 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp
componentCount);
}
+template<typename T1, typename T2>
+SpvInst* emitOpTypeCoopMat(
+ IRInst* inst,
+ const T1& componentType,
+ const T2& scope,
+ const T2& rowCount,
+ const T2& columnCount,
+ const T2& matrixUse)
+{
+ static_assert(isSingular<T1>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeCooperativeMatrixKHR,
+ kResultID,
+ componentType,
+ scope,
+ rowCount,
+ columnCount,
+ matrixUse);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix
template<typename T>
SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index baef62f1c..d07d587e5 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(),
coopVecType);
}
+ case kIROp_CoopMatrixType:
+ {
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix"));
+
+ IRBuilder builder(m_irModule);
+ auto coopMatType = static_cast<IRCoopMatrixType*>(inst);
+ return emitOpTypeCoopMat(
+ coopMatType,
+ coopMatType->getElementType(),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getScope())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(),
+ builder.getIntType()),
+ emitIntConstant(
+ static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(),
+ builder.getIntType()));
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto baseTy = base->getDataType();
SLANG_ASSERT(
as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) ||
- as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy));
+ as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) ||
+ as<IRCoopMatrixType>(baseTy));
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
@@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
if (as<IRBoolType>(fromType))
{
@@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
bool isMatrixCast = false;
if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) ||
- as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV))
+ as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) ||
+ // Cooperative matrices behave like vectors where arithmetic operations can be performed
+ // directly without having to loop through the matrix and performing operations on the
+ // vectors.
+ as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV))
{
- fromType = getVectorElementType(fromTypeV);
- toType = getVectorElementType(toTypeV);
+ fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ toType = getVectorOrCoopMatrixElementType(toTypeV);
}
else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV))
{
@@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
+
SLANG_ASSERT(isFloatingType(toType));
if (isIntegralType(fromType))
@@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto fromTypeV = inst->getOperand(0)->getDataType();
const auto toTypeV = inst->getDataType();
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
- const auto fromType = getVectorElementType(fromTypeV);
- const auto toType = getVectorElementType(toTypeV);
+ const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
+ const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
SLANG_ASSERT(isFloatingType(fromType));
if (as<IRBoolType>(toType))
@@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
UInt operandCount,
ArrayView<IRInst*> operands)
{
- IRType* elementType = getVectorElementType(operands[0]->getDataType());
+ IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
bool isFloatingPoint = false;
bool isBool = false;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c7ed5affe..3de40d2c0 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -235,6 +235,7 @@ INST(Nop, nop, 0, 0)
INST(RayQueryType, RayQuery, 1, HOISTABLE)
INST(HitObjectType, HitObject, 0, HOISTABLE)
INST(CoopVectorType, CoopVectorType, 2, HOISTABLE)
+INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE)
// Opaque type that can be dynamically cast to other resource types.
INST(DynamicResourceType, DynamicResource, 0, HOISTABLE)
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 58bb7aaf2..4919850eb 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -23,6 +23,18 @@ IRType* getVectorElementType(IRType* type)
return vectorType->getElementType();
if (auto coopVecType = as<IRCoopVectorType>(type))
return coopVecType->getElementType();
+ if (auto coopMatType = as<IRCoopMatrixType>(type))
+ return coopMatType->getElementType();
+ return type;
+}
+
+IRType* getVectorOrCoopMatrixElementType(IRType* type)
+{
+ auto vectorElementType = getVectorElementType(type);
+ if (vectorElementType != type)
+ return vectorElementType;
+ if (auto coopMatrixType = as<IRCoopMatrixType>(type))
+ return coopMatrixType->getElementType();
return type;
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 549981f58..0a8bc9b1d 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -78,6 +78,9 @@ bool isComInterfaceType(IRType* type);
// If `type` is a vector, returns its element type. Otherwise, return `type`.
IRType* getVectorElementType(IRType* type);
+// If `type` is a vector or a coop matrix, returns its element type. Otherwise, return `type`.
+IRType* getVectorOrCoopMatrixElementType(IRType* type);
+
// If `type` is a matrix, returns its element type. Otherwise, return `type`.
IRType* getMatrixElementType(IRType* type);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f75fe2f48..c105a698a 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5351,6 +5351,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
{
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
}
+ else if (auto coopMatType = as<IRCoopMatrixType>(valueType))
+ {
+ type = coopMatType->getElementType();
+ }
else if (const auto basicType = as<IRBasicType>(valueType))
{
// HLSL support things like float.x, in which case we just return the base pointer.
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index dbc66c6a3..dbf2b91be 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1869,6 +1869,17 @@ struct IRCoopVectorType : IRType
IR_LEAF_ISA(CoopVectorType)
};
+struct IRCoopMatrixType : IRType
+{
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+ IRInst* getScope() { return getOperand(1); }
+ IRInst* getRowCount() { return getOperand(2); }
+ IRInst* getColumnCount() { return getOperand(3); }
+ IRInst* getMatrixUse() { return getOperand(4); }
+
+ IR_LEAF_ISA(CoopMatrixType)
+};
+
bool isDefinition(IRInst* inVal);
// A structure type is represented as a parent instruction,