summaryrefslogtreecommitdiffstats
path: root/Whisper/API/sFullParams.h
blob: 42d48a4d5a42027282ab575b701ca022ff8856b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma once
#include <stdint.h>
#include <assert.h>

namespace Whisper
{
	// Available sampling strategies
	enum struct eSamplingStrategy : int
	{
		// Always select the most probable token
		Greedy,
		// TODO: not implemented yet!
		BeamSearch,
	};

	using pfnNewSegment = HRESULT( __cdecl* )( iContext* ctx, uint32_t n_new, void* user_data ) noexcept;

	// Return S_OK to proceed, or S_FALSE to stop the process and return S_OK from runFull / runStreamed method
	using pfnEncoderBegin = HRESULT( __cdecl* )( iContext* ctx, void* user_data ) noexcept;

	enum struct eFullParamsFlags : uint32_t
	{
		Translate = 1,
		NoContext = 2,
		SingleSegment = 4,
		PrintSpecial = 8,
		PrintProgress = 0x10,
		PrintRealtime = 0x20,
		PrintTimestamps = 0x40,

		// Experimental
		TokenTimestamps = 0x100,
		SpeedupAudio = 0x200,
	};

	inline eFullParamsFlags operator | ( eFullParamsFlags a, eFullParamsFlags b )
	{
		return (eFullParamsFlags)( (uint32_t)a | (uint32_t)b );
	}
	inline void operator |= ( eFullParamsFlags& a, eFullParamsFlags b )
	{
		a = a | b;
	}

	struct sFullParams
	{
		eSamplingStrategy strategy;
		// Count of CPU threads
		int cpuThreads;
		int n_max_text_ctx;
		int offset_ms;          // start offset in ms
		int duration_ms;        // audio duration to process in ms
		eFullParamsFlags flags;
		uint32_t language;

		// [EXPERIMENTAL] token-level timestamps
		float thold_pt;         // timestamp token probability threshold (~0.01)
		float thold_ptsum;      // timestamp token sum probability threshold (~0.01)
		int   max_len;          // max segment length in characters
		int   max_tokens;       // max tokens per segment (0 = no limit)

		struct
		{
			int n_past;
		} greedy;

		struct
		{
			int n_past;
			int beam_width;
			int n_best;
		} beam_search;

		// [EXPERIMENTAL] speed-up techniques
		int  audio_ctx;         // overwrite the audio context size (0 = use default)

		// tokens to provide the whisper model as initial prompt
		// these are prepended to any existing text context from a previous call
		const whisper_token* prompt_tokens;
		int prompt_n_tokens;

		pfnNewSegment new_segment_callback;
		void* new_segment_callback_user_data;

		pfnEncoderBegin encoder_begin_callback;
		void* encoder_begin_callback_user_data;

		// Couple utility methods, they workaround the lack of bit fields in C++
		inline bool flag( eFullParamsFlags f ) const
		{
			return 0 != ( (uint32_t)flags & (uint32_t)f );
		}
		inline void resetFlag( eFullParamsFlags bit )
		{
			uint32_t f = (uint32_t)flags;
			f &= ~(uint32_t)bit;
			flags = (eFullParamsFlags)f;
		}
		inline void setFlag( eFullParamsFlags bit, bool set = true )
		{
			uint32_t f = (uint32_t)flags;
			if( set )
				f |= (uint32_t)bit;
			else
				f &= ~(uint32_t)bit;
			flags = (eFullParamsFlags)f;
		}
	};

	struct sSegmentTime
	{
		int64_t begin, end;
	};

	inline uint32_t makeLanguageKey( const char* code )
	{
		assert( strlen( code ) <= 4 );
		uint32_t res = 0;
		uint32_t shift = 0;
		for( size_t i = 0; i < 4; i++, code++, shift += 8 )
		{
			const char c = *code;
			if( c == '\0' )
				return res;
			uint32_t u32 = (uint8_t)c;
			u32 = u32 << shift;
			res |= u32;
		}
		return res;
	}

	using pfnReportProgress = HRESULT( __stdcall* )( double val, iContext* ctx, void* pv ) noexcept;
	struct sProgressSink
	{
		pfnReportProgress pfn;
		void* pv;
	};
}