summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/DecoderTensors.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/CPU/DecoderTensors.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/CPU/DecoderTensors.cpp')
-rw-r--r--Whisper/CPU/DecoderTensors.cpp68
1 files changed, 68 insertions, 0 deletions
diff --git a/Whisper/CPU/DecoderTensors.cpp b/Whisper/CPU/DecoderTensors.cpp
new file mode 100644
index 0000000..22de476
--- /dev/null
+++ b/Whisper/CPU/DecoderTensors.cpp
@@ -0,0 +1,68 @@
+#include "stdafx.h"
+#include "DecoderTensors.h"
+using namespace CpuCompute;
+
+#if TENSOR_GGML_COMPAT
+namespace
+{
+ class CompatContext
+ {
+ std::vector<ggml_tensor>& vec;
+ size_t index;
+
+ public:
+ CompatContext( std::vector<ggml_tensor>& dest, size_t layers ) :
+ vec( dest )
+ {
+ constexpr size_t tensorsPerLayer = 21;
+ const size_t count = tensorsPerLayer * layers + 4;
+ vec.resize( count );
+ index = 0;
+ }
+
+ void add( const Tensor& rsi, ggml_tensor*& res )
+ {
+ ggml_tensor& ten = vec[ index ];
+ index++;
+ ten = rsi.ggml();
+ res = &ten;
+ }
+
+ void add2( const TensorPair& rsi, ggml_tensor*& w, ggml_tensor*& b )
+ {
+ add( rsi.w, w );
+ add( rsi.b, b );
+ }
+
+ bool isComplete() const
+ {
+ return index == vec.size();
+ }
+ };
+}
+
+void DecoderTensors::makeCompatTensors()
+{
+ CompatContext ctx( ggml, layers.size() );
+
+ ctx.add( positionalEmbedding, d_pe );
+ ctx.add( tokenEmbedding, d_te );
+ ctx.add2( ln, d_ln_w, d_ln_b );
+
+ for( auto& i : layers )
+ {
+ ctx.add2( i.attnLn0, i.attn_ln_0_w, i.attn_ln_0_b );
+ ctx.add2( i.attnLn1, i.attn_ln_1_w, i.attn_ln_1_b );
+ ctx.add2( i.attnQuery, i.attn_q_w, i.attn_q_b );
+ ctx.add( i.attnKey, i.attn_k_w );
+ ctx.add2( i.attnValue, i.attn_v_w, i.attn_v_b );
+ ctx.add2( i.crossAttnLn0, i.cross_attn_ln_0_w, i.cross_attn_ln_0_b );
+ ctx.add2( i.crossAttnLn1, i.cross_attn_ln_1_w, i.cross_attn_ln_1_b );
+ ctx.add2( i.crossAttnQuery, i.cross_attn_q_w, i.cross_attn_q_b );
+ ctx.add2( i.mlpLn, i.mlp_ln_w, i.mlp_ln_b );
+ ctx.add2( i.mlp0, i.mlp_0_w, i.mlp_0_b );
+ ctx.add2( i.mlp1, i.mlp_1_w, i.mlp_1_b );
+ }
+ assert( ctx.isComplete() );
+}
+#endif \ No newline at end of file