1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
//TEST(compute):SIMPLE(filecheck=SPIRV):-target spirv-asm -entry computeMain -stage compute
//TEST(compute):SIMPLE(filecheck=SPIRV_BL):-target spirv-asm -entry computeMain -stage compute -DBLOCK_LOAD
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-tensor-addressing
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-tensor-addressing -Xslang -DRW
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-block-loads -Xslang -DBLOCK_LOAD
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK_BL):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-block-loads -Xslang -DBLOCK_LOAD -Xslang -DRW
//CHECK: 0
//CHECK-NEXT: 0
//CHECK-NEXT: 0
//CHECK-NEXT: 0
//CHECK-NEXT: 5
//CHECK-NEXT: 6
//CHECK-NEXT: 0
//CHECK-NEXT: 0
//CHECK-NEXT: 9
//CHECK_BL: 0
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: 7
//CHECK_BL-NEXT: C
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: C
//CHECK_BL-NEXT: 11
//CHECK_BL-NEXT: 0
//CHECK_BL-NEXT: 0
//TEST_INPUT:ubuffer(data=[1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24], stride=4, count=256),name=buf
#if defined(RW)
RWByteAddressBuffer inputBuffer;
#else // #if defined(RW)
ByteAddressBuffer inputBuffer;
#endif // #else // #if defined(RW)
//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer
RWByteAddressBuffer outputBuffer;
using namespace linalg;
typealias CoopMatType = CoopMat<int32_t, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>;
int32_t decodeFunc(uint32_t* encoded, uint32_t blockCoord[2], uint32_t coordInBlock[2])
{
uint32_t coord = blockCoord[1] * 4 + blockCoord[0];
uint32_t mask = (0xff << (coordInBlock[0] * 8));
return int32_t(encoded[coord] & mask) + 1;
}
[numthreads(32, 1, 1)]
void computeMain()
{
//SPIRV: = OpCreateTensorLayoutNV %
TensorLayout<2, CoopMatClampMode.Undefined> tl;
//SPIRV: = OpTensorLayoutSetDimensionNV %
let tl1 = tl.Dimension(32, 16);
//SPIRV: = OpTensorLayoutSetStrideNV %
let tl2 = tl1.Stride(4, 1);
//SPIRV: = OpTensorLayoutSliceNV %
let tl3 = tl2.Slice(4, 24, 0, 16);
//SPIRV: = OpTensorLayoutSetClampValueNV %
let tl4 = tl3.ClampValue(CoopMatClampMode.Repeat);
//SPIRV: = OpTensorLayoutSetBlockSizeNV %
let tl5 = tl4.BlockSize(4, 8);
#if defined(BLOCK_LOAD)
//SPIRV_BL: = OpCooperativeMatrixLoadTensorNV %{{.*}} DecodeFunc %
let mat = CoopMatType.Load<uint32_t>(inputBuffer, 0, tl5, decodeFunc);
#else // #if defined(BLOCK_LOAD)
//SPIRV: = OpCooperativeMatrixLoadTensorNV %{{.*}} None
let mat = CoopMatType.Load(inputBuffer, 0, tl5);
#endif // #else // #if defined(BLOCK_LOAD)
//SPIRV:OpCooperativeMatrixStoreTensorNV %{{.*}} None
mat.Store(outputBuffer, 0, tl5);
}
|