summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/mulMatImpl.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/CPU/mulMatImpl.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/CPU/mulMatImpl.cpp')
-rw-r--r--Whisper/CPU/mulMatImpl.cpp213
1 files changed, 213 insertions, 0 deletions
diff --git a/Whisper/CPU/mulMatImpl.cpp b/Whisper/CPU/mulMatImpl.cpp
new file mode 100644
index 0000000..fc50b03
--- /dev/null
+++ b/Whisper/CPU/mulMatImpl.cpp
@@ -0,0 +1,213 @@
+#include "stdafx.h"
+#include <intrin.h>
+#include "mulMatImpl.h"
+#include "mulMat.kernel.hpp"
+
+#define DBG_TRACK_TEMPLATE_INSTANTIATION 0
+
+#if DBG_TRACK_TEMPLATE_INSTANTIATION
+#include <unordered_set>
+static std::unordered_set<uint16_t> g_mulMatTemplates;
+#endif
+
+namespace
+{
+ using namespace CpuCompute;
+
+ bool checkAvx2Support()
+ {
+ int cpuInfo[ 4 ];
+ __cpuid( cpuInfo, 7 );
+ return ( cpuInfo[ 1 ] & ( 1 << 5 ) ) != 0;
+ }
+
+ // a / b, rounded up to the next integer
+ inline uint32_t divRoundUp( uint32_t a, uint32_t b )
+ {
+ assert( b != 0 );
+ return ( a + ( b - 1 ) ) / b;
+ }
+}
+
+const bool MulMatBase::haveAvx2 = checkAvx2Support();
+
+MulMatBase::MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats ) :
+ resultPointer( result.fp32() ),
+ pa( a.data() ),
+ pb( b.data() ),
+ runner( pfor )
+{
+ length = a.ne[ 0 ];
+ resultStrides[ 0 ] = result.nb[ 1 ];
+ resultStrides[ 1 ] = result.nb[ 2 ];
+ resultStrides[ 2 ] = result.nb[ 3 ];
+ store( resultSize, result.sizeVec() );
+ store( stridesA, a.stridesVec() );
+ store( stridesB, b.stridesVec() );
+
+ countPanels = divRoundUp( resultSize[ 0 ], panelHeightRegs * 8 );
+ completeTilesPerPanel = resultSize[ 1 ] / tileWidthFloats;
+ lastColumnsInPanel = (uint8_t)( resultSize[ 1 ] % tileWidthFloats );
+ this->panelHeightRegisters = panelHeightRegs;
+ this->tileWidth = tileWidthFloats;
+
+ // Pick a method which reshapes a panel of the matrix A into the shape we need to compute the product
+ // Store the pointer to that method in the field of this class
+ if( a.nb[ 0 ] == 1 )
+ {
+ if( haveAvx2 )
+ pfnMakePanel = &MulMatBase::transposePanelAvx2;
+ else
+ pfnMakePanel = &MulMatBase::transposePanel;
+ }
+ else if( a.nb[ 1 ] == 1 )
+ {
+ switch( panelHeightRegs )
+ {
+ case 1:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor8;
+ break;
+ case 2:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor16;
+ break;
+ case 4:
+ pfnMakePanel = &MulMatBase::copyPanelColumnMajor32;
+ break;
+ default:
+ throw E_NOTIMPL;
+ }
+ }
+ else
+ pfnMakePanel = &MulMatBase::gatherPanel;
+
+ // That last version is generic and very simple, unlikely to have weird bugs
+ // pfnMakePanel = &MulMatBase::gatherPanel;
+
+#if DBG_TRACK_TEMPLATE_INSTANTIATION
+ uint16_t key = panelHeightRegs;
+ key = key << 8;
+ key |= tileWidthFloats;
+ if( !g_mulMatTemplates.emplace( key ).second )
+ return;
+ logDebug( u8"MulMatImpl<panelHeightRegs = %i, tileWidthFloats = %i>", (int)panelHeightRegs, (int)tileWidthFloats );
+#endif
+}
+
+HRESULT MulMatBase::run( ParallelForRunner& pfor )
+{
+ size_t length = (size_t)countPanels * resultSize[ 2 ] * resultSize[ 3 ];
+ return pfor.parallelFor( *this, length );
+}
+
+const float* MulMatBase::getLayerB( size_t m2, size_t m3 ) const
+{
+ const float* rsi = (const float*)this->pb;
+ rsi += m2 * stridesB[ 2 ];
+ rsi += m3 * stridesB[ 3 ];
+ return rsi;
+}
+
+// This method is the main one, it’s called by the thread pool
+template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+HRESULT __stdcall MulMatImpl<panelHeightRegs, tileWidthFloats>::compute( size_t i, size_t end ) const noexcept
+{
+ // Allocate a thread-local buffer for the transposed panel
+ constexpr size_t panelHeightFloats = panelHeightRegs * 8;
+ uint16_t* const panel = (uint16_t*)runner.threadLocalBuffer( floatsPerPanel() * 2 );
+ const size_t resultStride = resultStrides[ 0 ];
+
+ // Load a few numbers from this class into local variables, while upcasting from DWORD into size_t
+ const size_t length = this->length;
+ const std::array<size_t, 2> stridesB{ this->stridesB[ 0 ], this->stridesB[ 1 ] };
+
+ // This outer loop iterates over the panels assigned to the current thread
+ // For example, matrix A of size [ 1024, 1024 ] may be split into panels of size [ 1024, 16 ]
+ // Each iteration of that loop computes matrix product of that panel, with the complete matrix B
+ for( ; i < end; i++ )
+ {
+ const size_t iPanel = i % countPanels;
+ size_t j = i / countPanels;
+ const size_t m2 = j % (size_t)resultSize[ 2 ];
+ const size_t m3 = j / (size_t)resultSize[ 2 ];
+
+ CHECK( ( this->*pfnMakePanel )( panel, iPanel, m2, m3 ) );
+ // We got a column-major panel in the thread local buffer, of size [ length, panelHeightRegs * 8 ]
+ // Hopefully, these buffers should all fit at least in L3 cache
+ // The longest matrix I saw in the debugger had 4096 elements, with panelHeightRegs = 4 that's 256 kb of data in the panel
+ const float* pb = getLayerB( m2, m3 );
+ float* rdi = getPanelDest( iPanel, m2, m3 );
+
+ const size_t storeWidth = std::min( panelHeightFloats, (size_t)resultSize[ 0 ] - iPanel * panelHeightFloats );
+ std::array<__m256, panelHeightRegs> vecPanel;
+#if 1
+ ResultTile<panelHeightRegs, tileWidthFloats> tile;
+
+ // This loop iterates over tiles within the panel.
+ // Each iteration of the loop computes an output tile of the result matrix.
+ for( j = 0; j < completeTilesPerPanel; j++, pb += tileWidthFloats * stridesB[ 1 ], rdi += resultStride * tileWidthFloats )
+ {
+ setZero( tile.arr );
+ const uint16_t* rsiA = panel;
+ const uint16_t* const rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ // This loop runs for `length` iterations, iterates over the first dimensions of both matrices, accumulating these dot products we're after
+ for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ tile.kernel( vecPanel, rsiB, stridesB[ 1 ] );
+ }
+ tile.store( rdi, storeWidth, tileWidthFloats, resultStride );
+ }
+
+ if( 0 != lastColumnsInPanel )
+ {
+ setZero( tile.arr );
+ const uint16_t* rsiA = panel;
+ const uint16_t* rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ for( ; rsiA < rsiAEnd; rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ tile.kernelPartial( vecPanel, rsiB, stridesB[ 1 ], lastColumnsInPanel );
+ }
+ tile.store( rdi, storeWidth, lastColumnsInPanel, resultStride );
+ }
+#else
+ // This version bypasses horizontal tiling, instead implements a brute force algorithm to multiply the current panel by the complete B matrix
+ // Not terribly efficient, only implemented for debugging purposes
+ const size_t resHeight = resultSize[ 1 ];
+ std::array<__m256, panelHeightRegs> tile;
+ for( size_t j = 0; j < resHeight; j++, pb += stridesB[ 1 ], rdi += resultStride )
+ {
+ setZero( tile );
+
+ const uint16_t* rsiA = panel;
+ const uint16_t* const rsiAEnd = panel + length * panelHeightFloats;
+ const float* rsiB = pb;
+ for( size_t k = 0; k < length; k++, rsiA += panelHeightFloats, rsiB += stridesB[ 0 ] )
+ {
+ loadPanel( rsiA, vecPanel );
+ const __m256 b = _mm256_broadcast_ss( rsiB );
+ for( size_t r = 0; r < panelHeightRegs; r++ )
+ tile[ r ] = _mm256_fmadd_ps( vecPanel[ r ], b, tile[ r ] );
+ }
+
+ alignas( 32 ) std::array<float, panelHeightFloats> arr;
+ for( size_t k = 0; k < panelHeightRegs; k++ )
+ _mm256_store_ps( &arr[ k * 8 ], tile[ k ] );
+ memcpy( rdi, arr.data(), storeWidth * 4 );
+ }
+#endif
+ }
+ return S_OK;
+}
+
+// Instantiate the templates we need
+template class MulMatImpl<4, 1>;
+template class MulMatImpl<1, 1>;
+template class MulMatImpl<4, 2>;
+template class MulMatImpl<1, 2>;
+template class MulMatImpl<2, 3>;
+template class MulMatImpl<1, 3>;
+template class MulMatImpl<2, 4>;
+template class MulMatImpl<1, 4>; \ No newline at end of file