summaryrefslogtreecommitdiffstats
path: root/Whisper/MF
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-19 17:10:24 +0100
committerKonstantin <const@const.me>2023-01-19 17:10:24 +0100
commit9df2ee2ead4ce23d06351a6cdb4fea588f79e429 (patch)
treed365bc24b192e3929801a4ede5b26e74a6c9e77f /Whisper/MF
parent06643094c166b0e80fb8f5f506f5e9d42a90c2bf (diff)
Workaround for the Microsoft’s bug in their MP3 decoder MFT
Diffstat (limited to 'Whisper/MF')
-rw-r--r--Whisper/MF/MediaFoundation.cpp21
-rw-r--r--Whisper/MF/PcmReader.cpp176
-rw-r--r--Whisper/MF/PcmReader.h3
-rw-r--r--Whisper/MF/mfUtils.h3
4 files changed, 187 insertions, 16 deletions
diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp
index 4a4f6a2..df6990c 100644
--- a/Whisper/MF/MediaFoundation.cpp
+++ b/Whisper/MF/MediaFoundation.cpp
@@ -7,6 +7,7 @@
#include <mfreadwrite.h>
#include "mfUtils.h"
#include "AudioCapture.h"
+#include <mfapi.h>
namespace Whisper
{
@@ -15,6 +16,7 @@ namespace Whisper
CComPtr<IMFSourceReader> reader;
bool wantStereo;
CComPtr<iMediaFoundation> mediaFoundation;
+ mutable int64_t preciseSamplesCount = 0;
HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final
{
@@ -31,7 +33,14 @@ namespace Whisper
HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const noexcept override final
{
if( reader )
- return getStreamDuration( reader, rdi );
+ {
+ if( 0 == preciseSamplesCount )
+ return getStreamDuration( reader, rdi );
+ else
+ { rdi = MFllMulDiv( preciseSamplesCount, 10'000'000, SAMPLE_RATE, 0 );
+ return S_OK;
+ }
+ }
return OLE_E_BLANK;
}
public:
@@ -48,8 +57,18 @@ namespace Whisper
logDebug16( L"Created source reader from the file \"%s\"", path );
return S_OK;
}
+ void setPreciseSamplesCount( int64_t count ) const
+ {
+ preciseSamplesCount = count;
+ }
};
+ void setPreciseSamplesCount( const iAudioReader* ar, int64_t count )
+ {
+ const AudioReader* r = static_cast<const AudioReader*>( ar );
+ r->setPreciseSamplesCount( count );
+ }
+
class MediaFoundation : public ComLight::ObjectRoot<iMediaFoundation>
{
MfStartupRaii raii;
diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp
index ab92fc3..16f38c6 100644
--- a/Whisper/MF/PcmReader.cpp
+++ b/Whisper/MF/PcmReader.cpp
@@ -110,13 +110,168 @@ namespace
static const HandlerMono s_mono;
static const HandlerDownmixedStereo s_downmix;
static const HandlerStereo s_stereo;
+
+ __forceinline __m128i load( const GUID& guid )
+ {
+ return _mm_loadu_si128( ( const __m128i* )( &guid ) );
+ }
+
+ // Find audio decoder MFT, query MF_MT_SUBTYPE attribute of the current input media type of that MFT
+ HRESULT getDecoderInputSubtype( IMFSourceReader* reader, __m128i& rdi )
+ {
+ store16( &rdi, _mm_setzero_si128() );
+
+ CComPtr<IMFSourceReaderEx> readerEx;
+ CHECK( reader->QueryInterface( &readerEx ) );
+ constexpr uint32_t stream = MF_SOURCE_READER_FIRST_AUDIO_STREAM;
+ const __m128i decGuid = load( MFT_CATEGORY_AUDIO_DECODER );
+ alignas( 16 ) GUID category;
+ for( DWORD i = 0; true; i++ )
+ {
+ CComPtr<IMFTransform> mft;
+ HRESULT hr = readerEx->GetTransformForStream( stream, i, &category, &mft );
+ if( FAILED( hr ) )
+ return hr;
+ const __m128i cat = _mm_load_si128( ( const __m128i* ) & category );
+ if( !vectorEqual( decGuid, cat ) )
+ continue;
+
+ CComPtr<IMFMediaType> mt;
+ CHECK( mft->GetInputCurrentType( 0, &mt ) );
+ CHECK( mt->GetGUID( MF_MT_SUBTYPE, (GUID*)&rdi ) );
+ return S_OK;
+ }
+ }
+
+ // S_OK when the reader has an MP3 decoder for the first audio stream, S_FALSE otherwise
+ HRESULT isMp3Decoder( IMFSourceReader* reader )
+ {
+ __m128i subtype;
+ CHECK( getDecoderInputSubtype( reader, subtype ) );
+ const bool res = vectorEqual( subtype, load( MFAudioFormat_MP3 ) );
+ return res ? S_OK : S_FALSE;
+ }
+
+ // Workaround for a Microsoft's bug in Media Foundation MP3 decoder: https://github.com/Const-me/Whisper/issues/4
+ // Media Foundation is reporting incorrect media duration = 12.54. Windows Media Player does the same.
+ // Winamp and Media Player Classic are reporting 12:35, VLC reports 12:36.
+ HRESULT getPreciseDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar )
+ {
+ size_t samples = 0;
+
+ // Decode the complete stream, counting samples
+ while( true )
+ {
+ DWORD dwFlags = 0;
+ CComPtr<IMFSample> sample;
+
+ // Read the next sample
+ HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
+ return hr;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED )
+ {
+ // logError( u8"Media type changes ain’t supported by the library." );
+ // return E_UNEXPECTED;
+
+ // This happens for some video files at the very start of the reading, with Dolby AC3 audio track.
+ // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed.
+ CHECK( validateCurrentMediaType( reader, mono ? 1 : 2 ) );
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM )
+ break;
+
+ if( !sample )
+ {
+ // printf( "No sample\n" );
+ continue;
+ }
+
+ // Get a pointer to the audio data in the sample.
+ CComPtr<IMFMediaBuffer> buffer;
+ hr = sample->ConvertToContiguousBuffer( &buffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ const float* pAudioData = nullptr;
+ DWORD cbBuffer;
+ hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ assert( 0 == ( cbBuffer % sizeof( float ) ) );
+ const size_t countFloats = cbBuffer / sizeof( float );
+ if( mono )
+ samples += countFloats;
+ else
+ {
+ assert( 0 == countFloats % 2 );
+ samples += countFloats / 2;
+ }
+
+ // Unlock the buffer
+ hr = buffer->Unlock();
+ if( FAILED( hr ) )
+ return hr;
+ }
+
+ // Rewind the stream to beginning
+ PROPVARIANT pv;
+ PropVariantInit( &pv );
+ pv.vt = VT_I8;
+ pv.hVal.QuadPart = 0;
+ CHECK( reader->SetCurrentPosition( GUID_NULL, pv ) );
+
+ // Make the output value
+ length = samples / FFT_STEP;
+
+ // Store the actual samples count in the reader
+ // This way the iAudioReader.getDuration() API returns correct value to the user
+ setPreciseSamplesCount( iar, samples );
+
+ return S_OK;
+ }
+
+ HRESULT getDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar )
+ {
+ HRESULT hr = isMp3Decoder( reader );
+ if( SUCCEEDED( hr ) )
+ {
+ if( S_OK == hr )
+ {
+ return getPreciseDuration( reader, length, mono, iar );
+ }
+ }
+ else
+ logWarningHr( hr, u8"isMp3Decoder" );
+
+ // Find out the length
+ int64_t durationTicks;
+ CHECK( getStreamDuration( reader, durationTicks ) );
+
+ // Convert length to chunks
+ // Seconds = Ticks / 10^7
+ // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7
+ // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down
+ constexpr __int64 mul = SAMPLE_RATE;
+ constexpr __int64 div = (__int64)FFT_STEP * 10'000'000;
+ length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 );
+ return S_OK;
+ }
}
-PcmReader::PcmReader( IMFSourceReader* reader, bool stereo )
+PcmReader::PcmReader( const iAudioReader* iar )
{
- if( nullptr == reader )
+ if( nullptr == iar )
throw E_POINTER;
- this->reader = reader;
+
+ check( iar->getReader( &reader ) );
+ const bool stereo = iar->requestedStereo() == S_OK;
// Set up media type, and figure out sample handler
check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
@@ -142,17 +297,10 @@ PcmReader::PcmReader( IMFSourceReader* reader, bool stereo )
check( createMediaType( !sourceMono, &mt ) );
check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) );
- // Find out the length
- int64_t durationTicks;
- check( getStreamDuration( reader, durationTicks ) );
-
- // Convert length to chunks
- // Seconds = Ticks / 10^7
- // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7
- // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down
- constexpr __int64 mul = SAMPLE_RATE;
- constexpr __int64 div = (__int64)FFT_STEP * 10'000'000;
- m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 );
+ // Find out the length.
+ // Sadly, broken Microsoft's MP3 decoder MFT made this much harder than necessary:
+ // https://github.com/Const-me/Whisper/issues/4
+ check( getDuration( reader, m_length, sourceMono, iar ) );
}
HRESULT PcmReader::readNextSample()
diff --git a/Whisper/MF/PcmReader.h b/Whisper/MF/PcmReader.h
index 9e3757e..aebfc1f 100644
--- a/Whisper/MF/PcmReader.h
+++ b/Whisper/MF/PcmReader.h
@@ -3,6 +3,7 @@
#include <mfidl.h>
#include <mfreadwrite.h>
#include "AudioBuffer.h"
+#include "../API/iMediaFoundation.cl.h"
namespace Whisper
{
@@ -44,7 +45,7 @@ namespace Whisper
public:
- PcmReader( IMFSourceReader* source, bool stereo );
+ PcmReader( const iAudioReader* reader );
// Count of chunks in the MEL spectrogram.
// The PCM audio is generally slightly longer than that, due to the incomplete last chunk.
diff --git a/Whisper/MF/mfUtils.h b/Whisper/MF/mfUtils.h
index c889a92..f67114d 100644
--- a/Whisper/MF/mfUtils.h
+++ b/Whisper/MF/mfUtils.h
@@ -12,4 +12,7 @@ namespace Whisper
HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration );
HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels );
+
+ struct iAudioReader;
+ void setPreciseSamplesCount( const iAudioReader* ar, int64_t count );
} \ No newline at end of file