diff options
| author | Gangzheng Tong <tonggangzheng@gmail.com> | 2025-06-25 20:09:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-26 03:09:57 +0000 |
| commit | 5afebcf02748384471a98858eedb685024f7f854 (patch) | |
| tree | b423667a22ebc05b283bbdfdb821f3040e044940 | |
| parent | ee51fe592747fc66bd0b5757207583198068b5bd (diff) | |
Add matrix operand for OpCooperativeVectorMatrixMulAddNV (#7524)
* Add matrix operand for OpCooperativeVectorMatrixMulAddNV
* update tests to use the supported UINT32 input component type
* Add MatrixCSignedComponentsKHR for coopVecMatMulAddPacked
---------
Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com>
| -rw-r--r-- | source/slang/hlsl.meta.slang | 41 | ||||
| -rw-r--r-- | tests/cooperative-vector/matrix-mul-bias-packed-mut.slang | 2 | ||||
| -rw-r--r-- | tests/cooperative-vector/matrix-mul-bias-packed.slang | 2 |
3 files changed, 41 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 46d251298..9800f2e65 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26493,7 +26493,13 @@ CoopVec<T, M> coopVecMatMulPacked( let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = matrix.GetBufferPointer(); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands int operands = 0; // NoneKHR + if (__isSignedInt<U>()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + } if (__isSignedInt<T>()) { operands |= 0x08; // MatrixResultSignedComponentsKHR @@ -26693,9 +26699,21 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = matrix.GetBufferPointer(); let biasPtr = bias.GetBufferPointer(); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands + int operands = 0; // NoneKHR + if (__isSignedInt<U>()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + operands |= 0x04; // MatrixCSignedComponentsKHR + } + if (__isSignedInt<T>()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } return spirv_asm { - result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; }; case hlsl: @@ -27036,7 +27054,13 @@ CoopVec<T, M> coopVecMatMulPacked( let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = __getStructuredBufferPtr(matrix); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands int operands = 0; // NoneKHR + if (__isSignedInt<U>()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + } if (__isSignedInt<T>()) { operands |= 0x08; // MatrixResultSignedComponentsKHR @@ -27109,9 +27133,22 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixPtr = __getStructuredBufferPtr(matrix); let biasPtr = __getStructuredBufferPtr(bias); + + // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands + int operands = 0; // NoneKHR + if (__isSignedInt<U>()) + { + operands |= 0x02; // MatrixBSignedComponentsKHR + operands |= 0x04; // MatrixCSignedComponentsKHR + } + if (__isSignedInt<T>()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm { - result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; }; } } diff --git a/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang b/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang index b3443e08a..54d3b11dd 100644 --- a/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang +++ b/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang @@ -25,7 +25,7 @@ ByteAddressBuffer bias; [numthreads(1, 1, 1)] void computeMain() { - let vec = coopVecLoad<1, int32_t>(input); + let vec = coopVecLoad<1, uint32_t>(input); var result = CoopVec<int32_t, 4>(8000); result.matMulAddAccumPacked( vec, diff --git a/tests/cooperative-vector/matrix-mul-bias-packed.slang b/tests/cooperative-vector/matrix-mul-bias-packed.slang index a47a82714..e80fb02ac 100644 --- a/tests/cooperative-vector/matrix-mul-bias-packed.slang +++ b/tests/cooperative-vector/matrix-mul-bias-packed.slang @@ -25,7 +25,7 @@ ByteAddressBuffer bias; [numthreads(1, 1, 1)] void computeMain() { - let vec = coopVecLoad<1, int32_t>(input); + let vec = coopVecLoad<1, uint32_t>(input); let result = coopVecMatMulAddPacked<int32_t, 4>( vec, CoopVecComponentType::SignedInt8Packed, |
