From d0b6a0b1ab49b5958015f31364c5ad73d9cd03eb Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:57:45 -0600 Subject: Add cooperative matrix 1 support (#6565) * initial wip for spirv * working tiled example * clean up store and load * minor fixes * fix loadAny name * add initial tests, including broken/unimplemented intrinsics * fix subscript * run tests at 16x16, remove not supported arithmetic tests * minor fixups on implementation * rename CoopMatMatrixUse * Update tests to pass validation layers locally * Add mat-mul-add test and minor fixes * Add more tests * Remove dead code * Add coopMatLoad function and tests, enforce constexpr for matrix layout * Use getVectorOrCoopMatrixElementType in place of getVectorElementType --- tests/cooperative-matrix/array.slang | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/cooperative-matrix/array.slang (limited to 'tests/cooperative-matrix/array.slang') diff --git a/tests/cooperative-matrix/array.slang b/tests/cooperative-matrix/array.slang new file mode 100644 index 000000000..b46c0f66b --- /dev/null +++ b/tests/cooperative-matrix/array.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +// CHECK: type: float +// CHECK: 1.000000 +// CHECK-NEXT: 2.000000 +// CHECK-NEXT: 3.000000 +// CHECK-NEXT: 4.000000 +// CHECK-NEXT: 5.000000 +// CHECK-NEXT: 6.000000 +// CHECK-NEXT: 7.000000 +// CHECK-NEXT: 8.000000 + +//TEST_INPUT:ubuffer(data=[1.0 2.0 3.0 4.0], stride=256),name=input1 +ByteAddressBuffer input1; + +//TEST_INPUT:ubuffer(data=[5.0 6.0 7.0 8.0], stride=256),name=input1 +ByteAddressBuffer input2; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typealias CoopMatType = CoopMat; + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let matrixLayout = CoopMatMatrixLayout::RowMajor; + + CoopMatType coopMatArray[2]; + coopMatArray[0] = CoopMatType.load(input1, 0, stride, matrixLayout); + coopMatArray[1] = CoopMatType.load(input2, 0, stride, matrixLayout); + + coopMatArray[0].store(outputBuffer, 0, stride, matrixLayout); + coopMatArray[1].store(outputBuffer, 4, stride, matrixLayout); +} -- cgit v1.2.3