summaryrefslogtreecommitdiffstats
path: root/WhisperNet/Context.cs
blob: 3170b893123884ee06aa16c1499b66a8a75f1070 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
using System.Diagnostics;
using Whisper.Internal;
using Whisper.Internals;

namespace Whisper
{
	/// <summary>Stateful context, contains methods to transcribe audio</summary>
	public sealed class Context: IDisposable
	{
		iContext context;
		// Caching the results object here saves time spent in ComLight library creating these callable proxies over and over again for the same underlying C++ object
		readonly iTranscribeResult transcribeResult;
		sFullParams fullParams;
		sProgressSink progressSink;
		bool disposed = false;
		readonly Action<object> pfnBuffer, pfnStream;

		internal Context( Internal.iContext context )
		{
			this.context = context;
			transcribeResult = context.getResults( eResultFlags.None );
			fullParams = context.fullDefaultParams( eSamplingStrategy.Greedy );
			pfnBuffer = processBuffer;
			pfnStream = processStream;
			progressSink = default;
		}

		void IDisposable.Dispose()
		{
			if( disposed )
				return;
			disposed = true;
			context?.Dispose();
			GC.SuppressFinalize( this );
		}

		/// <summary>Adjustable parameters</summary>
		public ref Parameters parameters => ref fullParams.publicParams;

		void processBuffer( object buffer )
		{
			context.runFull( ref fullParams, (iAudioBuffer)buffer );
		}
		void processStream( object reader )
		{
			context.runStreamed( ref fullParams, ref progressSink, (iAudioReader)reader );
		}

		void runImpl( object source, Callbacks? callbacks, ReadOnlySpan<int> promptTokens, Action<object> pfn )
		{
			if( null != callbacks )
			{
				// TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call.
				// Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object
				fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData )
				{
					return callbacks.newSegment( this, countNew );
				};

				fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData )
				{
					return callbacks.encoderBegin( this );
				};
			}

			try
			{
				if( promptTokens.IsEmpty )
				{
					pfn( source );
					return;
				}
				unsafe
				{
					fixed( int* tokens = promptTokens )
					{
						fullParams.prompt_tokens = (IntPtr)tokens;
						fullParams.prompt_n_tokens = promptTokens.Length;
						pfn( source );
					}
				}
			}
			finally
			{
				// Reset these delegates.
				// Otherwise, this class will retain the callbacks object preventing it from being garbage collected.
				fullParams.newSegmentCallback = null;
				fullParams.encoderBeginCallback = null;

				fullParams.prompt_tokens = IntPtr.Zero;
				fullParams.prompt_n_tokens = 0;
			}
		}

		/// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
		public void runFull( iAudioBuffer buffer, Callbacks? callbacks, ReadOnlySpan<int> promptTokens )
		{
			runImpl( buffer, callbacks, promptTokens, pfnBuffer );
		}
		/// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
		public void runFull( iAudioBuffer buffer, Callbacks? callbacks = null ) =>
			runFull( buffer, callbacks, ReadOnlySpan<int>.Empty );
		/// <summary>Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text</summary>
		public void runFull( iAudioBuffer buffer, Callbacks? callbacks, int[]? promptTokens ) =>
			runFull( buffer, callbacks, promptTokens ?? ReadOnlySpan<int>.Empty );

		/// <summary>Run the entire model, streaming audio from the provided reader object</summary>
		public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, ReadOnlySpan<int> promptTokens )
		{
			if( null != pfnProgress )
			{
				progressSink.pfn = delegate ( double value, IntPtr context, IntPtr pv )
				{
					try
					{
						pfnProgress.Invoke( value );
						return 0;
					}
					catch( Exception ex )
					{
						return ex.HResult;
					}
				};
			}
			try
			{
				runImpl( reader, callbacks, promptTokens, pfnStream );
			}
			finally
			{
				progressSink.pfn = null;
			}
		}

		/// <summary>Run the entire model, streaming audio from the provided reader object</summary>
		public void runFull( iAudioReader reader, Action<double>? pfnProgress = null, Callbacks? callbacks = null ) =>
			runFull( reader, callbacks, pfnProgress, ReadOnlySpan<int>.Empty );

		/// <summary>Run the entire model, streaming audio from the provided reader object</summary>
		public void runFull( iAudioReader reader, Callbacks? callbacks, Action<double>? pfnProgress, int[]? promptTokens ) =>
			runFull( reader, callbacks, pfnProgress, promptTokens ?? ReadOnlySpan<int>.Empty );

		/// <summary>Get text results out of the context</summary>
		public TranscribeResult results( eResultFlags flags = eResultFlags.None )
		{
			if( flags.HasFlag( eResultFlags.NewObject ) )
				throw new ArgumentException();

			iTranscribeResult res = context.getResults( flags );
			Debug.Assert( ReferenceEquals( res, transcribeResult ) );
			return new TranscribeResult( res );
		}

		/// <summary>Print timing data</summary>
		public void timingsPrint() => context.timingsPrint();

		/// <summary>Reset timing data</summary>
		public void timingsReset() => context.timingsReset();

		/// <summary>Continuously process audio from microphone or a similar capture device</summary>
		/// <remarks>It’s recommended to call this method on a background thread.</remarks>
		public void runCapture( iAudioCapture capture, Callbacks? callbacks, CaptureCallbacks? captureCallbacks )
		{
			if( null != callbacks )
			{
				// TODO [very low, performance]: the following code creates 2 new GC-allocated objects on each call.
				// Possible to optimize by caching these function pointers in static readonly fields, and use another [ThreadStatic] field for the callbacks object
				fullParams.newSegmentCallback = delegate ( IntPtr ctx, int countNew, IntPtr userData )
				{
					return callbacks.newSegment( this, countNew );
				};

				fullParams.encoderBeginCallback = delegate ( IntPtr ctx, IntPtr userData )
				{
					return callbacks.encoderBegin( this );
				};
			}

			try
			{
				sCaptureCallbacks cc = default;
				if( captureCallbacks != null )
				{
					cc.shouldCancel = captureCallbacks.cancel( this );
					cc.captureStatus = captureCallbacks.status( this );
				}
				context.runCapture( ref fullParams, ref cc, capture );
			}
			finally
			{
				// Reset these delegates.
				// Otherwise, this class will retain the callbacks object preventing it from being garbage collected.
				fullParams.newSegmentCallback = null;
				fullParams.encoderBeginCallback = null;

				fullParams.prompt_tokens = IntPtr.Zero;
				fullParams.prompt_n_tokens = 0;
			}
		}

		/// <summary>Try to detect speaker by comparing channels of the stereo PCM data</summary>
		/// <remarks>
		/// <para>The feature requires stereo PCM data.<br/>Pass <c>stereo=true</c> to <see cref="iMediaFoundation.loadAudioFile" /> or <see cref="iMediaFoundation.openAudioFile"/> methods,<br/>
		/// or <see cref="eCaptureFlags.Stereo" /> to <see cref="iMediaFoundation.openCaptureDevice" /> method.</para>
		/// <para>It seems to work fine with <a href="https://www.bluemic.com/en-us/products/yeti/">Blue Yeti</a> microphone,
		/// after switched the microphone to Stereo pattern.<br/> With recorded sounds however, the performance varies depending on the recording.</para>
		/// </remarks>
		public eSpeakerChannel detectSpeaker( sTimeInterval interval ) =>
			context.detectSpeaker( ref interval );
	}
}