From 5afebcf02748384471a98858eedb685024f7f854 Mon Sep 17 00:00:00 2001 From: Gangzheng Tong Date: Wed, 25 Jun 2025 20:09:57 -0700 Subject: Add matrix operand for OpCooperativeVectorMatrixMulAddNV (#7524) * Add matrix operand for OpCooperativeVectorMatrixMulAddNV * update tests to use the supported UINT32 input component type * Add MatrixCSignedComponentsKHR for coopVecMatMulAddPacked --------- Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> --- source/slang/hlsl.meta.slang | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 46d251298..9800f2e65 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26493,7 +26493,13 @@ CoopVec coopVecMatMulPacked( let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = matrix.GetBufferPointer(); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands int operands = 0; // NoneKHR + if (__isSignedInt()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + } if (__isSignedInt()) { operands |= 0x08; // MatrixResultSignedComponentsKHR @@ -26693,9 +26699,21 @@ CoopVec coopVecMatMulAddPacked()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + operands |= 0x04; // MatrixCSignedComponentsKHR + } + if (__isSignedInt()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } return spirv_asm { - result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; }; case hlsl: @@ -27036,7 +27054,13 @@ CoopVec coopVecMatMulPacked( let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = __getStructuredBufferPtr(matrix); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands int operands = 0; // NoneKHR + if (__isSignedInt()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + } if (__isSignedInt()) { operands |= 0x08; // MatrixResultSignedComponentsKHR @@ -27109,9 +27133,22 @@ CoopVec coopVecMatMulAddPacked()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + operands |= 0x04; // MatrixCSignedComponentsKHR + } + if (__isSignedInt()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm { - result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; }; } } -- cgit v1.2.3