diff options
| author | Konstantin <const@const.me> | 2023-01-19 17:10:24 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-19 17:10:24 +0100 |
| commit | 9df2ee2ead4ce23d06351a6cdb4fea588f79e429 (patch) | |
| tree | d365bc24b192e3929801a4ede5b26e74a6c9e77f | |
| parent | 06643094c166b0e80fb8f5f506f5e9d42a90c2bf (diff) | |
Workaround for the Microsoft’s bug in their MP3 decoder MFT
| -rw-r--r-- | Examples/WhisperDesktop/TranscribeDlg.cpp | 6 | ||||
| -rw-r--r-- | Readme.md | 2 | ||||
| -rw-r--r-- | Whisper/MF/MediaFoundation.cpp | 21 | ||||
| -rw-r--r-- | Whisper/MF/PcmReader.cpp | 176 | ||||
| -rw-r--r-- | Whisper/MF/PcmReader.h | 3 | ||||
| -rw-r--r-- | Whisper/MF/mfUtils.h | 3 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.misc.cpp | 8 | ||||
| -rw-r--r-- | Whisper/Whisper/MelStreamer.cpp | 8 | ||||
| -rw-r--r-- | Whisper/Whisper/MelStreamer.h | 9 |
9 files changed, 204 insertions, 32 deletions
diff --git a/Examples/WhisperDesktop/TranscribeDlg.cpp b/Examples/WhisperDesktop/TranscribeDlg.cpp index cd23c93..b99e98a 100644 --- a/Examples/WhisperDesktop/TranscribeDlg.cpp +++ b/Examples/WhisperDesktop/TranscribeDlg.cpp @@ -298,7 +298,6 @@ HRESULT TranscribeDlg::transcribe() CComPtr<iAudioReader> reader; CHECK_EX( appState.mediaFoundation->openAudioFile( transcribeArgs.pathMedia, false, &reader ) ); - CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) ); const eOutputFormat format = transcribeArgs.format; CAtlFile outputFile; @@ -324,6 +323,11 @@ HRESULT TranscribeDlg::transcribe() // Run the transcribe CHECK_EX( context->runStreamed( fullParams, progressSink, reader ) ); + // Once finished, query duration of the audio. + // The duration before the processing is sometimes different, by 20 seconds for the file in that issue: + // https://github.com/Const-me/Whisper/issues/4 + CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) ); + context->timingsPrint(); if( format == eOutputFormat::None ) @@ -47,7 +47,7 @@ The implementation is based on the [2009 article](https://www.researchgate.net/p * Pre-built binaries available The only supported platform is 64-bit Windows.<br/> -Should work on Windows 8.0 or newer, but I have only tested on Windows 10.<br/> +Should work on Windows 8.1 or newer, but I have only tested on Windows 10.<br/> The library requires a Direct3D 11.0 capable GPU, which in 2023 simply means “any hardware GPU”. The most recent GPU without D3D 11.0 support was Intel [Sandy Bridge](https://en.wikipedia.org/wiki/Sandy_Bridge) from 2011. diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp index 4a4f6a2..df6990c 100644 --- a/Whisper/MF/MediaFoundation.cpp +++ b/Whisper/MF/MediaFoundation.cpp @@ -7,6 +7,7 @@ #include <mfreadwrite.h> #include "mfUtils.h" #include "AudioCapture.h" +#include <mfapi.h> namespace Whisper { @@ -15,6 +16,7 @@ namespace Whisper CComPtr<IMFSourceReader> reader; bool wantStereo; CComPtr<iMediaFoundation> mediaFoundation; + mutable int64_t preciseSamplesCount = 0; HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final { @@ -31,7 +33,14 @@ namespace Whisper HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const noexcept override final { if( reader ) - return getStreamDuration( reader, rdi ); + { + if( 0 == preciseSamplesCount ) + return getStreamDuration( reader, rdi ); + else + { rdi = MFllMulDiv( preciseSamplesCount, 10'000'000, SAMPLE_RATE, 0 ); + return S_OK; + } + } return OLE_E_BLANK; } public: @@ -48,8 +57,18 @@ namespace Whisper logDebug16( L"Created source reader from the file \"%s\"", path ); return S_OK; } + void setPreciseSamplesCount( int64_t count ) const + { + preciseSamplesCount = count; + } }; + void setPreciseSamplesCount( const iAudioReader* ar, int64_t count ) + { + const AudioReader* r = static_cast<const AudioReader*>( ar ); + r->setPreciseSamplesCount( count ); + } + class MediaFoundation : public ComLight::ObjectRoot<iMediaFoundation> { MfStartupRaii raii; diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp index ab92fc3..16f38c6 100644 --- a/Whisper/MF/PcmReader.cpp +++ b/Whisper/MF/PcmReader.cpp @@ -110,13 +110,168 @@ namespace static const HandlerMono s_mono; static const HandlerDownmixedStereo s_downmix; static const HandlerStereo s_stereo; + + __forceinline __m128i load( const GUID& guid ) + { + return _mm_loadu_si128( ( const __m128i* )( &guid ) ); + } + + // Find audio decoder MFT, query MF_MT_SUBTYPE attribute of the current input media type of that MFT + HRESULT getDecoderInputSubtype( IMFSourceReader* reader, __m128i& rdi ) + { + store16( &rdi, _mm_setzero_si128() ); + + CComPtr<IMFSourceReaderEx> readerEx; + CHECK( reader->QueryInterface( &readerEx ) ); + constexpr uint32_t stream = MF_SOURCE_READER_FIRST_AUDIO_STREAM; + const __m128i decGuid = load( MFT_CATEGORY_AUDIO_DECODER ); + alignas( 16 ) GUID category; + for( DWORD i = 0; true; i++ ) + { + CComPtr<IMFTransform> mft; + HRESULT hr = readerEx->GetTransformForStream( stream, i, &category, &mft ); + if( FAILED( hr ) ) + return hr; + const __m128i cat = _mm_load_si128( ( const __m128i* ) & category ); + if( !vectorEqual( decGuid, cat ) ) + continue; + + CComPtr<IMFMediaType> mt; + CHECK( mft->GetInputCurrentType( 0, &mt ) ); + CHECK( mt->GetGUID( MF_MT_SUBTYPE, (GUID*)&rdi ) ); + return S_OK; + } + } + + // S_OK when the reader has an MP3 decoder for the first audio stream, S_FALSE otherwise + HRESULT isMp3Decoder( IMFSourceReader* reader ) + { + __m128i subtype; + CHECK( getDecoderInputSubtype( reader, subtype ) ); + const bool res = vectorEqual( subtype, load( MFAudioFormat_MP3 ) ); + return res ? S_OK : S_FALSE; + } + + // Workaround for a Microsoft's bug in Media Foundation MP3 decoder: https://github.com/Const-me/Whisper/issues/4 + // Media Foundation is reporting incorrect media duration = 12.54. Windows Media Player does the same. + // Winamp and Media Player Classic are reporting 12:35, VLC reports 12:36. + HRESULT getPreciseDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar ) + { + size_t samples = 0; + + // Decode the complete stream, counting samples + while( true ) + { + DWORD dwFlags = 0; + CComPtr<IMFSample> sample; + + // Read the next sample + HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); + return hr; + } + + if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED ) + { + // logError( u8"Media type changes ain’t supported by the library." ); + // return E_UNEXPECTED; + + // This happens for some video files at the very start of the reading, with Dolby AC3 audio track. + // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed. + CHECK( validateCurrentMediaType( reader, mono ? 1 : 2 ) ); + } + + if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM ) + break; + + if( !sample ) + { + // printf( "No sample\n" ); + continue; + } + + // Get a pointer to the audio data in the sample. + CComPtr<IMFMediaBuffer> buffer; + hr = sample->ConvertToContiguousBuffer( &buffer ); + if( FAILED( hr ) ) + return hr; + + const float* pAudioData = nullptr; + DWORD cbBuffer; + hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer ); + if( FAILED( hr ) ) + return hr; + + assert( 0 == ( cbBuffer % sizeof( float ) ) ); + const size_t countFloats = cbBuffer / sizeof( float ); + if( mono ) + samples += countFloats; + else + { + assert( 0 == countFloats % 2 ); + samples += countFloats / 2; + } + + // Unlock the buffer + hr = buffer->Unlock(); + if( FAILED( hr ) ) + return hr; + } + + // Rewind the stream to beginning + PROPVARIANT pv; + PropVariantInit( &pv ); + pv.vt = VT_I8; + pv.hVal.QuadPart = 0; + CHECK( reader->SetCurrentPosition( GUID_NULL, pv ) ); + + // Make the output value + length = samples / FFT_STEP; + + // Store the actual samples count in the reader + // This way the iAudioReader.getDuration() API returns correct value to the user + setPreciseSamplesCount( iar, samples ); + + return S_OK; + } + + HRESULT getDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar ) + { + HRESULT hr = isMp3Decoder( reader ); + if( SUCCEEDED( hr ) ) + { + if( S_OK == hr ) + { + return getPreciseDuration( reader, length, mono, iar ); + } + } + else + logWarningHr( hr, u8"isMp3Decoder" ); + + // Find out the length + int64_t durationTicks; + CHECK( getStreamDuration( reader, durationTicks ) ); + + // Convert length to chunks + // Seconds = Ticks / 10^7 + // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7 + // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down + constexpr __int64 mul = SAMPLE_RATE; + constexpr __int64 div = (__int64)FFT_STEP * 10'000'000; + length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 ); + return S_OK; + } } -PcmReader::PcmReader( IMFSourceReader* reader, bool stereo ) +PcmReader::PcmReader( const iAudioReader* iar ) { - if( nullptr == reader ) + if( nullptr == iar ) throw E_POINTER; - this->reader = reader; + + check( iar->getReader( &reader ) ); + const bool stereo = iar->requestedStereo() == S_OK; // Set up media type, and figure out sample handler check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); @@ -142,17 +297,10 @@ PcmReader::PcmReader( IMFSourceReader* reader, bool stereo ) check( createMediaType( !sourceMono, &mt ) ); check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) ); - // Find out the length - int64_t durationTicks; - check( getStreamDuration( reader, durationTicks ) ); - - // Convert length to chunks - // Seconds = Ticks / 10^7 - // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7 - // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down - constexpr __int64 mul = SAMPLE_RATE; - constexpr __int64 div = (__int64)FFT_STEP * 10'000'000; - m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 ); + // Find out the length. + // Sadly, broken Microsoft's MP3 decoder MFT made this much harder than necessary: + // https://github.com/Const-me/Whisper/issues/4 + check( getDuration( reader, m_length, sourceMono, iar ) ); } HRESULT PcmReader::readNextSample() diff --git a/Whisper/MF/PcmReader.h b/Whisper/MF/PcmReader.h index 9e3757e..aebfc1f 100644 --- a/Whisper/MF/PcmReader.h +++ b/Whisper/MF/PcmReader.h @@ -3,6 +3,7 @@ #include <mfidl.h> #include <mfreadwrite.h> #include "AudioBuffer.h" +#include "../API/iMediaFoundation.cl.h" namespace Whisper { @@ -44,7 +45,7 @@ namespace Whisper public: - PcmReader( IMFSourceReader* source, bool stereo ); + PcmReader( const iAudioReader* reader ); // Count of chunks in the MEL spectrogram. // The PCM audio is generally slightly longer than that, due to the incomplete last chunk. diff --git a/Whisper/MF/mfUtils.h b/Whisper/MF/mfUtils.h index c889a92..f67114d 100644 --- a/Whisper/MF/mfUtils.h +++ b/Whisper/MF/mfUtils.h @@ -12,4 +12,7 @@ namespace Whisper HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration ); HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels ); + + struct iAudioReader; + void setPreciseSamplesCount( const iAudioReader* ar, int64_t count ); }
\ No newline at end of file diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp index 9ce5000..9a156fb 100644 --- a/Whisper/Whisper/ContextImpl.misc.cpp +++ b/Whisper/Whisper/ContextImpl.misc.cpp @@ -384,20 +384,16 @@ HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const mediaTimeOffset = 0; auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::RunComplete ); - CComPtr<IMFSourceReader> mfReader; - CHECK( reader->getReader( &mfReader ) ); - const bool stereo = reader->requestedStereo() == S_OK; - try { if( params.cpuThreads > 1 ) { - MelStreamerThread mel{ model.filters, profiler, mfReader, params.cpuThreads, stereo }; + MelStreamerThread mel{ model.filters, profiler, reader, params.cpuThreads }; return runFullImpl( params, progress, mel ); } else { - MelStreamerSimple mel{ model.filters, profiler, mfReader, stereo }; + MelStreamerSimple mel{ model.filters, profiler, reader }; return runFullImpl( params, progress, mel ); } } diff --git a/Whisper/Whisper/MelStreamer.cpp b/Whisper/Whisper/MelStreamer.cpp index 54268f2..a7c2472 100644 --- a/Whisper/Whisper/MelStreamer.cpp +++ b/Whisper/Whisper/MelStreamer.cpp @@ -3,8 +3,8 @@ #include "../Utils/parallelFor.h" using namespace Whisper; -MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, IMFSourceReader* source, bool stereo ) : - reader( source, stereo ), +MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, const iAudioReader* iar ) : + reader( iar ), melContext( filters ), profiler( prof ) { } @@ -231,8 +231,8 @@ HRESULT MelStreamerSimple::makeBuffer( size_t off, size_t len, const float** buf return S_OK; } -MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ) : - MelStreamer( filters, profiler, source, stereo ), +MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, const iAudioReader* iar, int countThreads ) : + MelStreamer( filters, profiler, iar ), workerThreads( countThreads ) { if( workerThreads > 1 ) diff --git a/Whisper/Whisper/MelStreamer.h b/Whisper/Whisper/MelStreamer.h index 152c1b6..2a891ca 100644 --- a/Whisper/Whisper/MelStreamer.h +++ b/Whisper/Whisper/MelStreamer.h @@ -6,6 +6,7 @@ #include <atlbase.h> #include "../Utils/parallelFor.h" #include "../Utils/ProfileCollection.h" +#include "../API/iMediaFoundation.cl.h" namespace Whisper { @@ -44,7 +45,7 @@ namespace Whisper size_t getLength() const noexcept override final { return reader.getLength(); } public: - MelStreamer( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ); + MelStreamer( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader ); }; // Single-threaded MEL streamer: runs these FFTs on-demand, from within makeBuffer() method @@ -53,8 +54,8 @@ namespace Whisper HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final; public: - MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ) : - MelStreamer( filters, profiler, source, stereo ) { } + MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader ) : + MelStreamer( filters, profiler, reader ) { } }; // Multi threaded MEL streamers: runs FFT on a background thread ahead of time @@ -92,7 +93,7 @@ namespace Whisper public: - MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ); + MelStreamerThread( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader, int countThreads ); ~MelStreamerThread(); }; |
