diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /Examples/WhisperDesktop/TranscribeDlg.cpp | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'Examples/WhisperDesktop/TranscribeDlg.cpp')
| -rw-r--r-- | Examples/WhisperDesktop/TranscribeDlg.cpp | 493 |
1 files changed, 493 insertions, 0 deletions
diff --git a/Examples/WhisperDesktop/TranscribeDlg.cpp b/Examples/WhisperDesktop/TranscribeDlg.cpp new file mode 100644 index 0000000..14bec05 --- /dev/null +++ b/Examples/WhisperDesktop/TranscribeDlg.cpp @@ -0,0 +1,493 @@ +#include "stdafx.h" +#include "TranscribeDlg.h" +#include "Utils/logger.h" + +HRESULT TranscribeDlg::show() +{ + auto res = DoModal( nullptr ); + if( res == -1 ) + return HRESULT_FROM_WIN32( GetLastError() ); + switch( res ) + { + case IDC_BACK: + return SCREEN_MODEL; + case IDC_CAPTURE: + return SCREEN_CAPTURE; + } + return S_OK; +} + +constexpr int progressMaxInteger = 1024 * 8; + +static const LPCTSTR regValInput = L"sourceMedia"; +static const LPCTSTR regValOutFormat = L"resultFormat"; +static const LPCTSTR regValOutPath = L"resultPath"; + +LRESULT TranscribeDlg::OnInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ) +{ + // First DDX call, hooks up variables to controls. + DoDataExchange( false ); + printModelDescription(); + languageSelector.initialize( m_hWnd, IDC_LANGUAGE, appState ); + cbConsole.initialize( m_hWnd, IDC_CONSOLE, appState ); + cbTranslate.initialize( m_hWnd, IDC_TRANSLATE, appState ); + populateOutputFormats(); + + pendingState.initialize( + { + languageSelector, + sourceMediaPath, GetDlgItem( IDC_BROWSE_MEDIA ), + transcribeOutFormat, + transcribeOutputPath, GetDlgItem( IDC_BROWSE_RESULT ), + GetDlgItem( IDC_TRANSCRIBE ), + GetDlgItem( IDCANCEL ), + GetDlgItem( IDC_BACK ), + GetDlgItem( IDC_CAPTURE ) + }, + { + progressBar, GetDlgItem( IDC_PENDING_TEXT ) + } ); + + HRESULT hr = work.create( this ); + if( FAILED( hr ) ) + { + reportError( m_hWnd, L"CreateThreadpoolWork failed", nullptr, hr ); + EndDialog( IDCANCEL ); + } + + progressBar.SetRange32( 0, progressMaxInteger ); + progressBar.SetStep( 1 ); + + sourceMediaPath.SetWindowText( appState.stringLoad( regValInput ) ); + transcribeOutFormat.SetCurSel( (int)appState.dwordLoad( regValOutFormat, 0 ) ); + transcribeOutputPath.SetWindowText( appState.stringLoad( regValOutPath ) ); + BOOL unused; + OnOutFormatChange( 0, 0, nullptr, unused ); + + appState.lastScreenSave( SCREEN_TRANSCRIBE ); + appState.setupIcon( this ); + ATLVERIFY( CenterWindow() ); + return 0; +} + +void TranscribeDlg::printModelDescription() +{ + CString text; + if( S_OK == appState.model->isMultilingual() ) + text = L"Multilingual"; + else + text = L"Single-language"; + text += L" model \""; + LPCTSTR path = appState.source.path; + path = ::PathFindFileName( path ); + text += path; + text += L"\", "; + const int64_t cb = appState.source.sizeInBytes; + if( cb < 1 << 30 ) + { + constexpr double mul = 1.0 / ( 1 << 20 ); + double mb = (double)cb * mul; + text.AppendFormat( L"%.1f MB", mb ); + } + else + { + constexpr double mul = 1.0 / ( 1 << 30 ); + double gb = (double)cb * mul; + text.AppendFormat( L"%.2f GB", gb ); + } + text += L" on disk, "; + text += implString( appState.source.impl ); + text += L" implementation"; + + modelDesc.SetWindowText( text ); +} + +void TranscribeDlg::populateOutputFormats() +{ + transcribeOutFormat.AddString( L"None" ); + transcribeOutFormat.AddString( L"Text File" ); + transcribeOutFormat.AddString( L"SubRip subtitles" ); + transcribeOutFormat.AddString( L"WebVTT subtitles" ); +} + +enum struct TranscribeDlg::eOutputFormat : uint8_t +{ + None = 0, + Text = 1, + SubRip = 2, + WebVTT = 3 +}; + +LRESULT TranscribeDlg::OnOutFormatChange( UINT, INT, HWND, BOOL& bHandled ) +{ + const BOOL enabled = transcribeOutFormat.GetCurSel() != 0; + transcribeOutputPath.EnableWindow( enabled ); + transcribeOutputBrowse.EnableWindow( enabled ); + return 0; +} + +void TranscribeDlg::onBrowseMedia() +{ + LPCTSTR title = L"Input audio file to transcribe"; + LPCTSTR filters = L"Multimedia Files\0*.wav;*.wave;*.mp3;*.wma;*.mp4;*.mpeg4;*.mkv\0\0"; + + CString path; + sourceMediaPath.GetWindowText( path ); + if( getOpenFileName( m_hWnd, title, filters, path ) ) + sourceMediaPath.SetWindowText( path ); +} + +static const LPCTSTR outputFilters = L"Text files (*.txt)\0*.txt\0SubRip subtitles (*.srt)\0*.srt\0WebVTT subtitles (*.vtt)\0*.vtt\0\0"; +static const std::array<LPCTSTR, 3> outputExtensions = +{ + L".txt", L".srt", L".vtt" +}; + +void TranscribeDlg::onBrowseOutput() +{ + const DWORD origFilterIndex = (DWORD)transcribeOutFormat.GetCurSel() - 1; + + LPCTSTR title = L"Output Text File"; + CString path; + transcribeOutputPath.GetWindowText( path ); + DWORD filterIndex = origFilterIndex; + if( !getSaveFileName( m_hWnd, title, outputFilters, path, &filterIndex ) ) + return; + + LPCTSTR ext = PathFindExtension( path ); + if( 0 == *ext && filterIndex < outputExtensions.size() ) + { + wchar_t* const buffer = path.GetBufferSetLength( path.GetLength() + 5 ); + PathRenameExtension( buffer, outputExtensions[ filterIndex ] ); + path.ReleaseBuffer(); + } + + transcribeOutputPath.SetWindowText( path ); + if( filterIndex != origFilterIndex ) + transcribeOutFormat.SetCurSel( filterIndex + 1 ); +} + +void TranscribeDlg::setPending( bool nowPending ) +{ + pendingState.setPending( nowPending ); +} + +void TranscribeDlg::transcribeError( LPCTSTR text, HRESULT hr ) +{ + reportError( m_hWnd, text, L"Unable to transcribe audio", hr ); +} + +void TranscribeDlg::onTranscribe() +{ + // Validate input + sourceMediaPath.GetWindowText( transcribeArgs.pathMedia ); + if( transcribeArgs.pathMedia.GetLength() <= 0 ) + { + transcribeError( L"Please select an input audio file" ); + return; + } + + if( !PathFileExists( transcribeArgs.pathMedia ) ) + { + transcribeError( L"Input audio file does not exist", HRESULT_FROM_WIN32( ERROR_FILE_NOT_FOUND ) ); + return; + } + + transcribeArgs.language = languageSelector.selectedLanguage(); + transcribeArgs.translate = cbTranslate.checked(); + if( isInvalidTranslate( m_hWnd, transcribeArgs.language, transcribeArgs.translate ) ) + return; + + transcribeArgs.format = (eOutputFormat)(uint8_t)transcribeOutFormat.GetCurSel(); + if( transcribeArgs.format != eOutputFormat::None ) + { + transcribeOutputPath.GetWindowText( transcribeArgs.pathOutput ); + if( transcribeArgs.pathOutput.GetLength() <= 0 ) + { + transcribeError( L"Please select an output text file" ); + return; + } + appState.stringStore( regValOutPath, transcribeArgs.pathOutput ); + } + else + cbConsole.ensureChecked(); + + appState.dwordStore( regValOutFormat, (uint32_t)(int)transcribeArgs.format ); + languageSelector.saveSelection( appState ); + cbTranslate.saveSelection( appState ); + appState.stringStore( regValInput, transcribeArgs.pathMedia ); + + setPending( true ); + + work.post(); +} + +void __stdcall TranscribeDlg::poolCallback() noexcept +{ + HRESULT hr = transcribe(); + PostMessage( WM_CALLBACK_STATUS, (WPARAM)hr ); +} + +static void printTime( CString& rdi, int64_t ticks ) +{ + const Whisper::sTimeSpan ts{ (uint64_t)ticks }; + const Whisper::sTimeSpanFields fields = ts; + + if( fields.days != 0 ) + { + rdi.AppendFormat( L"%i days, %i hours", fields.days, (int)fields.hours ); + return; + } + if( ( fields.hours | fields.minutes ) != 0 ) + { + rdi.AppendFormat( L"%02d:%02d:%02d", (int)fields.hours, (int)fields.minutes, (int)fields.seconds ); + return; + } + rdi.AppendFormat( L"%.3f seconds", (double)ticks / 1E7 ); +} + +LRESULT TranscribeDlg::onCallbackStatus( UINT, WPARAM wParam, LPARAM, BOOL& bHandled ) +{ + setPending( false ); + const HRESULT hr = (HRESULT)wParam; + if( FAILED( hr ) ) + { + LPCTSTR failMessage = L"Transcribe failed"; + + if( transcribeArgs.errorMessage.GetLength() > 0 ) + { + CString tmp = failMessage; + tmp += L"\n"; + tmp += transcribeArgs.errorMessage; + transcribeError( tmp, hr ); + } + else + transcribeError( failMessage, hr ); + + return 0; + } + + const int64_t elapsed = ( GetTickCount64() - transcribeArgs.startTime ) * 10'000; + const int64_t media = transcribeArgs.mediaDuration; + CString message = L"Transcribed the audio\nMedia duration: "; + printTime( message, media ); + message += L"\nProcessing time: "; + printTime( message, elapsed ); + message += L"\nRelative processing speed: "; + double mul = (double)media / (double)elapsed; + message.AppendFormat( L"%g", mul ); + + MessageBox( message, L"Transcribe Completed", MB_OK | MB_ICONINFORMATION ); + return 0; +} + +void TranscribeDlg::getThreadError() +{ + getLastError( transcribeArgs.errorMessage ); +} + +#define CHECK_EX( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) { getThreadError(); return __hr; } } + +HRESULT TranscribeDlg::transcribe() +{ + transcribeArgs.startTime = GetTickCount64(); + clearLastError(); + transcribeArgs.errorMessage = L""; + + using namespace Whisper; + CComPtr<iAudioReader> reader; + + CHECK_EX( appState.mediaFoundation->openAudioFile( transcribeArgs.pathMedia, false, &reader ) ); + CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) ); + + const eOutputFormat format = transcribeArgs.format; + CAtlFile outputFile; + if( format != eOutputFormat::None ) + CHECK( outputFile.Create( transcribeArgs.pathOutput, GENERIC_WRITE, 0, CREATE_ALWAYS ) ); + + transcribeArgs.resultFlags = eResultFlags::Timestamps | eResultFlags::Tokens; + + CComPtr<iContext> context; + CHECK_EX( appState.model->createContext( &context ) ); + + sFullParams fullParams; + CHECK_EX( context->fullDefaultParams( eSamplingStrategy::Greedy, &fullParams ) ); + fullParams.language = transcribeArgs.language; + fullParams.setFlag( eFullParamsFlags::Translate, transcribeArgs.translate ); + fullParams.resetFlag( eFullParamsFlags::PrintRealtime ); + + fullParams.new_segment_callback_user_data = this; + fullParams.new_segment_callback = &newSegmentCallbackStatic; + + // Setup the progress indication sink + sProgressSink progressSink{ &progressCallbackStatic, this }; + // Run the transcribe + CHECK_EX( context->runStreamed( fullParams, progressSink, reader ) ); + + context->timingsPrint(); + + if( format == eOutputFormat::None ) + return S_OK; + + CComPtr<iTranscribeResult> result; + CHECK_EX( context->getResults( transcribeArgs.resultFlags, &result ) ); + + sTranscribeLength len; + CHECK_EX( result->getSize( len ) ); + const sSegment* const segments = result->getSegments(); + + switch( format ) + { + case eOutputFormat::Text: + return writeTextFile( segments, len.countSegments, outputFile ); + case eOutputFormat::SubRip: + return writeSubRip( segments, len.countSegments, outputFile ); + case eOutputFormat::WebVTT: + return writeWebVTT( segments, len.countSegments, outputFile ); + default: + return E_FAIL; + } +} + +#undef CHECK_EX + +inline HRESULT TranscribeDlg::progressCallback( double p ) noexcept +{ + constexpr double mul = progressMaxInteger; + int pos = lround( mul * p ); + progressBar.PostMessage( PBM_SETPOS, pos, 0 ); + return S_OK; +} + +HRESULT __cdecl TranscribeDlg::progressCallbackStatic( double p, Whisper::iContext* ctx, void* pv ) noexcept +{ + TranscribeDlg* dlg = (TranscribeDlg*)pv; + return dlg->progressCallback( p ); +} + +namespace +{ + HRESULT write( CAtlFile& file, const CStringA& line ) + { + if( line.GetLength() > 0 ) + CHECK( file.Write( cstr( line ), (DWORD)line.GetLength() ) ); + return S_OK; + } + + void printTime( CStringA& rdi, Whisper::sTimeSpan time, bool comma ) + { + Whisper::sTimeSpanFields fields = time; + const char separator = comma ? ',' : '.'; + rdi.AppendFormat( "%02d:%02d:%02d%c%03d", + (int)fields.hours, + (int)fields.minutes, + (int)fields.seconds, + separator, + fields.ticks / 10'000 ); + } + + const char* skipBlank( const char* rsi ) + { + while( true ) + { + const char c = *rsi; + if( c == ' ' || c == '\t' ) + { + rsi++; + continue; + } + return rsi; + } + } +} + +using Whisper::sSegment; + + +HRESULT TranscribeDlg::writeTextFile( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + using namespace Whisper; + CHECK( writeUtf8Bom( file ) ); + CStringA line; + for( size_t i = 0; i < length; i++ ) + { + line = skipBlank( segments[ i ].text ); + line += "\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +HRESULT TranscribeDlg::writeSubRip( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + CHECK( writeUtf8Bom( file ) ); + CStringA line; + for( size_t i = 0; i < length; i++ ) + { + const sSegment& seg = segments[ i ]; + + line.Format( "%zu\r\n", i + 1 ); + printTime( line, seg.time.begin, true ); + line += " --> "; + printTime( line, seg.time.end, true ); + line += "\r\n"; + line += skipBlank( seg.text ); + line += "\r\n\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +HRESULT TranscribeDlg::writeWebVTT( const sSegment* const segments, const size_t length, CAtlFile& file ) +{ + CHECK( writeUtf8Bom( file ) ); + CStringA line; + line = "WEBVTT\r\n\r\n"; + CHECK( write( file, line ) ); + + for( size_t i = 0; i < length; i++ ) + { + const sSegment& seg = segments[ i ]; + line = ""; + + printTime( line, seg.time.begin, false ); + line += " --> "; + printTime( line, seg.time.end, false ); + line += "\r\n"; + line += skipBlank( seg.text ); + line += "\r\n\r\n"; + CHECK( write( file, line ) ); + } + return S_OK; +} + +inline HRESULT TranscribeDlg::newSegmentCallback( Whisper::iContext* ctx, uint32_t n_new ) +{ + using namespace Whisper; + CComPtr<iTranscribeResult> result; + CHECK( ctx->getResults( transcribeArgs.resultFlags, &result ) ); + return logNewSegments( result, n_new ); +} + +HRESULT __cdecl TranscribeDlg::newSegmentCallbackStatic( Whisper::iContext* ctx, uint32_t n_new, void* user_data ) noexcept +{ + TranscribeDlg* dlg = (TranscribeDlg*)user_data; + return dlg->newSegmentCallback( ctx, n_new ); +} + +void TranscribeDlg::onWmClose() +{ + if( GetDlgItem( IDCANCEL ).IsWindowEnabled() ) + { + EndDialog( IDCANCEL ); + return; + } + + constexpr UINT flags = MB_YESNO | MB_ICONQUESTION | MB_DEFBUTTON2; + const int res = this->MessageBox( L"Transcribe is in progress.\nDo you want to quit anyway?", L"Confirm exit", flags ); + if( res != IDYES ) + return; + + // TODO: instead of ExitProcess(), implement another callback in the DLL API, for proper cancellation of the background task + ExitProcess( 1 ); +}
\ No newline at end of file |
