summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/mulMat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Whisper/CPU/mulMat.cpp')
-rw-r--r--Whisper/CPU/mulMat.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/Whisper/CPU/mulMat.cpp b/Whisper/CPU/mulMat.cpp
new file mode 100644
index 0000000..1b6ed25
--- /dev/null
+++ b/Whisper/CPU/mulMat.cpp
@@ -0,0 +1,54 @@
+#include "stdafx.h"
+#include "mulMat.h"
+#include "mulMatImpl.h"
+using namespace CpuCompute;
+
+namespace
+{
+ template<uint8_t panelHeightRegs, uint8_t tileWidthFloats>
+ static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
+ {
+ MulMatImpl<panelHeightRegs, tileWidthFloats> impl{ result, a, b, pfor };
+ return impl.run( pfor );
+ }
+}
+
+HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor )
+{
+ if( a.type() != eDataType::FP16 )
+ return E_NOTIMPL;
+ if( b.type() != eDataType::FP32 )
+ return E_NOTIMPL;
+
+ // return mulMatImpl<1, 1>( result, a, b, pfor );
+
+ if( b.ne[ 1 ] == 1 )
+ {
+ // Multiplying by a single row
+ if( a.ne[ 1 ] >= 32 )
+ return mulMatImpl<4, 1>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 1>( result, a, b, pfor );
+ }
+ else if( b.ne[ 1 ] == 2 )
+ {
+ if( a.ne[ 1 ] >= 32 )
+ return mulMatImpl<4, 2>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 2>( result, a, b, pfor );
+ }
+ else if( b.ne[ 1 ] == 3 )
+ {
+ if( a.ne[ 1 ] >= 16 )
+ return mulMatImpl<2, 3>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 3>( result, a, b, pfor );
+ }
+ else
+ {
+ if( a.ne[ 1 ] >= 16 )
+ return mulMatImpl<2, 4>( result, a, b, pfor );
+ else
+ return mulMatImpl<1, 4>( result, a, b, pfor );
+ }
+} \ No newline at end of file