summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/DecoderTensors.cpp
blob: 22de476bbdf3ad923143b9c09563e651df383f98 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "stdafx.h"
#include "DecoderTensors.h"
using namespace CpuCompute;

#if TENSOR_GGML_COMPAT
namespace
{
	class CompatContext
	{
		std::vector<ggml_tensor>& vec;
		size_t index;

	public:
		CompatContext( std::vector<ggml_tensor>& dest, size_t layers ) :
			vec( dest )
		{
			constexpr size_t tensorsPerLayer = 21;
			const size_t count = tensorsPerLayer * layers + 4;
			vec.resize( count );
			index = 0;
		}

		void add( const Tensor& rsi, ggml_tensor*& res )
		{
			ggml_tensor& ten = vec[ index ];
			index++;
			ten = rsi.ggml();
			res = &ten;
		}

		void add2( const TensorPair& rsi, ggml_tensor*& w, ggml_tensor*& b )
		{
			add( rsi.w, w );
			add( rsi.b, b );
		}

		bool isComplete() const
		{
			return index == vec.size();
		}
	};
}

void DecoderTensors::makeCompatTensors()
{
	CompatContext ctx( ggml, layers.size() );

	ctx.add( positionalEmbedding, d_pe );
	ctx.add( tokenEmbedding, d_te );
	ctx.add2( ln, d_ln_w, d_ln_b );

	for( auto& i : layers )
	{
		ctx.add2( i.attnLn0, i.attn_ln_0_w, i.attn_ln_0_b );
		ctx.add2( i.attnLn1, i.attn_ln_1_w, i.attn_ln_1_b );
		ctx.add2( i.attnQuery, i.attn_q_w, i.attn_q_b );
		ctx.add( i.attnKey, i.attn_k_w );
		ctx.add2( i.attnValue, i.attn_v_w, i.attn_v_b );
		ctx.add2( i.crossAttnLn0, i.cross_attn_ln_0_w, i.cross_attn_ln_0_b );
		ctx.add2( i.crossAttnLn1, i.cross_attn_ln_1_w, i.cross_attn_ln_1_b );
		ctx.add2( i.crossAttnQuery, i.cross_attn_q_w, i.cross_attn_q_b );
		ctx.add2( i.mlpLn, i.mlp_ln_w, i.mlp_ln_b );
		ctx.add2( i.mlp0, i.mlp_0_w, i.mlp_0_b );
		ctx.add2( i.mlp1, i.mlp_1_w, i.mlp_1_b );
	}
	assert( ctx.isComplete() );
}
#endif