summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/mulMatImpl.panel.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/CPU/mulMatImpl.panel.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/CPU/mulMatImpl.panel.cpp')
-rw-r--r--Whisper/CPU/mulMatImpl.panel.cpp274
1 files changed, 274 insertions, 0 deletions
diff --git a/Whisper/CPU/mulMatImpl.panel.cpp b/Whisper/CPU/mulMatImpl.panel.cpp
new file mode 100644
index 0000000..f3baf21
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.panel.cpp
@@ -0,0 +1,274 @@
+#include "stdafx.h"
+#include <intrin.h>
+#include "mulMatImpl.h"
+#include "mulMatUtils.hpp"
+using namespace CpuCompute;
+
+// We want to keep code size reasonable, that's why these panel reshaping methods are in the base class
+HRESULT MulMatBase::transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 0 ] == 1 );
+
+ const size_t heightFloats = (size_t)panelHeightRegisters * 8;
+ i *= heightFloats;
+
+ const uint16_t* rsi = (const uint16_t*)pa;
+ rsi += m3 * stridesA[ 3 ];
+ rsi += m2 * stridesA[ 2 ];
+ rsi += i * stridesA[ 1 ];
+
+ const size_t resultStride = heightFloats;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel
+ for( size_t i = 0; i < panelHeightRegisters; i++ )
+ {
+ transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ // Advance by 8 floats in the output buffer
+ rdi += 8;
+ // Advance by 8 rows in the source matrix
+ rsi += 8 * stridesA[ 1 ];
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+ zeroAlignedMemory( rdi, resultStride * length * sizeof( uint16_t ) );
+
+ const size_t completePanels = remainder / 8;
+ for( size_t i = 0; i < completePanels; i++ )
+ {
+ transpose8( rdi, length, rsi, stridesA[ 1 ], resultStride );
+ rdi += 8;
+ rsi += 8 * stridesA[ 1 ];
+ }
+ const size_t lastPanel = remainder % 8;
+ if( 0 != lastPanel )
+ transpose8Partial( rdi, length, lastPanel, rsi, stridesA[ 1 ], resultStride );
+ }
+ return S_OK;
+}
+
+inline const uint16_t* MulMatBase::getPanelA( size_t i, size_t m2, size_t m3 ) const
+{
+ const uint16_t* rsi = (const uint16_t*)pa;
+ rsi += m3 * stridesA[ 3 ];
+ rsi += m2 * stridesA[ 2 ];
+ rsi += i * stridesA[ 1 ];
+ return rsi;
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 1 );
+
+ constexpr size_t heightFloats = 8;
+ i *= heightFloats;
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+
+ constexpr size_t resultStride = heightFloats;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 8 elements
+ copyColumnMajor( rdi, length, rsi, stridesA[ 0 ], resultStride );
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+ copyColumnMajorPartial( rdi, length, remainder, rsi, stridesA[ 0 ], resultStride );
+ }
+ return S_OK;
+}
+
+__forceinline __m128i load8Partial( const uint16_t* x, size_t len )
+{
+ assert( len > 0 && len < 8 );
+ __m128i ix = _mm_setzero_si128();
+ switch( len )
+ {
+ case 1: // load 2 bytes
+ ix = _mm_cvtsi32_si128( *x );
+ break;
+ case 2: // load 4 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ break;
+ case 3: // load 6 bytes
+ ix = _mm_cvtsi32_si128( *(const int*)x );
+ ix = _mm_insert_epi16( ix, x[ 2 ], 2 );
+ break;
+ case 4: // load 8 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ break;
+ case 5: // load 10 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi16( ix, x[ 4 ], 4 );
+ break;
+ case 6: // load 12 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ break;
+ case 7: // load 14 bytes
+ ix = _mm_cvtsi64_si128( *(const int64_t*)x );
+ ix = _mm_insert_epi32( ix, *(const int*)( x + 4 ), 2 );
+ ix = _mm_insert_epi16( ix, x[ 6 ], 6 );
+ break;
+ }
+ return ix;
+}
+
+__forceinline __m256i load16Partial( const uint16_t* rsi, size_t len )
+{
+ assert( len > 0 && len < 16 );
+
+ if( len < 8 )
+ {
+ __m128i low = load8Partial( rsi, len );
+ return _mm256_setr_m128i( low, _mm_setzero_si128() );
+ }
+ else if( len > 8 )
+ {
+ __m128i low = load16( (const int*)rsi );
+ __m128i high = load8Partial( rsi + 8, len - 8 );
+ return _mm256_setr_m128i( low, high );
+ }
+ else
+ {
+ __m128i low = load16( (const int*)rsi );
+ return _mm256_setr_m128i( low, _mm_setzero_si128() );
+ }
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 2 );
+
+ constexpr size_t heightFloats = 16;
+ i *= heightFloats;
+
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+ uint16_t* const rdiEnd = rdi + 16 * length;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 16 elements
+ for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+
+ for( ; rdi < rdiEnd; rdi += 16, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = load16Partial( rsi, remainder );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ }
+ }
+ return S_OK;
+}
+
+HRESULT MulMatBase::copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ assert( stridesA[ 1 ] == 1 );
+ assert( panelHeightRegisters == 4 );
+
+ constexpr size_t heightFloats = 32;
+ i *= heightFloats;
+
+ const uint16_t* rsi = getPanelA( i, m2, m3 );
+ uint16_t* const rdiEnd = rdi + 32 * length;
+
+ if( i + heightFloats <= resultSize[ 0 ] )
+ {
+ // A complete panel, height = 32 elements
+ for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ v = _mm256_loadu_si256( ( const __m256i* )( rsi + 16 ) );
+ _mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
+ }
+ }
+ else
+ {
+ // A partial panel, at the bottom of the first argument matrix
+ const size_t remainder = resultSize[ 0 ] - i;
+ assert( remainder > 0 && remainder < heightFloats );
+
+ // _mm256_setzero_si256 probably compiles into vpxor, that's AVX2, we don't want that here
+ const __m256 zero = _mm256_setzero_ps();
+
+ for( ; rdi < rdiEnd; rdi += 32, rsi += stridesA[ 0 ] )
+ {
+ if( remainder < 16 )
+ {
+ __m256i v = load16Partial( rsi, remainder );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ _mm256_store_ps( (float*)( rdi + 16 ), zero );
+ }
+ else if( remainder > 16 )
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ v = load16Partial( rsi + 16, remainder - 16 );
+ _mm256_store_si256( ( __m256i* )( rdi + 16 ), v );
+ }
+ else
+ {
+ __m256i v = _mm256_loadu_si256( ( const __m256i* )rsi );
+ _mm256_store_si256( ( __m256i* )rdi, v );
+ _mm256_store_ps( (float*)( rdi + 16 ), zero );
+ }
+ }
+ }
+ return S_OK;
+}
+
+HRESULT MulMatBase::gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const
+{
+ // BTW, I never saw this method called.
+ const size_t heightFloats = (size_t)panelHeightRegisters * 8;
+ const size_t length = this->length;
+
+ zeroAlignedMemory( rdi, length * heightFloats * sizeof( uint16_t ) );
+
+ const size_t height = std::min( heightFloats, resultSize[ 0 ] - i );
+ const size_t strideElement = stridesA[ 0 ];
+ const size_t strideRow = stridesA[ 1 ];
+ const uint16_t* rsi = getPanelA( i * heightFloats, m2, m3 );
+
+ if( strideElement < strideRow )
+ {
+ for( size_t r = 0; r < height; r++, rsi += strideRow, rdi++ )
+ {
+ const uint16_t* sourceRow = rsi;
+ uint16_t* destRow = rdi;
+ for( size_t c = 0; c < length; c++, sourceRow += strideElement, destRow += heightFloats )
+ *destRow = *sourceRow;
+ }
+ }
+ else
+ {
+ for( size_t c = 0; c < length; c++, rsi += strideElement, rdi += heightFloats )
+ {
+ const uint16_t* sourceCol = rsi;
+ uint16_t* destCol = rdi;
+ for( size_t r = 0; r < height; r++, sourceCol += strideRow, destCol++ )
+ *destCol = *sourceCol;
+ }
+ }
+ return S_OK;
+} \ No newline at end of file