summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/mulMatImpl.panel.cpp
blob: f3baf2170b282130ed04d9efb00eb4b6c375ba21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#include "stdafx.h"
#include <intrin.h>
#include "mulMatImpl.h"
#include "mulMatUtils.hpp"
using namespace CpuCompute;

// We want to keep code size reasonable, that's why these panel reshaping methods are in the base class
HRESULT MulMatBase::transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
{
	assert( stridesA[ 0 ] == 1 );

	const size_t heightFloats = (size_t)panelHeightRegisters * 8;
	i *= heightFloats;

	const uint16_t* rsi = (const uint16_t*)pa;
	rsi += m3 * stridesA[ 3 ];
	rsi += m2 * stridesA[ 2 ];
	rsi += i * stridesA[ 1 ];

	const size_t resultStride = heightFloats;

	if( i + heightFloats <= resultSize[ 0 ] )
	{
		// A complete panel
		for( size_t i = 0; i < panelHeightRegisters; i++ )
		{
			transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
			// Advance by 8 floats in the output buffer
			rdi += 8;
			// Advance by 8 rows in the source matrix
			rsi += 8 * stridesA[ 1 ];
		}
	}
	else
	{
		// A partial panel, at the bottom of the first argument matrix
		const size_t remainder = resultSize[ 0 ] - i;
		assert( remainder > 0 && remainder < heightFloats );
		zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) );

		const size_t completePanels = remainder / 8;
		for( size_t i = 0; i < completePanels; i++ )
		{
			transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
			rdi += 8;
			rsi += 8 * stridesA[ 1 ];
		}
		const size_t lastPanel = remainder % 8;
		if( 0 != lastPanel )
			transpose8Partial( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride );
	}
	return S_OK;
}

inline const uint16_t* MulMatBase::getPanelA( size_t i, size_t m2, size_t m3 ) const
{
	const uint16_t* rsi = (const uint16_t*)pa;
	rsi += m3 * stridesA[ 3 ];
	rsi += m2 * stridesA[ 2 ];
	rsi += i * stridesA[ 1 ];
	return rsi;
}

HRESULT MulMatBase::copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
{
	assert( stridesA[ 1 ] == 1 );
	assert( panelHeightRegisters == 1 );

	constexpr size_t heightFloats = 8;
	i *= heightFloats;
	const uint16_t* rsi = getPanelA( i, m2, m3 );

	constexpr size_t resultStride = heightFloats;

	if( i + heightFloats <= resultSize[ 0 ] )
	{
		// A complete panel, height = 8 elements
		copyColumnMajor( rdi, length, rsi, stridesA[ 0 ], resultStride );
	}
	else
	{
		// A partial panel, at the bottom of the first argument matrix
		const size_t remainder = resultSize[ 0 ] - i;
		assert( remainder > 0 && remainder < heightFloats );
		copyColumnMajorPartial( rdi, length, remainder, rsi, stridesA[ 0 ], resultStride );
	}
	return S_OK;
}

__forceinline __m128i load8Partial( const uint16_t* x, size_t len )
{
	assert( len > 0 && len < 8 );
	__m128i ix = _mm_setzero_si128();
	switch( len )
	{
	case 1: // load 2 bytes
		ix = _mm_cvtsi32_si128( *x );
		break;
	case 2: // load 4 bytes
		ix = _mm_cvtsi32_si128( *(const int*)x );
		break;
	case 3: // load 6 bytes
		ix = _mm_cvtsi32_si128( *(const int*)x );
		ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
		break;
	case 4: // load 8 bytes
		ix = _mm_cvtsi64_si128( *(const int64_t*)x );
		break;
	case 5: // load 10 bytes
		ix = _mm_cvtsi64_si128( *(const int64_t*)x );
		ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
		break;
	case 6: // load 12 bytes
		ix = _mm_cvtsi64_si128( *(const int64_t*)x );
		ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
		break;
	case 7: // load 14 bytes
		ix = _mm_cvtsi64_si128( *(const int64_t*)x );
		ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
		ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
		break;
	}
	return ix;
}

__forceinline __m256i load16Partial( const uint16_t* rsi, size_t len )
{
	assert( len > 0 && len < 16 );

	if( len < 8 )
	{
		__m128i low = load8Partial( rsi, len );
		return _mm256_setr_m128i( low, _mm_setzero_si128() );
	}
	else if( len > 8 )
	{
		__m128i low = load16( (const int*)rsi );
		__m128i high = load8Partial( rsi + 8, len - 8 );
		return _mm256_setr_m128i( low, high );
	}
	else
	{
		__m128i low = load16( (const int*)rsi );
		return _mm256_setr_m128i( low, _mm_setzero_si128() );
	}
}

HRESULT MulMatBase::copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
{
	assert( stridesA[ 1 ] == 1 );
	assert( panelHeightRegisters == 2 );

	constexpr size_t heightFloats = 16;
	i *= heightFloats;

	const uint16_t* rsi = getPanelA( i, m2, m3 );
	uint16_t* const rdiEnd = rdi + 16 * length;

	if( i + heightFloats <= resultSize[ 0 ] )
	{
		// A complete panel, height = 16 elements
		for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
		{
			__m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
			_mm256_store_si256( ( __m256i* )rdi, v );
		}
	}
	else
	{
		// A partial panel, at the bottom of the first argument matrix
		const size_t remainder = resultSize[ 0 ] - i;
		assert( remainder > 0 && remainder < heightFloats );

		for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
		{
			__m256i v = load16Partial( rsi, remainder );
			_mm256_store_si256( ( __m256i* )rdi, v );
		}
	}
	return S_OK;
}

HRESULT MulMatBase::copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
{
	assert( stridesA[ 1 ] == 1 );
	assert( panelHeightRegisters == 4 );

	constexpr size_t heightFloats = 32;
	i *= heightFloats;

	const uint16_t* rsi = getPanelA( i, m2, m3 );
	uint16_t* const rdiEnd = rdi + 32 * length;

	if( i + heightFloats <= resultSize[ 0 ] )
	{
		// A complete panel, height = 32 elements
		for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
		{
			__m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
			_mm256_store_si256( ( __m256i* )rdi, v );
			v = _mm256_loadu_si256( ( const __m256i* )( rsi + 16 ) );
			_mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
		}
	}
	else
	{
		// A partial panel, at the bottom of the first argument matrix
		const size_t remainder = resultSize[ 0 ] - i;
		assert( remainder > 0 && remainder < heightFloats );

		// _mm256_setzero_si256 probably compiles into vpxor, that's AVX2, we don't want that here
		const __m256 zero = _mm256_setzero_ps();

		for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
		{
			if( remainder < 16 )
			{
				__m256i v = load16Partial( rsi, remainder );
				_mm256_store_si256( ( __m256i* )rdi, v );
				_mm256_store_ps( (float*)( rdi + 16 ), zero );
			}
			else if( remainder > 16 )
			{
				__m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
				_mm256_store_si256( ( __m256i* )rdi, v );
				v = load16Partial( rsi + 16, remainder - 16 );
				_mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
			}
			else
			{
				__m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
				_mm256_store_si256( ( __m256i* )rdi, v );
				_mm256_store_ps( (float*)( rdi + 16 ), zero );
			}
		}
	}
	return S_OK;
}

HRESULT MulMatBase::gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
{
	// BTW, I never saw this method called.
	const size_t heightFloats = (size_t)panelHeightRegisters * 8;
	const size_t length = this->length;

	zeroAlignedMemory( rdi, length * heightFloats * sizeof( uint16_t ) );

	const size_t height = std::min( heightFloats, resultSize[ 0 ] - i );
	const size_t strideElement = stridesA[ 0 ];
	const size_t strideRow = stridesA[ 1 ];
	const uint16_t* rsi = getPanelA( i * heightFloats, m2, m3 );

	if( strideElement < strideRow )
	{
		for( size_t r = 0; r < height; r++, rsi += strideRow, rdi++ )
		{
			const uint16_t* sourceRow = rsi;
			uint16_t* destRow = rdi;
			for( size_t c = 0; c < length; c++, sourceRow += strideElement, destRow += heightFloats )
				*destRow = *sourceRow;
		}
	}
	else
	{
		for( size_t c = 0; c < length; c++, rsi += strideElement, rdi += heightFloats )
		{
			const uint16_t* sourceCol = rsi;
			uint16_t* destCol = rdi;
			for( size_t r = 0; r < height; r++, sourceCol += strideRow, destCol++ )
				*destCol = *sourceCol;
		}
	}
	return S_OK;
}