diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /Examples/main | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'Examples/main')
| -rw-r--r-- | Examples/main/main.cpp | 315 | ||||
| -rw-r--r-- | Examples/main/main.vcxproj | 93 | ||||
| -rw-r--r-- | Examples/main/main.vcxproj.filters | 12 | ||||
| -rw-r--r-- | Examples/main/miscUtils.cpp | 48 | ||||
| -rw-r--r-- | Examples/main/miscUtils.h | 9 | ||||
| -rw-r--r-- | Examples/main/params.cpp | 101 | ||||
| -rw-r--r-- | Examples/main/params.h | 38 |
7 files changed, 616 insertions, 0 deletions
diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp new file mode 100644 index 0000000..c9eacf2 --- /dev/null +++ b/Examples/main/main.cpp @@ -0,0 +1,315 @@ +#include "params.h" +#include "../../Whisper/API/iContext.cl.h" +#include "../../Whisper/API/iMediaFoundation.cl.h" +#include "../../ComLightLib/comLightClient.h" +#include "miscUtils.h" +#include <array> +#include <atomic> +using namespace Whisper; + +#define STREAM_AUDIO 1 + +static HRESULT loadWhisperModel( const wchar_t* path, iModel** pp ) +{ + using namespace Whisper; + constexpr eModelImplementation impl = eModelImplementation::GPU; + // constexpr eModelImplementation impl = eModelImplementation::Reference; + return Whisper::loadModel( path, impl, nullptr, pp ); +} + +namespace +{ + struct sPrintUserData + { + const whisper_params* params; + // const std::vector<std::vector<float>>* pcmf32s; + }; + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + static const std::array<const char*, 10> k_colors = + { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", + }; + + std::string to_timestamp( sTimeSpan ts, bool comma = false ) + { + sTimeSpanFields fields = ts; + uint32_t msec = fields.ticks / 10'000; + uint32_t hr = fields.days * 24 + fields.hours; + uint32_t min = fields.minutes; + uint32_t sec = fields.seconds; + + char buf[ 32 ]; + snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", hr, min, sec, comma ? "," : ".", msec ); + return std::string( buf ); + } + + static int colorIndex( const sToken& tok ) + { + const float p = tok.probability; + const float p3 = p * p * p; + int col = (int)( p3 * float( k_colors.size() ) ); + col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) ); + return col; + } + + HRESULT __cdecl newSegmentCallback( iContext* context, uint32_t n_new, void* user_data ) noexcept + { + ComLight::CComPtr<iTranscribeResult> results; + CHECK( context->getResults( eResultFlags::Timestamps | eResultFlags::Tokens, &results ) ); + + sTranscribeLength length; + CHECK( results->getSize( length ) ); + + const whisper_params& params = *( (sPrintUserData*)user_data )->params; + // const std::vector<std::vector<float>>& pcmf32s = *( (sPrintUserData*)user_data )->pcmf32s; + + // print the last n_new segments + const uint32_t s0 = length.countSegments - n_new; + if( s0 == 0 ) + printf( "\n" ); + + const sSegment* const segments = results->getSegments(); + const sToken* const tokens = results->getTokens(); + + for( uint32_t i = s0; i < length.countSegments; i++ ) + { + const sSegment& seg = segments[ i ]; + + if( params.no_timestamps ) + { + if( params.print_colors ) + { + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%s%S", k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + } + else + wprintf( L"%s", utf16( seg.text ).c_str() ); + fflush( stdout ); + continue; + } + + std::string speaker = ""; +#if 0 + if( params.diarize && pcmf32s.size() == 2 ) + { + const size_t n_samples = pcmf32s[ 0 ].size(); + const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples ); + const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples ); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for( int64_t j = is0; j < is1; j++ ) + { + energy0 += fabs( pcmf32s[ 0 ][ j ] ); + energy1 += fabs( pcmf32s[ 1 ][ j ] ); + } + + if( energy0 > 1.1 * energy1 ) + speaker = "(speaker 0)"; + else if( energy1 > 1.1 * energy0 ) + speaker = "(speaker 1)"; + else + speaker = "(speaker ?)"; + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } +#endif + + if( params.print_colors ) + { + printf( "[%s --> %s] ", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str() ); + + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%S%s%S", speaker.c_str(), k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + printf( "\n" ); + } + else + wprintf( L"[%S --> %S] %S%s\n", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str(), speaker.c_str(), utf16( seg.text ).c_str() ); + } + return S_OK; + } + + HRESULT __cdecl beginSegmentCallback( iContext* context, void* user_data ) noexcept + { + std::atomic_bool* flag = (std::atomic_bool*)user_data; + bool aborted = flag->load(); + return aborted ? S_FALSE : S_OK; + } + + HRESULT setupConsoleColors() + { + HANDLE h = GetStdHandle( STD_OUTPUT_HANDLE ); + if( h == INVALID_HANDLE_VALUE ) + return HRESULT_FROM_WIN32( GetLastError() ); + + DWORD mode = 0; + if( !GetConsoleMode( h, &mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + if( 0 != ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) ) + return S_FALSE; + + mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if( !SetConsoleMode( h, mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + return S_OK; + } +} + +int wmain( int argc, wchar_t* argv[] ) +{ + // Whisper::dbgCompareTraces( LR"(C:\Temp\2remove\Whisper\ref.bin)", LR"(C:\Temp\2remove\Whisper\gpu.bin )" ); return 0; + + // Tell logger to use the standard output stream for the messages + { + Whisper::sLoggerSetup logSetup; + logSetup.flags = eLoggerFlags::UseStandardError; + logSetup.level = eLogLevel::Debug; + Whisper::setupLogger( logSetup ); + } + + whisper_params params; + if( !params.parse( argc, argv ) ) + return 1; + + if( params.print_colors ) + { + if( FAILED( setupConsoleColors() ) ) + params.print_colors = false; + } + + if( params.fname_inp.empty() ) + { + fprintf( stderr, "error: no input files specified\n" ); + whisper_print_usage( argc, argv, params ); + return 2; + } + + if( Whisper::findLanguageKeyA( params.language.c_str() ) == UINT_MAX ) + { + fprintf( stderr, "error: unknown language '%s'\n", params.language.c_str() ); + whisper_print_usage( argc, argv, params ); + return 3; + } + + ComLight::CComPtr<iModel> model; + HRESULT hr = loadWhisperModel( params.model.c_str(), &model ); + if( FAILED( hr ) ) + { + printError( "failed to load the model", hr ); + return 4; + } + + ComLight::CComPtr<iContext> context; + hr = model->createContext( &context ); + if( FAILED( hr ) ) + { + printError( "failed to initialize whisper context", hr ); + return 5; + } + + ComLight::CComPtr<iMediaFoundation> mf; + hr = initMediaFoundation( &mf ); + if( FAILED( hr ) ) + { + printError( "failed to initialize Media Foundation runtime", hr ); + return 5; + } + + for( const std::wstring& fname : params.fname_inp ) + { + // print some info about the processing + { + if( model->isMultilingual() == S_FALSE ) + { + if( params.language != "en" || params.translate ) + { + params.language = "en"; + params.translate = false; + fprintf( stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__ ); + } + } + + /* + fwprintf( stderr, L"%S: processing '%s' (%zu samples, %.1f sec), %d threads, %d processors, lang = %S, task = %S, timestamps = %d ...\n", + __func__, fname.c_str(), audio.pcmf32.size(), audio.seconds(), + params.n_threads, params.n_processors, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1 ); + */ + } + + // run the inference + Whisper::sFullParams wparams; + context->fullDefaultParams( eSamplingStrategy::Greedy, &wparams ); + + wparams.resetFlag( eFullParamsFlags::PrintRealtime | eFullParamsFlags::PrintProgress ); + wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps ); + wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special ); + wparams.setFlag( eFullParamsFlags::Translate, params.translate ); + wparams.language = Whisper::makeLanguageKey( params.language.c_str() ); + wparams.cpuThreads = params.n_threads; + if( params.max_context != UINT_MAX ) + wparams.n_max_text_ctx = params.max_context; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.setFlag( eFullParamsFlags::TokenTimestamps, params.output_wts || params.max_len > 0 ); + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.setFlag( eFullParamsFlags::SpeedupAudio, params.speed_up ); + // sPrintUserData user_data = { ¶ms, &audio.pcmf32s }; + sPrintUserData user_data = { ¶ms }; + + // this callback is called on each new segment + if( !wparams.flag( eFullParamsFlags::PrintRealtime ) ) + { + wparams.new_segment_callback = &newSegmentCallback; + wparams.new_segment_callback_user_data = &user_data; + } + + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + std::atomic_bool is_aborted = false; + { + wparams.encoder_begin_callback = &beginSegmentCallback; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + +#if STREAM_AUDIO + ComLight::CComPtr<iAudioReader> reader; + CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) ); + sProgressSink progressSink{ nullptr, nullptr }; + hr = context->runStreamed( wparams, progressSink, reader ); +#else + ComLight::CComPtr<iAudioBuffer> buffer; + CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) ); + hr = context->runFull( wparams, buffer ); +#endif + if( FAILED( hr ) ) + { + fwprintf( stderr, L"%s: failed to process audio\n", argv[ 0 ] ); + return 10; + } + } + + context->timingsPrint(); + context = nullptr; + return 0; +}
\ No newline at end of file diff --git a/Examples/main/main.vcxproj b/Examples/main/main.vcxproj new file mode 100644 index 0000000..4945b88 --- /dev/null +++ b/Examples/main/main.vcxproj @@ -0,0 +1,93 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{4cca7042-eb15-4f7a-b77b-5cafd2df47b2}</ProjectGuid> + <RootNamespace>main</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NOMINMAX;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NOMINMAX;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="main.cpp" /> + <ClCompile Include="miscUtils.cpp" /> + <ClCompile Include="params.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="miscUtils.h" /> + <ClInclude Include="params.h" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\..\Whisper\Whisper.vcxproj"> + <Project>{701df8c8-e4a5-43ec-9c6b-747bbf4d8e71}</Project> + </ProjectReference> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Examples/main/main.vcxproj.filters b/Examples/main/main.vcxproj.filters new file mode 100644 index 0000000..94cd8a1 --- /dev/null +++ b/Examples/main/main.vcxproj.filters @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <ClCompile Include="main.cpp" /> + <ClCompile Include="params.cpp" /> + <ClCompile Include="miscUtils.cpp" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="params.h" /> + <ClInclude Include="miscUtils.h" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/Examples/main/miscUtils.cpp b/Examples/main/miscUtils.cpp new file mode 100644 index 0000000..3ebda20 --- /dev/null +++ b/Examples/main/miscUtils.cpp @@ -0,0 +1,48 @@ +#include "miscUtils.h" +#define WIN32_LEAN_AND_MEAN +#include <windows.h> + +std::string utf8( const std::wstring& utf16 ) +{ + int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr ); + std::string str( count, 0 ); + WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr ); + return str; +} + +std::wstring utf16( const std::string& u8 ) +{ + int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 ); + std::wstring str( count, 0 ); + MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count ); + return str; +} + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + NULL ) ) + return err; + return nullptr; + } +} + +void printError( const char* what, HRESULT hr ) +{ + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + fwprintf( stderr, L"%S: %s\n", what, err ); + LocalFree( (HLOCAL)err ); + } + else + fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr ); +}
\ No newline at end of file diff --git a/Examples/main/miscUtils.h b/Examples/main/miscUtils.h new file mode 100644 index 0000000..52770a6 --- /dev/null +++ b/Examples/main/miscUtils.h @@ -0,0 +1,9 @@ +#pragma once +#include <string> + +std::string utf8( const std::wstring& utf16 ); + +std::wstring utf16( const std::string& u8 ); + +using HRESULT = long; +void printError( const char* what, HRESULT hr );
\ No newline at end of file diff --git a/Examples/main/params.cpp b/Examples/main/params.cpp new file mode 100644 index 0000000..ff1cfdd --- /dev/null +++ b/Examples/main/params.cpp @@ -0,0 +1,101 @@ +#include "params.h" +#include <algorithm> +#include <thread> +#include "miscUtils.h" + +whisper_params::whisper_params() +{ +#ifdef _DEBUG + n_threads = 2; +#else + n_threads = std::min( 4u, std::thread::hardware_concurrency() ); +#endif +} + +namespace +{ + const char* cstr( bool b ) + { + return b ? "true" : "false"; + } +} + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ) +{ + fprintf( stderr, "\n" ); + fprintf( stderr, "usage: %S [options] file0.wav file1.wav ...\n", argv[ 0 ] ); + fprintf( stderr, "\n" ); + fprintf( stderr, "options:\n" ); + fprintf( stderr, " -h, --help [default] show this help message and exit\n" ); + fprintf( stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads ); + fprintf( stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors ); + fprintf( stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms ); + fprintf( stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n ); + fprintf( stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms ); + fprintf( stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context ); + fprintf( stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len ); + fprintf( stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold ); + fprintf( stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", cstr( params.speed_up ) ); + fprintf( stderr, " -tr, --translate [%-7s] translate from source language to english\n", cstr( params.translate ) ); + fprintf( stderr, " -di, --diarize [%-7s] stereo audio diarization\n", cstr( params.diarize ) ); + fprintf( stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", cstr( params.output_txt ) ); + fprintf( stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", cstr( params.output_vtt ) ); + fprintf( stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", cstr( params.output_srt ) ); + fprintf( stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", cstr( params.output_wts ) ); + fprintf( stderr, " -ps, --print-special [%-7s] print special tokens\n", cstr( params.print_special ) ); + fprintf( stderr, " -nc, --no-colors [%-7s] do not print colors\n", cstr( !params.print_colors ) ); + fprintf( stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", cstr( params.no_timestamps ) ); + fprintf( stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str() ); + fprintf( stderr, " -m FNAME, --model FNAME [%-7S] model path\n", params.model.c_str() ); + fprintf( stderr, " -f FNAME, --file FNAME [%-7s] path of the input audio file\n", "" ); + fprintf( stderr, "\n" ); +} + +bool whisper_params::parse( int argc, wchar_t* argv[] ) +{ + for( int i = 1; i < argc; i++ ) + { + std::wstring arg = argv[ i ]; + + if( arg[ 0 ] != '-' ) + { + fname_inp.push_back( arg ); + continue; + } + + if( arg == L"-h" || arg == L"--help" ) + { + whisper_print_usage( argc, argv, *this ); + return false; + } + + else if( arg == L"-t" || arg == L"--threads" ) { n_threads = std::stoul( argv[ ++i ] ); } + else if( arg == L"-p" || arg == L"--processors" ) { n_processors = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ot" || arg == L"--offset-t" ) { offset_t_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-on" || arg == L"--offset-n" ) { offset_n = std::stoul( argv[ ++i ] ); } + else if( arg == L"-d" || arg == L"--duration" ) { duration_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-mc" || arg == L"--max-context" ) { max_context = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ml" || arg == L"--max-len" ) { max_len = std::stoul( argv[ ++i ] ); } + else if( arg == L"-wt" || arg == L"--word-thold" ) { word_thold = std::stof( argv[ ++i ] ); } + else if( arg == L"-su" || arg == L"--speed-up" ) { speed_up = true; } + else if( arg == L"-tr" || arg == L"--translate" ) { translate = true; } + else if( arg == L"-di" || arg == L"--diarize" ) { diarize = true; } + else if( arg == L"-otxt" || arg == L"--output-txt" ) { output_txt = true; } + else if( arg == L"-ovtt" || arg == L"--output-vtt" ) { output_vtt = true; } + else if( arg == L"-osrt" || arg == L"--output-srt" ) { output_srt = true; } + else if( arg == L"-owts" || arg == L"--output-words" ) { output_wts = true; } + else if( arg == L"-ps" || arg == L"--print-special" ) { print_special = true; } + else if( arg == L"-nc" || arg == L"--no-colors" ) { print_colors = false; } + else if( arg == L"-nt" || arg == L"--no-timestamps" ) { no_timestamps = true; } + else if( arg == L"-l" || arg == L"--language" ) { language = utf8( argv[ ++i ] ); } + else if( arg == L"-m" || arg == L"--model" ) { model = argv[ ++i ]; } + else if( arg == L"-f" || arg == L"--file" ) { fname_inp.push_back( argv[ ++i ] ); } + else + { + fprintf( stderr, "error: unknown argument: %S\n", arg.c_str() ); + whisper_print_usage( argc, argv, *this ); + return false; + } + } + return true; +}
\ No newline at end of file diff --git a/Examples/main/params.h b/Examples/main/params.h new file mode 100644 index 0000000..9eb2b04 --- /dev/null +++ b/Examples/main/params.h @@ -0,0 +1,38 @@ +#pragma once +#include <vector> +#include <string> + +// command-line parameters +struct whisper_params +{ + uint32_t n_threads; + uint32_t n_processors = 1; + uint32_t offset_t_ms = 0; + uint32_t offset_n = 0; + uint32_t duration_ms = 0; + uint32_t max_context = UINT_MAX; + uint32_t max_len = 0; + + float word_thold = 0.01f; + + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool print_special = false; + bool print_colors = true; + bool no_timestamps = false; + + std::string language = "en"; + std::wstring model = L"models/ggml-base.en.bin"; + std::vector<std::wstring> fname_inp; + + whisper_params(); + + bool parse( int argc, wchar_t* argv[] ); +}; + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params );
\ No newline at end of file |
