summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/mulMatByRowTiled.hlsl
blob: 8bf62d77771d7877d9cdb9a557a153566f559900 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ]
// Dispatch [ ( E1 + TILE_Y - 1 ) / TILE_Y, E2, E3 ] thread groups of this shader
// This one here is the second most expensive shader in the model, after matrix*matrix product.
// Optimized heavily, as a result the readability ain't great.

#ifndef TILE_Y
static const uint TILE_Y = 64;
#endif
#ifndef THREADS_X
static const uint THREADS_X = 32;
#endif
#ifndef THREADS_Y
static const uint THREADS_Y = 16;
#endif

Buffer<float> arg0: register( t0 );
Buffer<float> arg1: register( t1 );
RWBuffer<float> result: register( u0 );

cbuffer Constants: register( b0 )
{
	uint4 arg0Size: packoffset( c0 );
	uint4 arg0Strides: packoffset( c1 );
	uint4 arg1Size: packoffset( c2 );
	uint4 arg1Strides: packoffset( c3 );
	uint4 resultSize: packoffset( c4 );
	uint4 resultStrides: packoffset( c5 );
}

inline uint hadd( uint2 vec )
{
	return vec.x + vec.y;
}

// Count of FP32 accumulators we need in every thread of the shader
static const uint heightScalars = TILE_Y / THREADS_Y;
// The local accumulators are float4 vectors, compute count of these vectors
static const uint heightVectors = ( heightScalars + 3 ) / 4;

groupshared float4 reductionBuffer[ heightVectors ][ THREADS_Y ][ THREADS_X ];

[numthreads( THREADS_X, THREADS_Y, 1 )]
void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID )
{
	uint i;
	// Despite inside GPU cores, the shared memory is still much slower than registers
	// For this reason, this shader accumulates numbers in local variables. Only uses groupshared buffer for the final reduction.
	float4 acc[ heightVectors ];
	// Zero out the accumulators
	[unroll]
	for( i = 0; i < heightVectors; i++ )
		acc[ i ] = 0.0;

	// Count of rows to compute in this thread group
	const uint height = min( TILE_Y, arg0Size.y - group.x * TILE_Y );

	uint s0 = hadd( group.yz * arg0Strides.zw );   //< arg0 layer for the thread group
	s0 += group.x * TILE_Y * arg0Strides.y;        //< arg0 first row for the thread group
	s0 += hadd( arg0Strides.xy * thread.xy );      //< arg0 load index for the thread

	uint s1 = hadd( group.yz * arg1Strides.zw );   //< arg1 layer for the thread group
	s1 += thread.x * arg1Strides.x;                //< arg1 load index for the thread

	const uint completeTiles = arg0Size.x / THREADS_X;
	// Each iteration of that loop loads THREADS_X elements from arg1,
	// a block of [ THREADS_X, height ] elements from arg0,
	// and accumulates these dot products in the local variables
	for( uint t = 0; t < completeTiles; t++, s0 += THREADS_X * arg0Strides.x, s1 += THREADS_X * arg1Strides.x )
	{
		// Load THREADS_X elements from arg1
		const float v1 = arg1[ s1 ];

		uint rsi = s0;
		[unroll]
		for( i = 0; i < heightVectors; i++ )
		{
			float4 v0 = 0.0;
			// Load up to 4*THREADS_X elements from arg0
			[unroll]
			for( uint j = 0; j < 4; j++, rsi += arg0Strides.y * THREADS_Y )
			{
				const uint y = ( i * 4 + j ) * THREADS_Y + thread.y;
				[branch]
				if( y < height )
					v0[ j ] = arg0[ rsi ];
			}
			// Multiply + accumulate
			acc[ i ] = mad( v0, v1, acc[ i ] );
		}
	}

	const uint rem = arg0Size.x % THREADS_X;
	if( thread.x < rem )
	{
		// E0 ain't a multiple of THREADS_X, we have a remainder

		// Load `rem` elements from arg1
		const float v1 = arg1[ s1 ];

		[unroll]
		for( i = 0; i < heightVectors; i++ )
		{
			float4 v0 = 0.0;
			// Load up to 4*rem elements from arg0
			[unroll]
			for( uint j = 0; j < 4; j++, s0 += arg0Strides.y * THREADS_Y )
			{
				const uint y = ( i * 4 + j ) * THREADS_Y + thread.y;
				[branch]
				if( y < height )
					v0[ j ] = arg0[ s0 ];
			}
			// Multiply + accumulate
			acc[ i ] = mad( v0, v1, acc[ i ] );
		}
	}

	// Now we need horizontal sum of these accumulators, reducing [height][THREADS_X] of them into [height][1] column
	// First, store local variables into the shared memory.
	[ unroll ]
	for( i = 0; i < heightVectors; i++ )
		reductionBuffer[ i ][ thread.y ][ thread.x ] = acc[ i ];
	GroupMemoryBarrierWithGroupSync();

	// Run reduction using that shared memory buffer
	for( i = THREADS_X / 2; i > 1; i /= 2 )
	{
		if( thread.x < i )
		{
			[unroll]
			for( uint iv = 0; iv < heightVectors; iv++ )
			{
				float4 that = reductionBuffer[ iv ][ thread.y ][ thread.x + i ];
				float4 tmp = acc[ iv ];
				tmp += that;
				reductionBuffer[ iv ][ thread.y ][ thread.x ] = tmp;
				acc[ iv ] = tmp;
			}
		}
		GroupMemoryBarrierWithGroupSync();
	}

	// And finally, store that column to global memory.
	// Only running that code on the threads of the group with thread.x = 0, to save a few loads from the groupshared buffer
	// This allows to use registers instead, faster to access
	if( thread.x != 0 )
		return;

	uint rdi = hadd( group.yz * resultStrides.zw );
	rdi += ( group.x * TILE_Y + thread.y ) * resultStrides.x;
	const uint rdiInc = THREADS_Y * resultStrides.x;

	[unroll]
	for( i = 0; i < heightVectors; i++ )
	{
		// The previous loop had "i > 1" continue condition, it didn't complete the last step of the reduction
		// The following line is doing that last reduction step
		const float4 resultVec = acc[ i ] + reductionBuffer[ i ][ thread.y ][ 1 ];

		// Conditionally store these 4 floats to the output tensor
		[unroll]
		for( uint j = 0; j < 4; j++, rdi += rdiInc )
		{
			const uint y = ( i * 4 + j ) * THREADS_Y + thread.y;
			[branch]
			if( y < height )
				result[ rdi ] = resultVec[ j ];
		}
	}
}