diff options
Diffstat (limited to 'Whisper/MF/PcmReader.cpp')
| -rw-r--r-- | Whisper/MF/PcmReader.cpp | 176 |
1 files changed, 162 insertions, 14 deletions
diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp index ab92fc3..16f38c6 100644 --- a/Whisper/MF/PcmReader.cpp +++ b/Whisper/MF/PcmReader.cpp @@ -110,13 +110,168 @@ namespace static const HandlerMono s_mono; static const HandlerDownmixedStereo s_downmix; static const HandlerStereo s_stereo; + + __forceinline __m128i load( const GUID& guid ) + { + return _mm_loadu_si128( ( const __m128i* )( &guid ) ); + } + + // Find audio decoder MFT, query MF_MT_SUBTYPE attribute of the current input media type of that MFT + HRESULT getDecoderInputSubtype( IMFSourceReader* reader, __m128i& rdi ) + { + store16( &rdi, _mm_setzero_si128() ); + + CComPtr<IMFSourceReaderEx> readerEx; + CHECK( reader->QueryInterface( &readerEx ) ); + constexpr uint32_t stream = MF_SOURCE_READER_FIRST_AUDIO_STREAM; + const __m128i decGuid = load( MFT_CATEGORY_AUDIO_DECODER ); + alignas( 16 ) GUID category; + for( DWORD i = 0; true; i++ ) + { + CComPtr<IMFTransform> mft; + HRESULT hr = readerEx->GetTransformForStream( stream, i, &category, &mft ); + if( FAILED( hr ) ) + return hr; + const __m128i cat = _mm_load_si128( ( const __m128i* ) & category ); + if( !vectorEqual( decGuid, cat ) ) + continue; + + CComPtr<IMFMediaType> mt; + CHECK( mft->GetInputCurrentType( 0, &mt ) ); + CHECK( mt->GetGUID( MF_MT_SUBTYPE, (GUID*)&rdi ) ); + return S_OK; + } + } + + // S_OK when the reader has an MP3 decoder for the first audio stream, S_FALSE otherwise + HRESULT isMp3Decoder( IMFSourceReader* reader ) + { + __m128i subtype; + CHECK( getDecoderInputSubtype( reader, subtype ) ); + const bool res = vectorEqual( subtype, load( MFAudioFormat_MP3 ) ); + return res ? S_OK : S_FALSE; + } + + // Workaround for a Microsoft's bug in Media Foundation MP3 decoder: https://github.com/Const-me/Whisper/issues/4 + // Media Foundation is reporting incorrect media duration = 12.54. Windows Media Player does the same. + // Winamp and Media Player Classic are reporting 12:35, VLC reports 12:36. + HRESULT getPreciseDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar ) + { + size_t samples = 0; + + // Decode the complete stream, counting samples + while( true ) + { + DWORD dwFlags = 0; + CComPtr<IMFSample> sample; + + // Read the next sample + HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample ); + if( FAILED( hr ) ) + { + logErrorHr( hr, u8"IMFSourceReader.ReadSample" ); + return hr; + } + + if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED ) + { + // logError( u8"Media type changes ain’t supported by the library." ); + // return E_UNEXPECTED; + + // This happens for some video files at the very start of the reading, with Dolby AC3 audio track. + // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed. + CHECK( validateCurrentMediaType( reader, mono ? 1 : 2 ) ); + } + + if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM ) + break; + + if( !sample ) + { + // printf( "No sample\n" ); + continue; + } + + // Get a pointer to the audio data in the sample. + CComPtr<IMFMediaBuffer> buffer; + hr = sample->ConvertToContiguousBuffer( &buffer ); + if( FAILED( hr ) ) + return hr; + + const float* pAudioData = nullptr; + DWORD cbBuffer; + hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer ); + if( FAILED( hr ) ) + return hr; + + assert( 0 == ( cbBuffer % sizeof( float ) ) ); + const size_t countFloats = cbBuffer / sizeof( float ); + if( mono ) + samples += countFloats; + else + { + assert( 0 == countFloats % 2 ); + samples += countFloats / 2; + } + + // Unlock the buffer + hr = buffer->Unlock(); + if( FAILED( hr ) ) + return hr; + } + + // Rewind the stream to beginning + PROPVARIANT pv; + PropVariantInit( &pv ); + pv.vt = VT_I8; + pv.hVal.QuadPart = 0; + CHECK( reader->SetCurrentPosition( GUID_NULL, pv ) ); + + // Make the output value + length = samples / FFT_STEP; + + // Store the actual samples count in the reader + // This way the iAudioReader.getDuration() API returns correct value to the user + setPreciseSamplesCount( iar, samples ); + + return S_OK; + } + + HRESULT getDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar ) + { + HRESULT hr = isMp3Decoder( reader ); + if( SUCCEEDED( hr ) ) + { + if( S_OK == hr ) + { + return getPreciseDuration( reader, length, mono, iar ); + } + } + else + logWarningHr( hr, u8"isMp3Decoder" ); + + // Find out the length + int64_t durationTicks; + CHECK( getStreamDuration( reader, durationTicks ) ); + + // Convert length to chunks + // Seconds = Ticks / 10^7 + // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7 + // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down + constexpr __int64 mul = SAMPLE_RATE; + constexpr __int64 div = (__int64)FFT_STEP * 10'000'000; + length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 ); + return S_OK; + } } -PcmReader::PcmReader( IMFSourceReader* reader, bool stereo ) +PcmReader::PcmReader( const iAudioReader* iar ) { - if( nullptr == reader ) + if( nullptr == iar ) throw E_POINTER; - this->reader = reader; + + check( iar->getReader( &reader ) ); + const bool stereo = iar->requestedStereo() == S_OK; // Set up media type, and figure out sample handler check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) ); @@ -142,17 +297,10 @@ PcmReader::PcmReader( IMFSourceReader* reader, bool stereo ) check( createMediaType( !sourceMono, &mt ) ); check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) ); - // Find out the length - int64_t durationTicks; - check( getStreamDuration( reader, durationTicks ) ); - - // Convert length to chunks - // Seconds = Ticks / 10^7 - // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7 - // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down - constexpr __int64 mul = SAMPLE_RATE; - constexpr __int64 div = (__int64)FFT_STEP * 10'000'000; - m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 ); + // Find out the length. + // Sadly, broken Microsoft's MP3 decoder MFT made this much harder than necessary: + // https://github.com/Const-me/Whisper/issues/4 + check( getDuration( reader, m_length, sourceMono, iar ) ); } HRESULT PcmReader::readNextSample() |
