1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
// Dispatch [ ( E1 + TILE_Y - 1 ) / TILE_Y, E2, E3 ] thread groups of this shader
#ifndef TILE_Y
static const uint TILE_Y = 64;
#endif
#ifndef THREADS_X
static const uint THREADS_X = 32;
#endif
#ifndef THREADS_Y
static const uint THREADS_Y = 16;
#endif
Buffer<float> arg0: register( t0 );
Buffer<float> arg1: register( t1 );
RWBuffer<float> result: register( u0 );
cbuffer Constants: register( b0 )
{
uint4 arg0Size: packoffset( c0 );
uint4 arg0Strides: packoffset( c1 );
uint4 arg1Size: packoffset( c2 );
uint4 arg1Strides: packoffset( c3 );
uint4 resultSize: packoffset( c4 );
uint4 resultStrides: packoffset( c5 );
}
groupshared float resTemp[ TILE_Y ][ THREADS_X ];
inline uint hadd( uint2 vec )
{
return vec.x + vec.y;
}
[ numthreads( THREADS_X, THREADS_Y, 1 ) ]
void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint threadFlattenned : SV_GroupIndex )
{
uint i;
// Zero out the shared buffer
for( i = thread.y; i < TILE_Y; i += THREADS_Y )
resTemp[ i ][ thread.x ] = 0.0;
GroupMemoryBarrierWithGroupSync();
// Count of rows to compute in this thread group
const uint height = min( TILE_Y, arg0Size.y - group.x * TILE_Y );
uint s0 = hadd( group.yz * arg0Strides.zw ); //< arg0 layer for the thread group
s0 += group.x * TILE_Y * arg0Strides.y; //< arg0 first row for the thread group
s0 += hadd( arg0Strides.xy * thread.xy ); //< arg0 load index for the thread
uint s1 = hadd( group.yz * arg1Strides.zw ); //< arg1 layer for the thread group
s1 += thread.x * arg1Strides.x; //< arg1 load index for the thread
const uint completeTiles = arg0Size.x / THREADS_X;
// Each iteration of that loop loads THREADS_X elements from arg1,
// a block of [ THREADS_X, height ] elements from arg0,
// and accumulates these dot products in the shared buffer
for( uint t = 0; t < completeTiles; t++, s0 += THREADS_X * arg0Strides.x, s1 += THREADS_X * arg1Strides.x )
{
// Load THREADS_X elements from arg1
const float v1 = arg1[ s1 ];
uint rsi = s0;
for( i = thread.y; i < height; i += THREADS_Y, rsi += arg0Strides.y * THREADS_Y )
{
// Load THREADS_X elements from arg0
const float v0 = arg0[ rsi ];
// Multiply and accumulate in the shared buffer
float acc = resTemp[ i ][ thread.x ];
acc = mad( v0, v1, acc );
resTemp[ i ][ thread.x ] = acc;
}
GroupMemoryBarrierWithGroupSync();
}
const uint rem = arg0Size.x % THREADS_X;
if( rem != 0 )
{
// E0 ain't a multiple of THREADS_X, we have a remainder
float v1;
if( thread.x < rem )
v1 = arg1[ s1 ];
else
v1 = 0.0;
for( i = thread.y; i < height; i += THREADS_Y, s0 += arg0Strides.y * THREADS_Y )
{
if( thread.x >= rem )
continue;
const float v0 = arg0[ s0 ];
float acc = resTemp[ i ][ thread.x ];
acc = mad( v0, v1, acc );
resTemp[ i ][ thread.x ] = acc;
}
GroupMemoryBarrierWithGroupSync();
}
// Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS_X] shared array into [height][1] column
for( i = THREADS_X / 2; i > 0; i /= 2 )
{
if( thread.x < i )
{
for( uint j = thread.y; j < height; j += THREADS_Y )
{
float sum = resTemp[ j ][ thread.x ];
sum += resTemp[ j ][ thread.x + i ];
resTemp[ j ][ thread.x ] = sum;
}
}
GroupMemoryBarrierWithGroupSync();
}
// And finally, store that column to global memory
if( threadFlattenned >= height )
return;
uint rdi = hadd( group.yz * resultStrides.zw ) + group.x * TILE_Y * resultStrides.x;
rdi += threadFlattenned * resultStrides.x;
result[ rdi ] = resTemp[ threadFlattenned ][ 0 ];
}
|