summaryrefslogtreecommitdiffstats
path: root/Examples/main
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Examples/main
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Examples/main')
-rw-r--r--Examples/main/main.cpp315
-rw-r--r--Examples/main/main.vcxproj93
-rw-r--r--Examples/main/main.vcxproj.filters12
-rw-r--r--Examples/main/miscUtils.cpp48
-rw-r--r--Examples/main/miscUtils.h9
-rw-r--r--Examples/main/params.cpp101
-rw-r--r--Examples/main/params.h38
7 files changed, 616 insertions, 0 deletions
diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp
new file mode 100644
index 0000000..c9eacf2
--- /dev/null
+++ b/Examples/main/main.cpp
@@ -0,0 +1,315 @@
+#include "params.h"
+#include "../../Whisper/API/iContext.cl.h"
+#include "../../Whisper/API/iMediaFoundation.cl.h"
+#include "../../ComLightLib/comLightClient.h"
+#include "miscUtils.h"
+#include <array>
+#include <atomic>
+using namespace Whisper;
+
+#define STREAM_AUDIO 1
+
+static HRESULT loadWhisperModel( const wchar_t* path, iModel** pp )
+{
+ using namespace Whisper;
+ constexpr eModelImplementation impl = eModelImplementation::GPU;
+ // constexpr eModelImplementation impl = eModelImplementation::Reference;
+ return Whisper::loadModel( path, impl, nullptr, pp );
+}
+
+namespace
+{
+ struct sPrintUserData
+ {
+ const whisper_params* params;
+ // const std::vector<std::vector<float>>* pcmf32s;
+ };
+
+ // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
+ // Lowest is red, middle is yellow, highest is green.
+ static const std::array<const char*, 10> k_colors =
+ {
+ "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
+ "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
+ };
+
+ std::string to_timestamp( sTimeSpan ts, bool comma = false )
+ {
+ sTimeSpanFields fields = ts;
+ uint32_t msec = fields.ticks / 10'000;
+ uint32_t hr = fields.days * 24 + fields.hours;
+ uint32_t min = fields.minutes;
+ uint32_t sec = fields.seconds;
+
+ char buf[ 32 ];
+ snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", hr, min, sec, comma ? "," : ".", msec );
+ return std::string( buf );
+ }
+
+ static int colorIndex( const sToken& tok )
+ {
+ const float p = tok.probability;
+ const float p3 = p * p * p;
+ int col = (int)( p3 * float( k_colors.size() ) );
+ col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) );
+ return col;
+ }
+
+ HRESULT __cdecl newSegmentCallback( iContext* context, uint32_t n_new, void* user_data ) noexcept
+ {
+ ComLight::CComPtr<iTranscribeResult> results;
+ CHECK( context->getResults( eResultFlags::Timestamps | eResultFlags::Tokens, &results ) );
+
+ sTranscribeLength length;
+ CHECK( results->getSize( length ) );
+
+ const whisper_params& params = *( (sPrintUserData*)user_data )->params;
+ // const std::vector<std::vector<float>>& pcmf32s = *( (sPrintUserData*)user_data )->pcmf32s;
+
+ // print the last n_new segments
+ const uint32_t s0 = length.countSegments - n_new;
+ if( s0 == 0 )
+ printf( "\n" );
+
+ const sSegment* const segments = results->getSegments();
+ const sToken* const tokens = results->getTokens();
+
+ for( uint32_t i = s0; i < length.countSegments; i++ )
+ {
+ const sSegment& seg = segments[ i ];
+
+ if( params.no_timestamps )
+ {
+ if( params.print_colors )
+ {
+ for( uint32_t j = 0; j < seg.countTokens; j++ )
+ {
+ const sToken& tok = tokens[ seg.firstToken + j ];
+ if( !params.print_special && ( tok.flags & eTokenFlags::Special ) )
+ continue;
+ wprintf( L"%S%s%S", k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" );
+ }
+ }
+ else
+ wprintf( L"%s", utf16( seg.text ).c_str() );
+ fflush( stdout );
+ continue;
+ }
+
+ std::string speaker = "";
+#if 0
+ if( params.diarize && pcmf32s.size() == 2 )
+ {
+ const size_t n_samples = pcmf32s[ 0 ].size();
+ const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples );
+ const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples );
+
+ double energy0 = 0.0f;
+ double energy1 = 0.0f;
+
+ for( int64_t j = is0; j < is1; j++ )
+ {
+ energy0 += fabs( pcmf32s[ 0 ][ j ] );
+ energy1 += fabs( pcmf32s[ 1 ][ j ] );
+ }
+
+ if( energy0 > 1.1 * energy1 )
+ speaker = "(speaker 0)";
+ else if( energy1 > 1.1 * energy0 )
+ speaker = "(speaker 1)";
+ else
+ speaker = "(speaker ?)";
+
+ //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
+ }
+#endif
+
+ if( params.print_colors )
+ {
+ printf( "[%s --> %s] ", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str() );
+
+ for( uint32_t j = 0; j < seg.countTokens; j++ )
+ {
+ const sToken& tok = tokens[ seg.firstToken + j ];
+ if( !params.print_special && ( tok.flags & eTokenFlags::Special ) )
+ continue;
+ wprintf( L"%S%S%s%S", speaker.c_str(), k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" );
+ }
+ printf( "\n" );
+ }
+ else
+ wprintf( L"[%S --> %S] %S%s\n", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str(), speaker.c_str(), utf16( seg.text ).c_str() );
+ }
+ return S_OK;
+ }
+
+ HRESULT __cdecl beginSegmentCallback( iContext* context, void* user_data ) noexcept
+ {
+ std::atomic_bool* flag = (std::atomic_bool*)user_data;
+ bool aborted = flag->load();
+ return aborted ? S_FALSE : S_OK;
+ }
+
+ HRESULT setupConsoleColors()
+ {
+ HANDLE h = GetStdHandle( STD_OUTPUT_HANDLE );
+ if( h == INVALID_HANDLE_VALUE )
+ return HRESULT_FROM_WIN32( GetLastError() );
+
+ DWORD mode = 0;
+ if( !GetConsoleMode( h, &mode ) )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ if( 0 != ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) )
+ return S_FALSE;
+
+ mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if( !SetConsoleMode( h, mode ) )
+ return HRESULT_FROM_WIN32( GetLastError() );
+ return S_OK;
+ }
+}
+
+int wmain( int argc, wchar_t* argv[] )
+{
+ // Whisper::dbgCompareTraces( LR"(C:\Temp\2remove\Whisper\ref.bin)", LR"(C:\Temp\2remove\Whisper\gpu.bin )" ); return 0;
+
+ // Tell logger to use the standard output stream for the messages
+ {
+ Whisper::sLoggerSetup logSetup;
+ logSetup.flags = eLoggerFlags::UseStandardError;
+ logSetup.level = eLogLevel::Debug;
+ Whisper::setupLogger( logSetup );
+ }
+
+ whisper_params params;
+ if( !params.parse( argc, argv ) )
+ return 1;
+
+ if( params.print_colors )
+ {
+ if( FAILED( setupConsoleColors() ) )
+ params.print_colors = false;
+ }
+
+ if( params.fname_inp.empty() )
+ {
+ fprintf( stderr, "error: no input files specified\n" );
+ whisper_print_usage( argc, argv, params );
+ return 2;
+ }
+
+ if( Whisper::findLanguageKeyA( params.language.c_str() ) == UINT_MAX )
+ {
+ fprintf( stderr, "error: unknown language '%s'\n", params.language.c_str() );
+ whisper_print_usage( argc, argv, params );
+ return 3;
+ }
+
+ ComLight::CComPtr<iModel> model;
+ HRESULT hr = loadWhisperModel( params.model.c_str(), &model );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to load the model", hr );
+ return 4;
+ }
+
+ ComLight::CComPtr<iContext> context;
+ hr = model->createContext( &context );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to initialize whisper context", hr );
+ return 5;
+ }
+
+ ComLight::CComPtr<iMediaFoundation> mf;
+ hr = initMediaFoundation( &mf );
+ if( FAILED( hr ) )
+ {
+ printError( "failed to initialize Media Foundation runtime", hr );
+ return 5;
+ }
+
+ for( const std::wstring& fname : params.fname_inp )
+ {
+ // print some info about the processing
+ {
+ if( model->isMultilingual() == S_FALSE )
+ {
+ if( params.language != "en" || params.translate )
+ {
+ params.language = "en";
+ params.translate = false;
+ fprintf( stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__ );
+ }
+ }
+
+ /*
+ fwprintf( stderr, L"%S: processing '%s' (%zu samples, %.1f sec), %d threads, %d processors, lang = %S, task = %S, timestamps = %d ...\n",
+ __func__, fname.c_str(), audio.pcmf32.size(), audio.seconds(),
+ params.n_threads, params.n_processors,
+ params.language.c_str(),
+ params.translate ? "translate" : "transcribe",
+ params.no_timestamps ? 0 : 1 );
+ */
+ }
+
+ // run the inference
+ Whisper::sFullParams wparams;
+ context->fullDefaultParams( eSamplingStrategy::Greedy, &wparams );
+
+ wparams.resetFlag( eFullParamsFlags::PrintRealtime | eFullParamsFlags::PrintProgress );
+ wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps );
+ wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special );
+ wparams.setFlag( eFullParamsFlags::Translate, params.translate );
+ wparams.language = Whisper::makeLanguageKey( params.language.c_str() );
+ wparams.cpuThreads = params.n_threads;
+ if( params.max_context != UINT_MAX )
+ wparams.n_max_text_ctx = params.max_context;
+ wparams.offset_ms = params.offset_t_ms;
+ wparams.duration_ms = params.duration_ms;
+
+ wparams.setFlag( eFullParamsFlags::TokenTimestamps, params.output_wts || params.max_len > 0 );
+ wparams.thold_pt = params.word_thold;
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
+
+ wparams.setFlag( eFullParamsFlags::SpeedupAudio, params.speed_up );
+ // sPrintUserData user_data = { &params, &audio.pcmf32s };
+ sPrintUserData user_data = { &params };
+
+ // this callback is called on each new segment
+ if( !wparams.flag( eFullParamsFlags::PrintRealtime ) )
+ {
+ wparams.new_segment_callback = &newSegmentCallback;
+ wparams.new_segment_callback_user_data = &user_data;
+ }
+
+ // example for abort mechanism
+ // in this example, we do not abort the processing, but we could if the flag is set to true
+ // the callback is called before every encoder run - if it returns false, the processing is aborted
+ std::atomic_bool is_aborted = false;
+ {
+ wparams.encoder_begin_callback = &beginSegmentCallback;
+ wparams.encoder_begin_callback_user_data = &is_aborted;
+ }
+
+#if STREAM_AUDIO
+ ComLight::CComPtr<iAudioReader> reader;
+ CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) );
+ sProgressSink progressSink{ nullptr, nullptr };
+ hr = context->runStreamed( wparams, progressSink, reader );
+#else
+ ComLight::CComPtr<iAudioBuffer> buffer;
+ CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) );
+ hr = context->runFull( wparams, buffer );
+#endif
+ if( FAILED( hr ) )
+ {
+ fwprintf( stderr, L"%s: failed to process audio\n", argv[ 0 ] );
+ return 10;
+ }
+ }
+
+ context->timingsPrint();
+ context = nullptr;
+ return 0;
+} \ No newline at end of file
diff --git a/Examples/main/main.vcxproj b/Examples/main/main.vcxproj
new file mode 100644
index 0000000..4945b88
--- /dev/null
+++ b/Examples/main/main.vcxproj
@@ -0,0 +1,93 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{4cca7042-eb15-4f7a-b77b-5cafd2df47b2}</ProjectGuid>
+ <RootNamespace>main</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NOMINMAX;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NOMINMAX;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="main.cpp" />
+ <ClCompile Include="miscUtils.cpp" />
+ <ClCompile Include="params.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="miscUtils.h" />
+ <ClInclude Include="params.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\..\Whisper\Whisper.vcxproj">
+ <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project>
+ </ProjectReference>
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/main/main.vcxproj.filters b/Examples/main/main.vcxproj.filters
new file mode 100644
index 0000000..94cd8a1
--- /dev/null
+++ b/Examples/main/main.vcxproj.filters
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <ClCompile Include="main.cpp" />
+ <ClCompile Include="params.cpp" />
+ <ClCompile Include="miscUtils.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="params.h" />
+ <ClInclude Include="miscUtils.h" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/Examples/main/miscUtils.cpp b/Examples/main/miscUtils.cpp
new file mode 100644
index 0000000..3ebda20
--- /dev/null
+++ b/Examples/main/miscUtils.cpp
@@ -0,0 +1,48 @@
+#include "miscUtils.h"
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+
+std::string utf8( const std::wstring& utf16 )
+{
+ int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr );
+ std::string str( count, 0 );
+ WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr );
+ return str;
+}
+
+std::wstring utf16( const std::string& u8 )
+{
+ int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 );
+ std::wstring str( count, 0 );
+ MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count );
+ return str;
+}
+
+namespace
+{
+ wchar_t* formatMessage( HRESULT hr )
+ {
+ wchar_t* err;
+ if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL,
+ hr,
+ MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
+ (LPTSTR)&err,
+ 0,
+ NULL ) )
+ return err;
+ return nullptr;
+ }
+}
+
+void printError( const char* what, HRESULT hr )
+{
+ const wchar_t* err = formatMessage( hr );
+ if( nullptr != err )
+ {
+ fwprintf( stderr, L"%S: %s\n", what, err );
+ LocalFree( (HLOCAL)err );
+ }
+ else
+ fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr );
+} \ No newline at end of file
diff --git a/Examples/main/miscUtils.h b/Examples/main/miscUtils.h
new file mode 100644
index 0000000..52770a6
--- /dev/null
+++ b/Examples/main/miscUtils.h
@@ -0,0 +1,9 @@
+#pragma once
+#include <string>
+
+std::string utf8( const std::wstring& utf16 );
+
+std::wstring utf16( const std::string& u8 );
+
+using HRESULT = long;
+void printError( const char* what, HRESULT hr ); \ No newline at end of file
diff --git a/Examples/main/params.cpp b/Examples/main/params.cpp
new file mode 100644
index 0000000..ff1cfdd
--- /dev/null
+++ b/Examples/main/params.cpp
@@ -0,0 +1,101 @@
+#include "params.h"
+#include <algorithm>
+#include <thread>
+#include "miscUtils.h"
+
+whisper_params::whisper_params()
+{
+#ifdef _DEBUG
+ n_threads = 2;
+#else
+ n_threads = std::min( 4u, std::thread::hardware_concurrency() );
+#endif
+}
+
+namespace
+{
+ const char* cstr( bool b )
+ {
+ return b ? "true" : "false";
+ }
+}
+
+void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params )
+{
+ fprintf( stderr, "\n" );
+ fprintf( stderr, "usage: %S [options] file0.wav file1.wav ...\n", argv[ 0 ] );
+ fprintf( stderr, "\n" );
+ fprintf( stderr, "options:\n" );
+ fprintf( stderr, " -h, --help [default] show this help message and exit\n" );
+ fprintf( stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads );
+ fprintf( stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors );
+ fprintf( stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms );
+ fprintf( stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n );
+ fprintf( stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms );
+ fprintf( stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context );
+ fprintf( stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len );
+ fprintf( stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold );
+ fprintf( stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", cstr( params.speed_up ) );
+ fprintf( stderr, " -tr, --translate [%-7s] translate from source language to english\n", cstr( params.translate ) );
+ fprintf( stderr, " -di, --diarize [%-7s] stereo audio diarization\n", cstr( params.diarize ) );
+ fprintf( stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", cstr( params.output_txt ) );
+ fprintf( stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", cstr( params.output_vtt ) );
+ fprintf( stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", cstr( params.output_srt ) );
+ fprintf( stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", cstr( params.output_wts ) );
+ fprintf( stderr, " -ps, --print-special [%-7s] print special tokens\n", cstr( params.print_special ) );
+ fprintf( stderr, " -nc, --no-colors [%-7s] do not print colors\n", cstr( !params.print_colors ) );
+ fprintf( stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", cstr( params.no_timestamps ) );
+ fprintf( stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str() );
+ fprintf( stderr, " -m FNAME, --model FNAME [%-7S] model path\n", params.model.c_str() );
+ fprintf( stderr, " -f FNAME, --file FNAME [%-7s] path of the input audio file\n", "" );
+ fprintf( stderr, "\n" );
+}
+
+bool whisper_params::parse( int argc, wchar_t* argv[] )
+{
+ for( int i = 1; i < argc; i++ )
+ {
+ std::wstring arg = argv[ i ];
+
+ if( arg[ 0 ] != '-' )
+ {
+ fname_inp.push_back( arg );
+ continue;
+ }
+
+ if( arg == L"-h" || arg == L"--help" )
+ {
+ whisper_print_usage( argc, argv, *this );
+ return false;
+ }
+
+ else if( arg == L"-t" || arg == L"--threads" ) { n_threads = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-p" || arg == L"--processors" ) { n_processors = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-ot" || arg == L"--offset-t" ) { offset_t_ms = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-on" || arg == L"--offset-n" ) { offset_n = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-d" || arg == L"--duration" ) { duration_ms = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-mc" || arg == L"--max-context" ) { max_context = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-ml" || arg == L"--max-len" ) { max_len = std::stoul( argv[ ++i ] ); }
+ else if( arg == L"-wt" || arg == L"--word-thold" ) { word_thold = std::stof( argv[ ++i ] ); }
+ else if( arg == L"-su" || arg == L"--speed-up" ) { speed_up = true; }
+ else if( arg == L"-tr" || arg == L"--translate" ) { translate = true; }
+ else if( arg == L"-di" || arg == L"--diarize" ) { diarize = true; }
+ else if( arg == L"-otxt" || arg == L"--output-txt" ) { output_txt = true; }
+ else if( arg == L"-ovtt" || arg == L"--output-vtt" ) { output_vtt = true; }
+ else if( arg == L"-osrt" || arg == L"--output-srt" ) { output_srt = true; }
+ else if( arg == L"-owts" || arg == L"--output-words" ) { output_wts = true; }
+ else if( arg == L"-ps" || arg == L"--print-special" ) { print_special = true; }
+ else if( arg == L"-nc" || arg == L"--no-colors" ) { print_colors = false; }
+ else if( arg == L"-nt" || arg == L"--no-timestamps" ) { no_timestamps = true; }
+ else if( arg == L"-l" || arg == L"--language" ) { language = utf8( argv[ ++i ] ); }
+ else if( arg == L"-m" || arg == L"--model" ) { model = argv[ ++i ]; }
+ else if( arg == L"-f" || arg == L"--file" ) { fname_inp.push_back( argv[ ++i ] ); }
+ else
+ {
+ fprintf( stderr, "error: unknown argument: %S\n", arg.c_str() );
+ whisper_print_usage( argc, argv, *this );
+ return false;
+ }
+ }
+ return true;
+} \ No newline at end of file
diff --git a/Examples/main/params.h b/Examples/main/params.h
new file mode 100644
index 0000000..9eb2b04
--- /dev/null
+++ b/Examples/main/params.h
@@ -0,0 +1,38 @@
+#pragma once
+#include <vector>
+#include <string>
+
+// command-line parameters
+struct whisper_params
+{
+ uint32_t n_threads;
+ uint32_t n_processors = 1;
+ uint32_t offset_t_ms = 0;
+ uint32_t offset_n = 0;
+ uint32_t duration_ms = 0;
+ uint32_t max_context = UINT_MAX;
+ uint32_t max_len = 0;
+
+ float word_thold = 0.01f;
+
+ bool speed_up = false;
+ bool translate = false;
+ bool diarize = false;
+ bool output_txt = false;
+ bool output_vtt = false;
+ bool output_srt = false;
+ bool output_wts = false;
+ bool print_special = false;
+ bool print_colors = true;
+ bool no_timestamps = false;
+
+ std::string language = "en";
+ std::wstring model = L"models/ggml-base.en.bin";
+ std::vector<std::wstring> fname_inp;
+
+ whisper_params();
+
+ bool parse( int argc, wchar_t* argv[] );
+};
+
+void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ); \ No newline at end of file