From 8c4603c73675958efc960fbd4bb599a2909d106a Mon Sep 17 00:00:00 2001 From: Konstantin Date: Mon, 16 Jan 2023 14:52:43 +0100 Subject: Source codes --- Whisper/CPU/mulMat.cpp | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 Whisper/CPU/mulMat.cpp (limited to 'Whisper/CPU/mulMat.cpp') diff --git a/Whisper/CPU/mulMat.cpp b/Whisper/CPU/mulMat.cpp new file mode 100644 index 0000000..1b6ed25 --- /dev/null +++ b/Whisper/CPU/mulMat.cpp @@ -0,0 +1,54 @@ +#include "stdafx.h" +#include "mulMat.h" +#include "mulMatImpl.h" +using namespace CpuCompute; + +namespace +{ + template + static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) + { + MulMatImpl impl{ result, a, b, pfor }; + return impl.run( pfor ); + } +} + +HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) +{ + if( a.type() != eDataType::FP16 ) + return E_NOTIMPL; + if( b.type() != eDataType::FP32 ) + return E_NOTIMPL; + + // return mulMatImpl<1, 1>( result, a, b, pfor ); + + if( b.ne[ 1 ] == 1 ) + { + // Multiplying by a single row + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 1>( result, a, b, pfor ); + else + return mulMatImpl<1, 1>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 2 ) + { + if( a.ne[ 1 ] >= 32 ) + return mulMatImpl<4, 2>( result, a, b, pfor ); + else + return mulMatImpl<1, 2>( result, a, b, pfor ); + } + else if( b.ne[ 1 ] == 3 ) + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 3>( result, a, b, pfor ); + else + return mulMatImpl<1, 3>( result, a, b, pfor ); + } + else + { + if( a.ne[ 1 ] >= 16 ) + return mulMatImpl<2, 4>( result, a, b, pfor ); + else + return mulMatImpl<1, 4>( result, a, b, pfor ); + } +} \ No newline at end of file -- cgit v1.2.3