summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/mulMatTiledEx.hlsl
blob: c59e3ea5110e48552b86672fac5483cb32dfb45e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
// This compute shader implements yet another version of matrix*matrix product
// For optimal VRAM access pattern, it requires both arguments to be reshaped into a sequence of horizontal column major panels.
// The panel height is TILE_SIZE, and the last panel of the matrix needs to be padded with zeros; see matReshapePanels.hlsl shader for the reshaping.
// So far, it's only used when running on AMD GPUs.
#ifndef TILE_SIZE
static const uint TILE_SIZE = 32;
#endif
#ifndef TILE_HEIGHT
static const uint TILE_HEIGHT = 64;
#endif
#ifndef THREADS_Y
static const uint THREADS_Y = 8;
#endif
// The above values have a following constraint: TILE_SIZE = THREADS_Y * N * 4 where N is an integer

#ifndef STREAM_SECOND_MATRIX
#define STREAM_SECOND_MATRIX 1
#endif

// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
Buffer<float> arg0: register( t0 );
// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
Buffer<float> arg1: register( t1 );
// FP32 output tensor, row major and continuous
RWBuffer<float> result: register( u0 );

cbuffer Constants: register( b0 )
{
	uint4 arg0Size: packoffset( c0 );
	uint arg0panel: packoffset( c1.y );
	uint2 arg0LayerStrides: packoffset( c1.z );

	// uint4 arg1Size: packoffset( c2 );
	uint arg1panel: packoffset( c3.y );
	uint2 arg1LayerStrides: packoffset( c3.z );

	uint4 resultSize: packoffset( c4 );
	uint4 resultStrides: packoffset( c5 );
}

// A smaller tile loaded from the first source matrix
groupshared float tile0[ TILE_HEIGHT ][ TILE_SIZE ];
#if !STREAM_SECOND_MATRIX
// A smaller tile loaded from the second source matrix
groupshared float tile1[ TILE_HEIGHT ][ TILE_SIZE ];
#endif

// Count of FP32 accumulators we need in every thread of the shader
static const uint heightScalars = TILE_SIZE / THREADS_Y;
// The local accumulators are float4 vectors, compute count of these vectors
static const uint heightVectors = ( heightScalars + 3 ) / 4;

#if STREAM_SECOND_MATRIX
void multiplyTiles( const uint3 thread, uint rsi, const uint h, inout float4 acc[ heightVectors ] )
{
	uint4 rsi4 = rsi + uint4( 0, THREADS_Y, THREADS_Y * 2, THREADS_Y * 3 );
	[unroll]
	for( uint iv = 0; iv < heightVectors; iv++, rsi4 += THREADS_Y * 4 )
	{
		float4 r = 0.0;
		uint4 rsiRow = rsi4;
		for( uint j = 0; j < h; j++, rsiRow += TILE_SIZE )
		{
			const float a = tile0[ j ][ thread.x ];
			float4 b = 0.0;
			[unroll]
			for( uint k = 0; k < 4; k++ )
			{
				b[ k ] = arg1[ rsiRow[ k ] ];
			}
			r = mad( a, b, r );
		}
		acc[ iv ] += r;
	}
}
#else
void multiplyTiles( const uint3 thread, inout float4 acc[ heightVectors ] )
{
	[unroll]
	for( uint i = 0; i < heightVectors; i++ )
	{
		float4 r = 0.0;
		for( uint j = 0; j < TILE_HEIGHT; j++ )
		{
			const float a = tile0[ j ][ thread.x ];
			float4 b;
			[unroll]
			for( uint k = 0; k < 4; k++ )
			{
				const uint row = ( i * 4 + k ) * THREADS_Y + thread.y;
				b[ k ] = tile1[ j ][ row ];
			}
			r = mad( a, b, r );
		}
		acc[ i ] += r;
	}
}
#endif

void storeTile( const uint3 thread, const uint4 pos, const uint2 size, in float4 acc[ heightVectors ] )
{
	if( thread.x >= size.x )
		return;

	const uint4 prod4 = pos * resultStrides;
	const uint2 prod2 = prod4.xy + prod4.zw;
	uint rdi = prod2.x + prod2.y;
	rdi += resultStrides.y * thread.y;
	rdi += resultStrides.x * thread.x;

	const uint4 offsets = THREADS_Y * uint4( 0, 1, 2, 3 );	//< a compile-time constant vector
	uint4 rdi4 = resultStrides.y * offsets + rdi;

	[unroll]
	for( uint iv = 0; iv < heightVectors; iv++, rdi4 += resultStrides.y * THREADS_Y * 4 )
	{
		const float4 source = acc[ iv ];
		[unroll]
		for( uint k = 0; k < 4; k++ )
		{
			const uint i = ( iv * 4 + k ) * THREADS_Y + thread.y;
			if( i < size.y )
				result[ rdi4[ k ] ] = source[ k ];
		}
	}
}

[numthreads( TILE_SIZE, THREADS_Y, 1 )]
void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID )
{
	uint i;
	// Zero all shared buffers
	for( i = thread.y; i < TILE_HEIGHT; i += THREADS_Y )
	{
		tile0[ i ][ thread.x ] = 0.0;
#if !STREAM_SECOND_MATRIX
		tile1[ i ][ thread.x ] = 0.0;
#endif
	}
	// Despite inside GPU cores, the shared memory is still much slower than registers
	// For this reason, this shader accumulates numbers in local variables. Only uses groupshared memory for tiles of the argument matrices.
	float4 acc[ heightVectors ];
	// Zero out the accumulators
	[unroll]
	for( i = 0; i < heightVectors; i++ )
		acc[ i ] = 0.0;

	const uint2 layer = uint2( group.z % resultSize.z, group.z / resultSize.z );

	uint rsi0 = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y;
	uint rsi1 = group.y * arg1panel + layer.x * arg1LayerStrides.x + layer.y * arg1LayerStrides.y;

	const uint threadOffset = thread.y * TILE_SIZE + thread.x;
	rsi0 += threadOffset;
#if STREAM_SECOND_MATRIX
	rsi1 += thread.y;
#else
	rsi1 += threadOffset;
#endif

	const uint completeTiles = arg0Size.x / TILE_HEIGHT;
	for( i = 0; i < completeTiles; i++ )
	{
		// Load [ TILE_SIZE, TILE_HEIGHT ] block from both source tensors into these groupshared buffers
		for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
		{
			tile0[ j ][ thread.x ] = arg0[ rsi0 ];
			rsi0 += THREADS_Y * TILE_SIZE;
#if !STREAM_SECOND_MATRIX
			tile1[ j ][ thread.x ] = arg1[ rsi1 ];
			rsi1 += THREADS_Y * TILE_SIZE;
#endif
		}

		// Wait for all threads in the group to complete these loads
		GroupMemoryBarrierWithGroupSync();

#if STREAM_SECOND_MATRIX
		multiplyTiles( thread, rsi1, TILE_HEIGHT, acc );
		rsi1 += TILE_HEIGHT * TILE_SIZE;
#else
		// Multiply + accumulate the elements collected in the groupshared buffers
		multiplyTiles( thread, acc );
#endif
		GroupMemoryBarrierWithGroupSync();
	}

	const uint rem = arg0Size.x % TILE_HEIGHT;
	if( rem != 0 )
	{
		// Load [ TILE_SIZE, rem ] block from both source tensors, and zero out the padding elements
		for( uint j = thread.y; j < TILE_HEIGHT; j += THREADS_Y )
		{
			[branch]
			if( j < rem )
			{
				tile0[ j ][ thread.x ] = arg0[ rsi0 ];
				rsi0 += THREADS_Y * TILE_SIZE;
#if !STREAM_SECOND_MATRIX
				tile1[ j ][ thread.x ] = arg1[ rsi1 ];
				rsi1 += THREADS_Y * TILE_SIZE;
#endif
			}
			else
			{
				tile0[ j ][ thread.x ] = 0.0;
#if !STREAM_SECOND_MATRIX
				tile1[ j ][ thread.x ] = 0.0;
#endif
			}
		}

		// Wait for all threads in the group to complete these loads
		GroupMemoryBarrierWithGroupSync();

		// Multiply + accumulate the elements collected in the groupshared buffers
#if STREAM_SECOND_MATRIX
		multiplyTiles( thread, rsi1, rem, acc );
#else
		multiplyTiles( thread, acc );
#endif
		GroupMemoryBarrierWithGroupSync();
	}

	const uint2 resultPos = group.xy * TILE_SIZE;
	const uint2 outputSize = min( TILE_SIZE, resultSize.xy - resultPos );
	storeTile( thread, uint4( resultPos, layer ), outputSize, acc );
}