summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/mulMatImpl.h
blob: 8e0062f78161f37902b506b5201237e3b227c676 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#pragma once
// Matrix*matrix multiplication is the most expensive algorithm in the model, by far.
// For this reason, the code in this source file, and in the mulMat.kernel.hpp header, is optimized for performance. Readability suffers.
// The implementation is inspired by following two articles:
// https://gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0
// https://link.springer.com/article/10.1007/s11227-022-05003-3
#include "ParallelForRunner.h"
#include "Tensor.h"

namespace CpuCompute
{
	// Abstract base class for all implementations, to reduce binary size
	class MulMatBase : public iComputeRange
	{
	protected:
		// Pointers to the payload of the output matrix
		float* const resultPointer;

		// Lengths of the dot products to compute, equal to width of both source matrices
		uint32_t length;

		// Last 3 strides of the output matrix, expressed as count of elements. The first one is always 1 because the output matrix is continuous.
		std::array<uint32_t, 3> resultStrides;

		// Size of the output matrix
		std::array<uint32_t, 4> resultSize;

		// Pointers to the payload of the source matrices
		const void* const pa;
		const void* const pb;

		// Matrix strides, expressed as count of elements
		std::array<uint32_t, 4> stridesA, stridesB;

		// Total count of panels in the layer of the output matrix.
		// The last panel might be incomplete, with smaller height.
		// The thread-local buffer however is always complete, unused elements will be zeros.
		uint32_t countPanels;

		// Complete tiles in the length of the panel
		uint32_t completeTilesPerPanel;

		// Count of the last remainder columns in the panel, can be 0
		uint8_t lastColumnsInPanel;

		// Same as panelHeightRegs template argument - height of the panels, in AVX vectors
		uint8_t panelHeightRegisters;

		// Same as tileWidthFloats template argument - width of the tile, in floats
		uint8_t tileWidth;

		// Method pointer to reshape a panel from the source matrix into a thread-local buffer
		using pfnTransposePanel = HRESULT( MulMatBase::* )( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		pfnTransposePanel pfnMakePanel;
		// The object which implements multithreading for this job, and supplies memory for thread-local buffers
		ParallelForRunner& runner;

		// Count of FP16 values in the thread-local panel buffer
		uint32_t floatsPerPanel() const
		{
			return length * panelHeightRegisters * 8;
		}

		// Transpose a horizontal panel of the first matrix, when the rows are continuous in that matrix
		HRESULT transposePanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		HRESULT transposePanelAvx2( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		// Copy a horizontal panel of the first matrix without transpose, for column major layout of that matrix
		HRESULT copyPanelColumnMajor8( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		HRESULT copyPanelColumnMajor16( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		HRESULT copyPanelColumnMajor32( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;
		// Transpose a panel of the first matrix for irregular layout of that matrix, when neither rows nor columns are at sequential addresses.
		// This one ain't implemented yet.
		HRESULT gatherPanel( uint16_t* rdi, size_t i, size_t m2, size_t m3 ) const;

		const uint16_t* getPanelA( size_t i, size_t m2, size_t m3 ) const;
		// Pointer to the first element of the second source matrix in the specified layer
		const float* getLayerB( size_t m2, size_t m3 ) const;

		// Pointer to the first element of the output tile of the result matrix
		float* getPanelDest( size_t i, size_t m2, size_t m3 ) const
		{
			float* rdi = resultPointer;
			rdi += m2 * resultStrides[ 1 ];
			rdi += m3 * resultStrides[ 2 ];
			rdi += i * panelHeightRegisters * 8;
			return rdi;
		}

		static const bool haveAvx2;
	public:
		MulMatBase( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor, uint8_t panelHeightRegs, uint8_t tileWidthFloats );
		HRESULT run( ParallelForRunner& pfor );
	};

	// This class actually contains the kernels implementations
	template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
	class MulMatImpl : public MulMatBase
	{
		HRESULT __stdcall compute( size_t i, size_t end ) const noexcept override final;

	public:
		MulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) :
			MulMatBase( result, a, b, pfor, panelHeightRegs, tileWidthFloats )
		{ }
	};
}