summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGangzheng Tong <tonggangzheng@gmail.com>2025-06-25 20:09:57 -0700
committerGitHub <noreply@github.com>2025-06-26 03:09:57 +0000
commit5afebcf02748384471a98858eedb685024f7f854 (patch)
treeb423667a22ebc05b283bbdfdb821f3040e044940
parentee51fe592747fc66bd0b5757207583198068b5bd (diff)
Add matrix operand for OpCooperativeVectorMatrixMulAddNV (#7524)
* Add matrix operand for OpCooperativeVectorMatrixMulAddNV * update tests to use the supported UINT32 input component type * Add MatrixCSignedComponentsKHR for coopVecMatMulAddPacked --------- Co-authored-by: Jay Kwak <82421531+jkwak-work@users.noreply.github.com>
-rw-r--r--source/slang/hlsl.meta.slang41
-rw-r--r--tests/cooperative-vector/matrix-mul-bias-packed-mut.slang2
-rw-r--r--tests/cooperative-vector/matrix-mul-bias-packed.slang2
3 files changed, 41 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 46d251298..9800f2e65 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -26493,7 +26493,13 @@ CoopVec<T, M> coopVecMatMulPacked(
let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
let matrixPtr = matrix.GetBufferPointer();
+
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ }
if (__isSignedInt<T>())
{
operands |= 0x08; // MatrixResultSignedComponentsKHR
@@ -26693,9 +26699,21 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
let matrixPtr = matrix.GetBufferPointer();
let biasPtr = bias.GetBufferPointer();
+
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
+ int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ operands |= 0x04; // MatrixCSignedComponentsKHR
+ }
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
return spirv_asm
{
- result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands;
};
case hlsl:
@@ -27036,7 +27054,13 @@ CoopVec<T, M> coopVecMatMulPacked(
let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
let matrixPtr = __getStructuredBufferPtr(matrix);
+
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ }
if (__isSignedInt<T>())
{
operands |= 0x08; // MatrixResultSignedComponentsKHR
@@ -27109,9 +27133,22 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
let matrixPtr = __getStructuredBufferPtr(matrix);
let biasPtr = __getStructuredBufferPtr(bias);
+
+ // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands
+ int operands = 0; // NoneKHR
+ if (__isSignedInt<U>())
+ {
+ operands |= 0x02; // MatrixBSignedComponentsKHR
+ operands |= 0x04; // MatrixCSignedComponentsKHR
+ }
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
+
return spirv_asm
{
- result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands;
};
}
}
diff --git a/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang b/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang
index b3443e08a..54d3b11dd 100644
--- a/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang
+++ b/tests/cooperative-vector/matrix-mul-bias-packed-mut.slang
@@ -25,7 +25,7 @@ ByteAddressBuffer bias;
[numthreads(1, 1, 1)]
void computeMain()
{
- let vec = coopVecLoad<1, int32_t>(input);
+ let vec = coopVecLoad<1, uint32_t>(input);
var result = CoopVec<int32_t, 4>(8000);
result.matMulAddAccumPacked(
vec,
diff --git a/tests/cooperative-vector/matrix-mul-bias-packed.slang b/tests/cooperative-vector/matrix-mul-bias-packed.slang
index a47a82714..e80fb02ac 100644
--- a/tests/cooperative-vector/matrix-mul-bias-packed.slang
+++ b/tests/cooperative-vector/matrix-mul-bias-packed.slang
@@ -25,7 +25,7 @@ ByteAddressBuffer bias;
[numthreads(1, 1, 1)]
void computeMain()
{
- let vec = coopVecLoad<1, int32_t>(input);
+ let vec = coopVecLoad<1, uint32_t>(input);
let result = coopVecMatMulAddPacked<int32_t, 4>(
vec,
CoopVecComponentType::SignedInt8Packed,