From 8c4603c73675958efc960fbd4bb599a2909d106a Mon Sep 17 00:00:00 2001 From: Konstantin Date: Mon, 16 Jan 2023 14:52:43 +0100 Subject: Source codes --- Examples/main/main.cpp | 315 +++++++++++++++++++++++++++++++++++++ Examples/main/main.vcxproj | 93 +++++++++++ Examples/main/main.vcxproj.filters | 12 ++ Examples/main/miscUtils.cpp | 48 ++++++ Examples/main/miscUtils.h | 9 ++ Examples/main/params.cpp | 101 ++++++++++++ Examples/main/params.h | 38 +++++ 7 files changed, 616 insertions(+) create mode 100644 Examples/main/main.cpp create mode 100644 Examples/main/main.vcxproj create mode 100644 Examples/main/main.vcxproj.filters create mode 100644 Examples/main/miscUtils.cpp create mode 100644 Examples/main/miscUtils.h create mode 100644 Examples/main/params.cpp create mode 100644 Examples/main/params.h (limited to 'Examples/main') diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp new file mode 100644 index 0000000..c9eacf2 --- /dev/null +++ b/Examples/main/main.cpp @@ -0,0 +1,315 @@ +#include "params.h" +#include "../../Whisper/API/iContext.cl.h" +#include "../../Whisper/API/iMediaFoundation.cl.h" +#include "../../ComLightLib/comLightClient.h" +#include "miscUtils.h" +#include +#include +using namespace Whisper; + +#define STREAM_AUDIO 1 + +static HRESULT loadWhisperModel( const wchar_t* path, iModel** pp ) +{ + using namespace Whisper; + constexpr eModelImplementation impl = eModelImplementation::GPU; + // constexpr eModelImplementation impl = eModelImplementation::Reference; + return Whisper::loadModel( path, impl, nullptr, pp ); +} + +namespace +{ + struct sPrintUserData + { + const whisper_params* params; + // const std::vector>* pcmf32s; + }; + + // Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] + // Lowest is red, middle is yellow, highest is green. + static const std::array k_colors = + { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", + }; + + std::string to_timestamp( sTimeSpan ts, bool comma = false ) + { + sTimeSpanFields fields = ts; + uint32_t msec = fields.ticks / 10'000; + uint32_t hr = fields.days * 24 + fields.hours; + uint32_t min = fields.minutes; + uint32_t sec = fields.seconds; + + char buf[ 32 ]; + snprintf( buf, sizeof( buf ), "%02d:%02d:%02d%s%03d", hr, min, sec, comma ? "," : ".", msec ); + return std::string( buf ); + } + + static int colorIndex( const sToken& tok ) + { + const float p = tok.probability; + const float p3 = p * p * p; + int col = (int)( p3 * float( k_colors.size() ) ); + col = std::max( 0, std::min( (int)k_colors.size() - 1, col ) ); + return col; + } + + HRESULT __cdecl newSegmentCallback( iContext* context, uint32_t n_new, void* user_data ) noexcept + { + ComLight::CComPtr results; + CHECK( context->getResults( eResultFlags::Timestamps | eResultFlags::Tokens, &results ) ); + + sTranscribeLength length; + CHECK( results->getSize( length ) ); + + const whisper_params& params = *( (sPrintUserData*)user_data )->params; + // const std::vector>& pcmf32s = *( (sPrintUserData*)user_data )->pcmf32s; + + // print the last n_new segments + const uint32_t s0 = length.countSegments - n_new; + if( s0 == 0 ) + printf( "\n" ); + + const sSegment* const segments = results->getSegments(); + const sToken* const tokens = results->getTokens(); + + for( uint32_t i = s0; i < length.countSegments; i++ ) + { + const sSegment& seg = segments[ i ]; + + if( params.no_timestamps ) + { + if( params.print_colors ) + { + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%s%S", k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + } + else + wprintf( L"%s", utf16( seg.text ).c_str() ); + fflush( stdout ); + continue; + } + + std::string speaker = ""; +#if 0 + if( params.diarize && pcmf32s.size() == 2 ) + { + const size_t n_samples = pcmf32s[ 0 ].size(); + const int64_t is0 = SourceAudio::sampleFromTimestamp( seg.time.begin, n_samples ); + const int64_t is1 = SourceAudio::sampleFromTimestamp( seg.time.end, n_samples ); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for( int64_t j = is0; j < is1; j++ ) + { + energy0 += fabs( pcmf32s[ 0 ][ j ] ); + energy1 += fabs( pcmf32s[ 1 ][ j ] ); + } + + if( energy0 > 1.1 * energy1 ) + speaker = "(speaker 0)"; + else if( energy1 > 1.1 * energy0 ) + speaker = "(speaker 1)"; + else + speaker = "(speaker ?)"; + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } +#endif + + if( params.print_colors ) + { + printf( "[%s --> %s] ", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str() ); + + for( uint32_t j = 0; j < seg.countTokens; j++ ) + { + const sToken& tok = tokens[ seg.firstToken + j ]; + if( !params.print_special && ( tok.flags & eTokenFlags::Special ) ) + continue; + wprintf( L"%S%S%s%S", speaker.c_str(), k_colors[ colorIndex( tok ) ], utf16( tok.text ).c_str(), "\033[0m" ); + } + printf( "\n" ); + } + else + wprintf( L"[%S --> %S] %S%s\n", to_timestamp( seg.time.begin ).c_str(), to_timestamp( seg.time.end ).c_str(), speaker.c_str(), utf16( seg.text ).c_str() ); + } + return S_OK; + } + + HRESULT __cdecl beginSegmentCallback( iContext* context, void* user_data ) noexcept + { + std::atomic_bool* flag = (std::atomic_bool*)user_data; + bool aborted = flag->load(); + return aborted ? S_FALSE : S_OK; + } + + HRESULT setupConsoleColors() + { + HANDLE h = GetStdHandle( STD_OUTPUT_HANDLE ); + if( h == INVALID_HANDLE_VALUE ) + return HRESULT_FROM_WIN32( GetLastError() ); + + DWORD mode = 0; + if( !GetConsoleMode( h, &mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + if( 0 != ( mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING ) ) + return S_FALSE; + + mode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if( !SetConsoleMode( h, mode ) ) + return HRESULT_FROM_WIN32( GetLastError() ); + return S_OK; + } +} + +int wmain( int argc, wchar_t* argv[] ) +{ + // Whisper::dbgCompareTraces( LR"(C:\Temp\2remove\Whisper\ref.bin)", LR"(C:\Temp\2remove\Whisper\gpu.bin )" ); return 0; + + // Tell logger to use the standard output stream for the messages + { + Whisper::sLoggerSetup logSetup; + logSetup.flags = eLoggerFlags::UseStandardError; + logSetup.level = eLogLevel::Debug; + Whisper::setupLogger( logSetup ); + } + + whisper_params params; + if( !params.parse( argc, argv ) ) + return 1; + + if( params.print_colors ) + { + if( FAILED( setupConsoleColors() ) ) + params.print_colors = false; + } + + if( params.fname_inp.empty() ) + { + fprintf( stderr, "error: no input files specified\n" ); + whisper_print_usage( argc, argv, params ); + return 2; + } + + if( Whisper::findLanguageKeyA( params.language.c_str() ) == UINT_MAX ) + { + fprintf( stderr, "error: unknown language '%s'\n", params.language.c_str() ); + whisper_print_usage( argc, argv, params ); + return 3; + } + + ComLight::CComPtr model; + HRESULT hr = loadWhisperModel( params.model.c_str(), &model ); + if( FAILED( hr ) ) + { + printError( "failed to load the model", hr ); + return 4; + } + + ComLight::CComPtr context; + hr = model->createContext( &context ); + if( FAILED( hr ) ) + { + printError( "failed to initialize whisper context", hr ); + return 5; + } + + ComLight::CComPtr mf; + hr = initMediaFoundation( &mf ); + if( FAILED( hr ) ) + { + printError( "failed to initialize Media Foundation runtime", hr ); + return 5; + } + + for( const std::wstring& fname : params.fname_inp ) + { + // print some info about the processing + { + if( model->isMultilingual() == S_FALSE ) + { + if( params.language != "en" || params.translate ) + { + params.language = "en"; + params.translate = false; + fprintf( stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__ ); + } + } + + /* + fwprintf( stderr, L"%S: processing '%s' (%zu samples, %.1f sec), %d threads, %d processors, lang = %S, task = %S, timestamps = %d ...\n", + __func__, fname.c_str(), audio.pcmf32.size(), audio.seconds(), + params.n_threads, params.n_processors, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1 ); + */ + } + + // run the inference + Whisper::sFullParams wparams; + context->fullDefaultParams( eSamplingStrategy::Greedy, &wparams ); + + wparams.resetFlag( eFullParamsFlags::PrintRealtime | eFullParamsFlags::PrintProgress ); + wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps ); + wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special ); + wparams.setFlag( eFullParamsFlags::Translate, params.translate ); + wparams.language = Whisper::makeLanguageKey( params.language.c_str() ); + wparams.cpuThreads = params.n_threads; + if( params.max_context != UINT_MAX ) + wparams.n_max_text_ctx = params.max_context; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.setFlag( eFullParamsFlags::TokenTimestamps, params.output_wts || params.max_len > 0 ); + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.setFlag( eFullParamsFlags::SpeedupAudio, params.speed_up ); + // sPrintUserData user_data = { ¶ms, &audio.pcmf32s }; + sPrintUserData user_data = { ¶ms }; + + // this callback is called on each new segment + if( !wparams.flag( eFullParamsFlags::PrintRealtime ) ) + { + wparams.new_segment_callback = &newSegmentCallback; + wparams.new_segment_callback_user_data = &user_data; + } + + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + std::atomic_bool is_aborted = false; + { + wparams.encoder_begin_callback = &beginSegmentCallback; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + +#if STREAM_AUDIO + ComLight::CComPtr reader; + CHECK( mf->openAudioFile( fname.c_str(), params.diarize, &reader ) ); + sProgressSink progressSink{ nullptr, nullptr }; + hr = context->runStreamed( wparams, progressSink, reader ); +#else + ComLight::CComPtr buffer; + CHECK( mf->loadAudioFile( fname.c_str(), params.diarize, &buffer ) ); + hr = context->runFull( wparams, buffer ); +#endif + if( FAILED( hr ) ) + { + fwprintf( stderr, L"%s: failed to process audio\n", argv[ 0 ] ); + return 10; + } + } + + context->timingsPrint(); + context = nullptr; + return 0; +} \ No newline at end of file diff --git a/Examples/main/main.vcxproj b/Examples/main/main.vcxproj new file mode 100644 index 0000000..4945b88 --- /dev/null +++ b/Examples/main/main.vcxproj @@ -0,0 +1,93 @@ + + + + + Debug + x64 + + + Release + x64 + + + + 16.0 + Win32Proj + {4cca7042-eb15-4f7a-b77b-5cafd2df47b2} + main + 10.0 + + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + + + + + + + + + + + + + + Level3 + true + NOMINMAX;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + stdcpp20 + + + Console + true + + + + + Level3 + true + true + true + NOMINMAX;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + stdcpp20 + + + Console + true + true + true + + + + + + + + + + + + + + {701df8c8-e4a5-43ec-9c6b-747bbf4d8e71} + + + + + + \ No newline at end of file diff --git a/Examples/main/main.vcxproj.filters b/Examples/main/main.vcxproj.filters new file mode 100644 index 0000000..94cd8a1 --- /dev/null +++ b/Examples/main/main.vcxproj.filters @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/Examples/main/miscUtils.cpp b/Examples/main/miscUtils.cpp new file mode 100644 index 0000000..3ebda20 --- /dev/null +++ b/Examples/main/miscUtils.cpp @@ -0,0 +1,48 @@ +#include "miscUtils.h" +#define WIN32_LEAN_AND_MEAN +#include + +std::string utf8( const std::wstring& utf16 ) +{ + int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr ); + std::string str( count, 0 ); + WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr ); + return str; +} + +std::wstring utf16( const std::string& u8 ) +{ + int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 ); + std::wstring str( count, 0 ); + MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count ); + return str; +} + +namespace +{ + wchar_t* formatMessage( HRESULT hr ) + { + wchar_t* err; + if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + hr, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + (LPTSTR)&err, + 0, + NULL ) ) + return err; + return nullptr; + } +} + +void printError( const char* what, HRESULT hr ) +{ + const wchar_t* err = formatMessage( hr ); + if( nullptr != err ) + { + fwprintf( stderr, L"%S: %s\n", what, err ); + LocalFree( (HLOCAL)err ); + } + else + fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr ); +} \ No newline at end of file diff --git a/Examples/main/miscUtils.h b/Examples/main/miscUtils.h new file mode 100644 index 0000000..52770a6 --- /dev/null +++ b/Examples/main/miscUtils.h @@ -0,0 +1,9 @@ +#pragma once +#include + +std::string utf8( const std::wstring& utf16 ); + +std::wstring utf16( const std::string& u8 ); + +using HRESULT = long; +void printError( const char* what, HRESULT hr ); \ No newline at end of file diff --git a/Examples/main/params.cpp b/Examples/main/params.cpp new file mode 100644 index 0000000..ff1cfdd --- /dev/null +++ b/Examples/main/params.cpp @@ -0,0 +1,101 @@ +#include "params.h" +#include +#include +#include "miscUtils.h" + +whisper_params::whisper_params() +{ +#ifdef _DEBUG + n_threads = 2; +#else + n_threads = std::min( 4u, std::thread::hardware_concurrency() ); +#endif +} + +namespace +{ + const char* cstr( bool b ) + { + return b ? "true" : "false"; + } +} + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ) +{ + fprintf( stderr, "\n" ); + fprintf( stderr, "usage: %S [options] file0.wav file1.wav ...\n", argv[ 0 ] ); + fprintf( stderr, "\n" ); + fprintf( stderr, "options:\n" ); + fprintf( stderr, " -h, --help [default] show this help message and exit\n" ); + fprintf( stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads ); + fprintf( stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors ); + fprintf( stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms ); + fprintf( stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n ); + fprintf( stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms ); + fprintf( stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context ); + fprintf( stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len ); + fprintf( stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold ); + fprintf( stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", cstr( params.speed_up ) ); + fprintf( stderr, " -tr, --translate [%-7s] translate from source language to english\n", cstr( params.translate ) ); + fprintf( stderr, " -di, --diarize [%-7s] stereo audio diarization\n", cstr( params.diarize ) ); + fprintf( stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", cstr( params.output_txt ) ); + fprintf( stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", cstr( params.output_vtt ) ); + fprintf( stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", cstr( params.output_srt ) ); + fprintf( stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", cstr( params.output_wts ) ); + fprintf( stderr, " -ps, --print-special [%-7s] print special tokens\n", cstr( params.print_special ) ); + fprintf( stderr, " -nc, --no-colors [%-7s] do not print colors\n", cstr( !params.print_colors ) ); + fprintf( stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", cstr( params.no_timestamps ) ); + fprintf( stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str() ); + fprintf( stderr, " -m FNAME, --model FNAME [%-7S] model path\n", params.model.c_str() ); + fprintf( stderr, " -f FNAME, --file FNAME [%-7s] path of the input audio file\n", "" ); + fprintf( stderr, "\n" ); +} + +bool whisper_params::parse( int argc, wchar_t* argv[] ) +{ + for( int i = 1; i < argc; i++ ) + { + std::wstring arg = argv[ i ]; + + if( arg[ 0 ] != '-' ) + { + fname_inp.push_back( arg ); + continue; + } + + if( arg == L"-h" || arg == L"--help" ) + { + whisper_print_usage( argc, argv, *this ); + return false; + } + + else if( arg == L"-t" || arg == L"--threads" ) { n_threads = std::stoul( argv[ ++i ] ); } + else if( arg == L"-p" || arg == L"--processors" ) { n_processors = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ot" || arg == L"--offset-t" ) { offset_t_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-on" || arg == L"--offset-n" ) { offset_n = std::stoul( argv[ ++i ] ); } + else if( arg == L"-d" || arg == L"--duration" ) { duration_ms = std::stoul( argv[ ++i ] ); } + else if( arg == L"-mc" || arg == L"--max-context" ) { max_context = std::stoul( argv[ ++i ] ); } + else if( arg == L"-ml" || arg == L"--max-len" ) { max_len = std::stoul( argv[ ++i ] ); } + else if( arg == L"-wt" || arg == L"--word-thold" ) { word_thold = std::stof( argv[ ++i ] ); } + else if( arg == L"-su" || arg == L"--speed-up" ) { speed_up = true; } + else if( arg == L"-tr" || arg == L"--translate" ) { translate = true; } + else if( arg == L"-di" || arg == L"--diarize" ) { diarize = true; } + else if( arg == L"-otxt" || arg == L"--output-txt" ) { output_txt = true; } + else if( arg == L"-ovtt" || arg == L"--output-vtt" ) { output_vtt = true; } + else if( arg == L"-osrt" || arg == L"--output-srt" ) { output_srt = true; } + else if( arg == L"-owts" || arg == L"--output-words" ) { output_wts = true; } + else if( arg == L"-ps" || arg == L"--print-special" ) { print_special = true; } + else if( arg == L"-nc" || arg == L"--no-colors" ) { print_colors = false; } + else if( arg == L"-nt" || arg == L"--no-timestamps" ) { no_timestamps = true; } + else if( arg == L"-l" || arg == L"--language" ) { language = utf8( argv[ ++i ] ); } + else if( arg == L"-m" || arg == L"--model" ) { model = argv[ ++i ]; } + else if( arg == L"-f" || arg == L"--file" ) { fname_inp.push_back( argv[ ++i ] ); } + else + { + fprintf( stderr, "error: unknown argument: %S\n", arg.c_str() ); + whisper_print_usage( argc, argv, *this ); + return false; + } + } + return true; +} \ No newline at end of file diff --git a/Examples/main/params.h b/Examples/main/params.h new file mode 100644 index 0000000..9eb2b04 --- /dev/null +++ b/Examples/main/params.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include + +// command-line parameters +struct whisper_params +{ + uint32_t n_threads; + uint32_t n_processors = 1; + uint32_t offset_t_ms = 0; + uint32_t offset_n = 0; + uint32_t duration_ms = 0; + uint32_t max_context = UINT_MAX; + uint32_t max_len = 0; + + float word_thold = 0.01f; + + bool speed_up = false; + bool translate = false; + bool diarize = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool print_special = false; + bool print_colors = true; + bool no_timestamps = false; + + std::string language = "en"; + std::wstring model = L"models/ggml-base.en.bin"; + std::vector fname_inp; + + whisper_params(); + + bool parse( int argc, wchar_t* argv[] ); +}; + +void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ); \ No newline at end of file -- cgit v1.2.3