summaryrefslogtreecommitdiffstats
path: root/WhisperNet/Library.cs
blob: ef10666920d937274d304668f0e78320e8fbb234 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
using ComLight;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics.X86;
using Whisper.Internal;

namespace Whisper
{
	/// <summary>Factory methods implemented by the C++ DLL</summary>
	public static class Library
	{
		static Library()
		{
			if( Environment.OSVersion.Platform != PlatformID.Win32NT )
				throw new ApplicationException( "This library requires Windows OS" );
			if( !Environment.Is64BitProcess )
				throw new ApplicationException( "This library only works in 64-bit processes" );
			if( RuntimeInformation.ProcessArchitecture != Architecture.X64 )
				throw new ApplicationException( "This library requires a processor with AMD64 instruction set" );
			if( !Sse41.IsSupported )
				throw new ApplicationException( "This library requires a CPU with SSE 4.1 support" );
			NativeLogger.startup();
		}

		const string dll = "Whisper.dll";

		[DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = false )]
		internal static extern void setupLogger( [In] ref sLoggerSetup setup );

		[DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
		static extern int loadModel( [MarshalAs( UnmanagedType.LPWStr )] string path, eModelImplementation impl, eGpuModelFlags flags,
			[In] ref sLoadModelCallbacks callbacks,
			[MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iModel> ) )] out iModel model );

		/// <summary>Load Whisper model from GGML file on disk</summary>
		/// <remarks>Models are large, depending on user’s disk speed this might take a while, and this function blocks the calling thread.<br/>
		/// Consider <see cref="loadModelAsync" /> instead.</remarks>
		/// <seealso href="https://huggingface.co/datasets/ggerganov/whisper.cpp" />
		public static iModel loadModel( string path, eGpuModelFlags flags = eGpuModelFlags.None, eModelImplementation impl = eModelImplementation.GPU )
		{
			iModel model;
			sLoadModelCallbacks callbacks = default;
			NativeLogger.prologue();
			int hr = loadModel( path, impl, flags, ref callbacks, out model );
			NativeLogger.throwForHR( hr );
			return model;
		}

		/// <summary>Load Whisper model on a background thread, with optional progress reporting and cancellation</summary>
		public static Task<iModel> loadModelAsync( string path, CancellationToken cancelToken, eGpuModelFlags flags = eGpuModelFlags.None, Action<double>? pfnProgress = null, eModelImplementation impl = eModelImplementation.GPU )
		{
			TaskCompletionSource<iModel> tcs = new TaskCompletionSource<iModel>();

			WaitCallback wcb = delegate ( object? state )
			{
				try
				{
					sLoadModelCallbacks callbacks = new sLoadModelCallbacks( cancelToken, pfnProgress );

					iModel model;
					NativeLogger.prologue();
					int hr = loadModel( path, impl, flags, ref callbacks, out model );
					NativeLogger.throwForHR( hr );

					tcs.SetResult( model );
				}
				catch( Exception ex )
				{
					tcs.SetException( ex );
				}
			};

			ThreadPool.QueueUserWorkItem( wcb );
			return tcs.Task;
		}

		[DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
		static extern int initMediaFoundation( [MarshalAs( UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof( Marshaler<iMediaFoundation> ) )] out iMediaFoundation mf );

		/// <summary>Initialize Media Foundation runtime</summary>
		public static iMediaFoundation initMediaFoundation()
		{
			iMediaFoundation mf;
			NativeLogger.prologue();
			int hr = initMediaFoundation( out mf );
			NativeLogger.throwForHR( hr );
			return mf;
		}

		// The .NET runtime uses UTF-16 for the strings, so we only need the Unicode version of this function.
		// The native DLL exports both Unicode and ASCII versions.
		[DllImport( dll, CallingConvention = RuntimeClass.defaultCallingConvention, PreserveSig = true )]
		static extern uint findLanguageKeyW( [MarshalAs( UnmanagedType.LPWStr )] string lang );

		/// <summary>Try to resolve language code string like <c>"en"</c>, <c>"pl"</c> or <c>"uk"</c> into the strongly-typed enum.</summary>
		/// <remarks>The function is case-sensitive, <c>"EN"</c> or <c>"UK"</c> gonna fail.</remarks>
		public static eLanguage? languageFromCode( string lang )
		{
			uint key = findLanguageKeyW( lang );
			if( key != uint.MaxValue )
				return (eLanguage)key;
			return null;
		}

		/// <summary>Set up delegate to receive log messages from the C++ library</summary>
		public static void setLogSink( eLogLevel lvl, eLoggerFlags flags = eLoggerFlags.SkipFormatMessage, pfnLogMessage? pfn = null )
		{
			NativeLogger.setup( lvl, flags, pfn );
		}
	}
}