diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/CPU/mulMat.cpp | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'Whisper/CPU/mulMat.cpp')
| -rw-r--r-- | Whisper/CPU/mulMat.cpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/Whisper/CPU/mulMat.cpp b/Whisper/CPU/mulMat.cpp new file mode 100644 index 0000000..1b6ed25 --- /dev/null +++ b/Whisper/CPU/mulMat.cpp @@ -0,0 +1,54 @@ +#include "stdafx.h" +#include "mulMat.h" +#include "mulMatImpl.h" +using namespace CpuCompute; + +namespace +{ + template<uint8_t panelHeightRegs, uint8_t tileWidthFloats> + static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) + { + MulMatImpl<panelHeightRegs, tileWidthFloats> impl{ result, a, b, pfor }; + return impl.run( pfor ); + } +} + +HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) +{ + if( a.type() != eDataType::FP16 ) + return E_NOTIMPL; + if( b.type() != eDataType::FP32 ) + return E_NOTIMPL; + + // return mulMatImpl<1, 1>( result, a, b, pfor ); + + if( b.ne[ 1 ] == 1 ) + { + // Multiplying by a single row + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 1>( result, a, b, pfor ); + else + return mulMatImpl<1, 1>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 2 ) + { + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 2>( result, a, b, pfor ); + else + return mulMatImpl<1, 2>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 3 ) + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 3>( result, a, b, pfor ); + else + return mulMatImpl<1, 3>( result, a, b, pfor ); + } + else + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 4>( result, a, b, pfor ); + else + return mulMatImpl<1, 4>( result, a, b, pfor ); + } +}
\ No newline at end of file |
