From d0b6a0b1ab49b5958015f31364c5ad73d9cd03eb Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:57:45 -0600 Subject: Add cooperative matrix 1 support (#6565) * initial wip for spirv * working tiled example * clean up store and load * minor fixes * fix loadAny name * add initial tests, including broken/unimplemented intrinsics * fix subscript * run tests at 16x16, remove not supported arithmetic tests * minor fixups on implementation * rename CoopMatMatrixUse * Update tests to pass validation layers locally * Add mat-mul-add test and minor fixes * Add more tests * Remove dead code * Add coopMatLoad function and tests, enforce constexpr for matrix layout * Use getVectorOrCoopMatrixElementType in place of getVectorElementType --- source/slang/hlsl.meta.slang | 541 +++++++++++++++++++++++++++++++++ source/slang/slang-capabilities.capdef | 11 + source/slang/slang-emit-spirv-ops.h | 22 ++ source/slang/slang-emit-spirv.cpp | 51 +++- source/slang/slang-ir-inst-defs.h | 1 + source/slang/slang-ir-util.cpp | 12 + source/slang/slang-ir-util.h | 3 + source/slang/slang-ir.cpp | 4 + source/slang/slang-ir.h | 11 + 9 files changed, 645 insertions(+), 11 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index bdaa2bad0..e71997c6c 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -22008,6 +22008,546 @@ extension RasterizerOrderedStructuredBuffer : IR int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } } +// +// Cooperative Matrix type +// + +__intrinsic_type($(kIROp_CoopMatrixType)) +[require(cooperative_matrix)] +struct CoopMat : IArray, IArithmetic +{ + // + // Initialization + // + + [ForceInline] + [require(cooperative_matrix)] + __init() + { + } + + [ForceInline] + [require(cooperative_matrix)] + __init(T t) + { + this.fill(t); + } + + [ForceInline] + [require(cooperative_matrix)] + __init(CoopMat other) + { + this.copyFrom(other); + } + + [ForceInline] + __init(This x) + { + this = x; + } + + // Required for `IArithmetic`. + [OverloadRank(-10)] + [ForceInline] + __init(int i) + { + this = CoopMat(T(i)); + } + + // + // Simple setters + // + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void fill(T t) + { + this = spirv_asm + { + result:$$CoopMat = OpConstantComposite $t; + }; + } + + [require(cooperative_matrix)] + [mutating] + [ForceInline] + void copyFrom(CoopMat other) + { + if (__isFloat() && __isInt()) + this = __int_to_float_cast(other); + else if (__isInt() && __isFloat()) + this = __float_to_int_cast(other); + else if (__isFloat() && __isFloat()) + this = __real_cast(other); + else if (__isInt() && __isInt()) + this = __int_cast(other); + } + + // + // Subscript + // + + __intrinsic_op($(kIROp_GetElement)) + [__NoSideEffect] + T __indexRead(int index); + + __intrinsic_op($(kIROp_GetElementPtr)) + [__ref] + [__NoSideEffect] + Ref __indexRef(int index); + + [ForceInline] + [__NoSideEffect] + int getCount() + { + return getLength(); + } + + [ForceInline] + [__NoSideEffect] + int getRowCount() + { + return M; + } + + [ForceInline] + [__NoSideEffect] + int getColumnCount() + { + return N; + } + + __subscript(int index) -> T + { + [__NoSideEffect] + [nonmutating] + get + { + return __indexRead(index); + } + + [mutating] + set + { + __indexRef(index) = newValue; + } + } + + /// Returns the number of components owned by each invocation. + [ForceInline] + [require(cooperative_matrix)] + uint getLength() + { + return spirv_asm + { + result:$$uint = OpCooperativeMatrixLengthKHR $$This; + }; + } + + // + // Store + // + + [ForceInline] + [require(cooperative_matrix)] + void store(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return store(__getEquivalentStructuredBuffer(buffer), element, stride, matrixLayout); + } + + [ForceInline] + [require(cooperative_matrix)] + void store(RWStructuredBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + void store(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [require(cooperative_matrix)] + [ForceInline] + void store(__ref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny(__ref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [ForceInline] + [require(cooperative_matrix)] + void storeAny(__ref groupshared vector[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + OpCooperativeMatrixStoreKHR %pointer $this $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Load + // + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + return load(__getEquivalentStructuredBuffer(buffer), element, stride, matrixLayout); + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(StructuredBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(RWStructuredBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let zero = 0; + let alignment = 16U; + return spirv_asm + { + %storagePointerType = OpTypePointer StorageBuffer $$T; + %pointer:%storagePointerType = OpAccessChain $buffer $zero $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [__NoSideEffect] + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16; + return spirv_asm + { + %pointer:$$T* = OpPtrAccessChain $buffer $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat load(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$T; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat loadAny(__constref groupshared U[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$U; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + [ForceInline] + [require(cooperative_matrix)] + static CoopMat loadAny(__constref groupshared vector[V] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) + { + let alignment = 16U; + return spirv_asm + { + %workgroupPointerType = OpTypePointer Workgroup $$vector; + %pointer:%workgroupPointerType = OpAccessChain &data $element; + result:$$CoopMat = OpCooperativeMatrixLoadKHR %pointer $matrixLayout $stride Aligned !alignment; + }; + } + + // + // Arithmetic + // + + __intrinsic_op($(kIROp_Add)) + This add(This other); + + __intrinsic_op($(kIROp_Sub)) + This sub(This other); + + __intrinsic_op($(kIROp_Mul)) + This mul(This other); + + __intrinsic_op($(kIROp_Div)) + This div(This other); + + __intrinsic_op($(kIROp_Neg)) + This neg(); + + This mod(This other) + { + This ret; + for (int i = 0; i < getLength(); ++i) + { + ret[i] = this[i] % other[i]; + } + return ret; + } + + // + // Equality and ordering + // + + bool equals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] != other[i]) + { + return false; + } + } + return true; + } + + bool lessThan(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return false; + } + + bool lessThanOrEquals(This other) + { + for (int i = 0; i < getLength(); i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return true; + } +} + +// +// Convenience loading functions for cooperative matrices which infer the +// element type for structured buffers, pointers and groupshared arrays. +// + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(ByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(RWByteAddressBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(StructuredBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(RWStructuredBuffer buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(T* buffer, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(buffer, element, stride, matrixLayout); +} + +[ForceInline] +[require(cooperative_matrix)] +CoopMat coopMatLoad(__constref groupshared T[U] data, uint element, uint stride, constexpr CoopMatMatrixLayout matrixLayout) +{ + return CoopMat.load(data, element, stride, matrixLayout); +} + +// +// Cooperative Matrix casting +// + +__generic +__intrinsic_op($(kIROp_IntCast)) +[require(cooperative_matrix)] +CoopMat __int_cast(CoopMat val); + +__generic +__intrinsic_op($(kIROp_FloatCast)) +[require(cooperative_matrix)] +CoopMat __real_cast(CoopMat val); + +__generic +__intrinsic_op($(kIROp_CastIntToFloat)) +[require(cooperative_matrix)] +CoopMat __int_to_float_cast(CoopMat val); + +__generic +__intrinsic_op($(kIROp_CastFloatToInt)) +[require(cooperative_matrix)] +CoopMat __float_to_int_cast(CoopMat val); + +// +// Cooperative Matrix multiplication with scalar +// + +__generic +[ForceInline] +[require(cooperative_matrix)] +CoopMat operator *(CoopMat lhs, const T rhs) +{ + return spirv_asm + { + result:$$CoopMat = OpMatrixTimesScalar $lhs $rhs; + }; +} + +__generic +[ForceInline] +[require(cooperative_matrix)] +CoopMat operator *(const T lhs, CoopMat rhs) +{ + return rhs * lhs; +} + +// +// Cooperative Matrix enums +// + +enum CoopMatScope +{ + Device = 1, + Workgroup = 2, + Subgroup = 3, + QueueFamily = 5, +}; + +enum CoopMatMatrixUse +{ + MatrixA = 0, + MatrixB = 1, + MatrixAccumulator = 2, +}; + +enum CoopMatMatrixLayout +{ + RowMajor = 0, + ColumnMajor = 1, +}; + +enum CoopMatMatrixOperands +{ + None = 0x0, + MatrixASigned = 0x1, + MatrixBSigned = 0x2, + MatrixCSigned = 0x4, + MatrixResultSigned = 0x8, + SaturatingAccumulation = 0x10, +}; + +// +// Cooperative Matrix multiply accumulate +// + +[require(cooperative_matrix)] +__generic +CoopMat coopMatMulAdd(CoopMat matA, CoopMat matB, CoopMat matC, constexpr CoopMatMatrixOperands operands) +{ + static_assert((RA == CoopMatMatrixUse::MatrixA) && (RB == CoopMatMatrixUse::MatrixB) && (RC == CoopMatMatrixUse::MatrixAccumulator), "matrix uses for `coopMatMulAdd` matrix parameters must be `MatrixA`, `MatrixB` and `MatrixAccumulator`"); + return spirv_asm + { + result:$$CoopMat = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; + }; +} + // // Cooperative Vector // @@ -23435,6 +23975,7 @@ CoopVec coopVecLoadGroupshared +SpvInst* emitOpTypeCoopMat( + IRInst* inst, + const T1& componentType, + const T2& scope, + const T2& rowCount, + const T2& columnCount, + const T2& matrixUse) +{ + static_assert(isSingular); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeCooperativeMatrixKHR, + kResultID, + componentType, + scope, + rowCount, + columnCount, + matrixUse); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix template SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index baef62f1c..d07d587e5 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1683,6 +1683,29 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast(coopVecType->getElementCount())->getValue(), coopVecType); } + case kIROp_CoopMatrixType: + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + IRBuilder builder(m_irModule); + auto coopMatType = static_cast(inst); + return emitOpTypeCoopMat( + coopMatType, + coopMatType->getElementType(), + emitIntConstant( + static_cast(coopMatType->getScope())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast(coopMatType->getRowCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast(coopMatType->getColumnCount())->getValue(), + builder.getIntType()), + emitIntConstant( + static_cast(coopMatType->getMatrixUse())->getValue(), + builder.getIntType())); + } case kIROp_MatrixType: { auto matrixType = static_cast(inst); @@ -6264,7 +6287,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as(baseTy) || as(baseTy) || as(baseTy) || - as(baseTy) || as(baseTy)); + as(baseTy) || as(baseTy) || + as(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6553,8 +6577,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as(fromTypeV) == !as(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); if (as(fromType)) { @@ -6687,10 +6711,14 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex bool isMatrixCast = false; if (as(fromTypeV) || as(toTypeV) || - as(fromTypeV) || as(toTypeV)) + as(fromTypeV) || as(toTypeV) || + // Cooperative matrices behave like vectors where arithmetic operations can be performed + // directly without having to loop through the matrix and performing operations on the + // vectors. + as(fromTypeV) || as(toTypeV)) { - fromType = getVectorElementType(fromTypeV); - toType = getVectorElementType(toTypeV); + fromType = getVectorOrCoopMatrixElementType(fromTypeV); + toType = getVectorOrCoopMatrixElementType(toTypeV); } else if (as(fromTypeV) || as(toTypeV)) { @@ -6737,8 +6765,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as(fromTypeV) == !as(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + SLANG_ASSERT(isFloatingType(toType)); if (isIntegralType(fromType)) @@ -6781,8 +6810,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto fromTypeV = inst->getOperand(0)->getDataType(); const auto toTypeV = inst->getDataType(); SLANG_ASSERT(!as(fromTypeV) == !as(toTypeV)); - const auto fromType = getVectorElementType(fromTypeV); - const auto toType = getVectorElementType(toTypeV); + const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); + const auto toType = getVectorOrCoopMatrixElementType(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); if (as(toType)) @@ -7085,7 +7114,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex UInt operandCount, ArrayView operands) { - IRType* elementType = getVectorElementType(operands[0]->getDataType()); + IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as(elementType); bool isFloatingPoint = false; bool isBool = false; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c7ed5affe..3de40d2c0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -235,6 +235,7 @@ INST(Nop, nop, 0, 0) INST(RayQueryType, RayQuery, 1, HOISTABLE) INST(HitObjectType, HitObject, 0, HOISTABLE) INST(CoopVectorType, CoopVectorType, 2, HOISTABLE) +INST(CoopMatrixType, CoopMatrixType, 5, HOISTABLE) // Opaque type that can be dynamically cast to other resource types. INST(DynamicResourceType, DynamicResource, 0, HOISTABLE) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 58bb7aaf2..4919850eb 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -23,6 +23,18 @@ IRType* getVectorElementType(IRType* type) return vectorType->getElementType(); if (auto coopVecType = as(type)) return coopVecType->getElementType(); + if (auto coopMatType = as(type)) + return coopMatType->getElementType(); + return type; +} + +IRType* getVectorOrCoopMatrixElementType(IRType* type) +{ + auto vectorElementType = getVectorElementType(type); + if (vectorElementType != type) + return vectorElementType; + if (auto coopMatrixType = as(type)) + return coopMatrixType->getElementType(); return type; } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 549981f58..0a8bc9b1d 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -78,6 +78,9 @@ bool isComInterfaceType(IRType* type); // If `type` is a vector, returns its element type. Otherwise, return `type`. IRType* getVectorElementType(IRType* type); +// If `type` is a vector or a coop matrix, returns its element type. Otherwise, return `type`. +IRType* getVectorOrCoopMatrixElementType(IRType* type); + // If `type` is a matrix, returns its element type. Otherwise, return `type`. IRType* getMatrixElementType(IRType* type); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f75fe2f48..c105a698a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5351,6 +5351,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); } + else if (auto coopMatType = as(valueType)) + { + type = coopMatType->getElementType(); + } else if (const auto basicType = as(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index dbc66c6a3..dbf2b91be 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1869,6 +1869,17 @@ struct IRCoopVectorType : IRType IR_LEAF_ISA(CoopVectorType) }; +struct IRCoopMatrixType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getScope() { return getOperand(1); } + IRInst* getRowCount() { return getOperand(2); } + IRInst* getColumnCount() { return getOperand(3); } + IRInst* getMatrixUse() { return getOperand(4); } + + IR_LEAF_ISA(CoopMatrixType) +}; + bool isDefinition(IRInst* inVal); // A structure type is represented as a parent instruction, -- cgit v1.2.3