diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-15 15:57:45 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-15 14:57:45 -0700 |
| commit | d0b6a0b1ab49b5958015f31364c5ad73d9cd03eb (patch) | |
| tree | e419bb3c89fa8c389eb0ccbbe8aaa29a1dcd515f /source | |
| parent | a6174ff9443507dece534aa193f8c45e8f0ce7db (diff) | |
Add cooperative matrix 1 support (#6565)
* initial wip for spirv
* working tiled example
* clean up store and load
* minor fixes
* fix loadAny name
* add initial tests, including broken/unimplemented intrinsics
* fix subscript
* run tests at 16x16, remove not supported arithmetic tests
* minor fixups on implementation
* rename CoopMatMatrixUse
* Update tests to pass validation layers locally
* Add mat-mul-add test and minor fixes
* Add more tests
* Remove dead code
* Add coopMatLoad function and tests, enforce constexpr for matrix layout
* Use getVectorOrCoopMatrixElementType in place of getVectorElementType
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 541 | ||||
| -rw-r--r-- | source/slang/slang-capabilities.capdef | 11 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 51 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 11 |
9 files changed, 645 insertions, 11 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index bdaa2bad0..e71997c6c 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -22009,6 +22009,546 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR } // +// Cooperative Matrix type +// + +__intrinsic_type($(kIROp_CoopMatrixType)) +[require(cooperative_matrix)] +struct CoopMat<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> : IArray<T>, IArithmetic +{ + // + // Initialization + // + + [ForceInline] + [require(cooperative_matrix)] + __init() + { + } + + [ForceInline] + [require(cooperative_matrix)] + __init(T t) + { + this.fill(t); + } + + [ForceInline] + [require(cooperative_matrix)] + __init<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + { + this.copyFrom(other); + } + + [ForceInline] + __init(This x) + { + this = x; + } + + // Required for `IArithmetic`. + [OverloadRank(-10)] + [ForceInline] + __init(int i) + { + this = CoopMat<T, S, M, N, R>(T(i)); + } + + // + // Simple setters + // + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void fill(T t) + { + this = spirv_asm + { + result:$$CoopMat<T, S, M, N, R> = OpConstantComposite $t; + }; + } + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void copyFrom<U : __BuiltinArithmeticType>(CoopMat<U, S, M, N, R> other) + { + if (__isFloat<T>() && __isInt<U>()) + this = __int_to_float_cast<T>(other); + else if (__isInt<T>() && __isFloat<U>()) + this = __float_to_int_cast<T>(other); + else if (__isFloat<T>() && __isFloat<U>()) + this = __real_cast<T>(other); + else if (__isInt<T>() && __isInt<U>()) + this = __int_cast<T>(other); + } + + // + // Subscript + // + + __intrinsic_op($(kIROp_GetElement)) + [__NoSideEffect] + T __indexRead(int index); + + __intrinsic_op($(kIROp_GetElementPtr)) + [__ref] + [__NoSideEffect] + Ref<T> __indexRef(int index); + + [ForceInline] + [__NoSideEffect] + int getCount() + { + return getLength(); + } + + [ForceInline] + [__NoSideEffect] + int getRowCount() + { + return M; + } + + [ForceInline] + [__NoSideEffect] + int getColumnCount() + { + return N; + } + + __subscript(int index) -> T + { + [__NoSideEffect] + [nonmutating] + get + { + return __indexRead(index); + } + + [mutating] + set + { + __indexRef(index) = newValue; + } + } + + /// Returns the number of components owned by each invocation. + [ForceInline] + [require(cooperative_matrix)] + uint getLength() + { + return spirv_asm + { + result:$$uint = OpCooperativeMatrixLengthKHR $$This; + }; + } + + // + // Store + // + + [ForceInline] + [require(cooperative_matrix)] + void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return store(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [ForceInline] + [require(cooperative_matrix)] + void store(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [require(cooperative_matrix)] + [ForceInline] + void store<let U : int>(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny<U, let V : int>(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny<U, let V : int, let L : int>(__ref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Load + // + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer<T>(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> load<let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> loadAny<U, let V : int>(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat<T, S, M, N, R> loadAny<U, let V : int, let L : int>(__constref groupshared vector<U, L>[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector<U, L>; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat<T, S, M, N, R> = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Arithmetic + // + + __intrinsic_op($(kIROp_Add)) + This add(This other); + + __intrinsic_op($(kIROp_Sub)) + This sub(This other); + + __intrinsic_op($(kIROp_Mul)) + This mul(This other); + + __intrinsic_op($(kIROp_Div)) + This div(This other); + + __intrinsic_op($(kIROp_Neg)) + This neg(); + + This mod(This other) + { + This ret; + for (int i = 0; i < getLength(); ++i) + { + ret[i] = this[i] % other[i]; + } + return ret; + } + + // + // Equality and ordering + // + + bool equals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] != other[i]) + { + return false; + } + } + return true; + } + + bool lessThan(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return false; + } + + bool lessThanOrEquals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return true; + } +} + +// +// Convenience loading functions for cooperative matrices which infer the +// element type for structured buffers, pointers and groupshared arrays. +// + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(StructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(RWStructuredBuffer<T> buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse>(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> coopMatLoad<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse, let U : int>(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat<T, S, M, N, R>.load(data, element, stride, matrixLayout); +} + +// +// Cooperative Matrix casting +// + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_IntCast)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __int_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_FloatCast)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __real_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_CastIntToFloat)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __int_to_float_cast(CoopMat<U,S,M,N,R> val); + +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +__intrinsic_op($(kIROp_CastFloatToInt)) +[require(cooperative_matrix)] +CoopMat<T,S,M,N,R> __float_to_int_cast(CoopMat<U,S,M,N,R> val); + +// +// Cooperative Matrix multiplication with scalar +// + +__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> operator *(CoopMat<T, S, M, N, R> lhs, const T rhs) +{ + return spirv_asm + { + result:$$CoopMat<T, S, M, N, R> = OpMatrixTimesScalar $lhs $rhs; + }; +} + +__generic<T : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let N : int, let R : CoopMatMatrixUse> +[ForceInline] +[require(cooperative_matrix)] +CoopMat<T, S, M, N, R> operator *(const T lhs, CoopMat<T, S, M, N, R> rhs) +{ + return rhs * lhs; +} + +// +// Cooperative Matrix enums +// + +enum CoopMatScope +{ + Device = 1, + Workgroup = 2, + Subgroup = 3, + QueueFamily = 5, +}; + +enum CoopMatMatrixUse +{ + MatrixA = 0, + MatrixB = 1, + MatrixAccumulator = 2, +}; + +enum CoopMatMatrixLayout +{ + RowMajor = 0, + ColumnMajor = 1, +}; + +enum CoopMatMatrixOperands +{ + None = 0x0, + MatrixASigned = 0x1, + MatrixBSigned = 0x2, + MatrixCSigned = 0x4, + MatrixResultSigned = 0x8, + SaturatingAccumulation = 0x10, +}; + +// +// Cooperative Matrix multiply accumulate +// + +[require(cooperative_matrix)] +__generic<T : __BuiltinArithmeticType, U : __BuiltinArithmeticType, V : __BuiltinArithmeticType, let S : CoopMatScope, let M : int, let K : int, let N : int, let RA : CoopMatMatrixUse, let RB : CoopMatMatrixUse, let RC : CoopMatMatrixUse> +CoopMat<V, S, M, N, RC> coopMatMulAdd(CoopMat<T, S, M, K, RA> matA, CoopMat<U, S, K, N, RB> matB, CoopMat<V, S, M, N, RC> matC, constexpr CoopMatMatrixOperands operands) +{ + static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`"); + return spirv_asm + { + result:$$CoopMat<V, S, M, N, RC> = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; + }; +} + +// // Cooperative Vector // @@ -23435,6 +23975,7 @@ CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, l // Coop Vector matrix multiplication // + /// Specifies the memory layout for matrices used in cooperative vector operations. /// @remarks This enum defines different matrix layout options that affect how matrix data is stored and accessed, /// including standard row-major and column-major layouts as well as specialized layouts optimized for specific operations. diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 749a72e3b..fc1bc71a8 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -545,6 +545,10 @@ def SPV_EXT_replicated_composites : _spirv_1_0; /// [EXT] def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites; +/// Represents the SPIR-V extension for SPV_KHR_cooperative_matrix. +/// [EXT] +def SPV_KHR_cooperative_matrix : _spirv_1_6 + SPV_EXT_physical_storage_buffer; + // SPIRV Capabilities. /// Represents the SPIR-V capability for atomic float 32 add operations. @@ -695,6 +699,10 @@ def spvCooperativeVectorNV : SPV_NV_cooperative_vector; /// [EXT] def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector; +/// Represents the SPIR-V capability for cooperative matrices +/// [EXT] +def spvCooperativeMatrixKHR : SPV_KHR_cooperative_matrix; + /// Represents the SPIR-V capability for maximal reconvergence. /// [EXT] def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence; @@ -1075,6 +1083,9 @@ alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV /// [Compound] alias cooperative_vector_training = spvCooperativeVectorTrainingNV; +/// Capabilities needed to use cooperative matrices +alias cooperative_matrix = spvCooperativeMatrixKHR; + // Non-internal shader stages // /// Pixel shader stage diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 880f6b083..017e58667 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -151,6 +151,28 @@ SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& comp componentCount); } +template<typename T1, typename T2> +SpvInst* emitOpTypeCoopMat( + IRInst* inst, + const T1& componentType, + const T2& scope, + const T2& rowCount, + const T2& columnCount, + const T2& matrixUse) +{ + static_assert(isSingular<T1>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeCooperativeMatrixKHR, + kResultID, + componentType, + scope, + rowCount, + columnCount, + matrixUse); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix template<typename T> SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index baef62f1c..d07d587e5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), coopVecType); } + case kIROp_CoopMatrixType: + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + IRBuilder builder(m_irModule); + auto coopMatType = static_cast<IRCoopMatrixType*>(inst); + return emitOpTypeCoopMat( + coopMatType, + coopMatType->getElementType(), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getScope())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getRowCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getColumnCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast<IRIntLit*>(coopMatType->getMatrixUse())->getValue(), + builder.getIntType())); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy) || + as<IRCoopMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); if (as<IRBoolType>(fromType)) { @@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isMatrixCast = false; if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || - as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV) || + // Cooperative matrices behave like vectors where arithmetic operations can be performed + // directly without having to loop through the matrix and performing operations on the + // vectors. + as<IRCoopMatrixType>(fromTypeV) || as<IRCoopMatrixType>(toTypeV)) { - fromType = getVectorElementType(fromTypeV); - toType = getVectorElementType(toTypeV); + fromType = getVectorOrCoopMatrixElementType(fromTypeV); + toType = getVectorOrCoopMatrixElementType(toTypeV); } else if (as<IRMatrixType>(fromTypeV) || as<IRMatrixType>(toTypeV)) { @@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + SLANG_ASSERT(isFloatingType(toType)); if (isIntegralType(fromType)) @@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); if (as<IRBoolType>(toType)) @@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = getVectorElementType(operands[0]->getDataType()); + IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c7ed5affe..3de40d2c0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -235,6 +235,7 @@ INST(Nop, nop, 0, 0) INST(RayQueryType, RayQuery, 1, HOISTABLE) INST(HitObjectType, HitObject, 0, HOISTABLE) INST(CoopVectorType, CoopVectorType, 2, HOISTABLE) +INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE) // Opaque type that can be dynamically cast to other resource types. INST(DynamicResourceType, DynamicResource, 0, HOISTABLE) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 58bb7aaf2..4919850eb 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -23,6 +23,18 @@ IRType* getVectorElementType(IRType* type) return vectorType->getElementType(); if (auto coopVecType = as<IRCoopVectorType>(type)) return coopVecType->getElementType(); + if (auto coopMatType = as<IRCoopMatrixType>(type)) + return coopMatType->getElementType(); + return type; +} + +IRType* getVectorOrCoopMatrixElementType(IRType* type) +{ + auto vectorElementType = getVectorElementType(type); + if (vectorElementType != type) + return vectorElementType; + if (auto coopMatrixType = as<IRCoopMatrixType>(type)) + return coopMatrixType->getElementType(); return type; } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 549981f58..0a8bc9b1d 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -78,6 +78,9 @@ bool isComInterfaceType(IRType* type); // If `type` is a vector, returns its element type. Otherwise, return `type`. IRType* getVectorElementType(IRType* type); +// If `type` is a vector or a coop matrix, returns its element type. Otherwise, return `type`. +IRType* getVectorOrCoopMatrixElementType(IRType* type); + // If `type` is a matrix, returns its element type. Otherwise, return `type`. IRType* getMatrixElementType(IRType* type); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f75fe2f48..c105a698a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5351,6 +5351,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); } + else if (auto coopMatType = as<IRCoopMatrixType>(valueType)) + { + type = coopMatType->getElementType(); + } else if (const auto basicType = as<IRBasicType>(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index dbc66c6a3..dbf2b91be 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1869,6 +1869,17 @@ struct IRCoopVectorType : IRType IR_LEAF_ISA(CoopVectorType) }; +struct IRCoopMatrixType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getScope() { return getOperand(1); } + IRInst* getRowCount() { return getOperand(2); } + IRInst* getColumnCount() { return getOperand(3); } + IRInst* getMatrixUse() { return getOperand(4); } + + IR_LEAF_ISA(CoopMatrixType) +}; + bool isDefinition(IRInst* inVal); // A structure type is represented as a parent instruction, |
