summaryrefslogtreecommitdiffstats
path: root/Whisper/CPU/DecoderTensors.h
blob: 2efa519a112abeb3132c9e47597512b3c67a44fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#pragma once
#include <vector>
#include "Tensor.h"
#include "LargeBuffer.h"
#if TENSOR_GGML_COMPAT
#include "../source/ggml.h"
#endif

namespace CpuCompute
{
	// A set of tensors for one decoder's layer
	struct LayerDecoder
	{
		// decoder.blocks.*.attn_ln
		TensorPair attnLn0;
		// decoder.blocks.*.attn.out
		TensorPair attnLn1;
		// decoder.blocks.*.attn.query
		TensorPair attnQuery;
		// decoder.blocks.*.attn.key
		Tensor attnKey;
		// decoder.blocks.*.attn.value
		TensorPair attnValue;
		// decoder.blocks.*.cross_attn_ln
		TensorPair crossAttnLn0;
		// decoder.blocks.*.cross_attn.out
		TensorPair crossAttnLn1;
		// decoder.blocks.*.cross_attn.query
		TensorPair crossAttnQuery;

		// decoder.blocks.*.cross_attn.key
		// Tensor crossAttnKey;
		// decoder.blocks.*.cross_attn.value
		// TensorPair crossAttnValue;

		// decoder.blocks.*.mlp_ln
		TensorPair mlpLn;
		// decoder.blocks.*.mlp.0
		TensorPair mlp0;
		// decoder.blocks.*.mlp.2
		TensorPair mlp1;

#if TENSOR_GGML_COMPAT
		// decoder.blocks.*.attn_ln
		ggml_tensor* attn_ln_0_w;
		ggml_tensor* attn_ln_0_b;

		// decoder.blocks.*.attn.out
		ggml_tensor* attn_ln_1_w;
		ggml_tensor* attn_ln_1_b;

		// decoder.blocks.*.attn.query
		ggml_tensor* attn_q_w;
		ggml_tensor* attn_q_b;

		// decoder.blocks.*.attn.key
		ggml_tensor* attn_k_w;

		// decoder.blocks.*.attn.value
		ggml_tensor* attn_v_w;
		ggml_tensor* attn_v_b;

		// decoder.blocks.*.cross_attn_ln
		ggml_tensor* cross_attn_ln_0_w;
		ggml_tensor* cross_attn_ln_0_b;

		// decoder.blocks.*.cross_attn.out
		ggml_tensor* cross_attn_ln_1_w;
		ggml_tensor* cross_attn_ln_1_b;

		// decoder.blocks.*.cross_attn.query
		ggml_tensor* cross_attn_q_w;
		ggml_tensor* cross_attn_q_b;

		// decoder.blocks.*.mlp_ln
		ggml_tensor* mlp_ln_w;
		ggml_tensor* mlp_ln_b;

		// decoder.blocks.*.mlp.0
		ggml_tensor* mlp_0_w;
		ggml_tensor* mlp_0_b;

		// decoder.blocks.*.mlp.2
		ggml_tensor* mlp_1_w;
		ggml_tensor* mlp_1_b;
#endif
	};

	struct DecoderTensors
	{
		// decoder.positional_embedding
		Tensor positionalEmbedding;

		// decoder.token_embedding
		Tensor tokenEmbedding;

		// decoder.ln
		TensorPair ln;
		// A vector of layers
		std::vector<LayerDecoder> layers;

		void setMemoryBuffer( LargeBuffer&& mem ) noexcept
		{
			memory = std::move( mem );
#if TENSOR_GGML_COMPAT
			makeCompatTensors();
#endif
		}

#if TENSOR_GGML_COMPAT
		void makeCompatTensors();

		// decoder.positional_embedding
		ggml_tensor* d_pe; // DD

		// decoder.token_embedding
		ggml_tensor* d_te; // DD

		// decoder.ln
		ggml_tensor* d_ln_w; // DD
		ggml_tensor* d_ln_b; // DD
#endif

	private:
		// A smart pointer which owns the memory for all the above tensors
		LargeBuffer memory;
#if TENSOR_GGML_COMPAT
		std::vector<ggml_tensor> ggml;
#endif
	};
}