summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Examples/WhisperDesktop/TranscribeDlg.cpp6
-rw-r--r--Readme.md2
-rw-r--r--Whisper/MF/MediaFoundation.cpp21
-rw-r--r--Whisper/MF/PcmReader.cpp176
-rw-r--r--Whisper/MF/PcmReader.h3
-rw-r--r--Whisper/MF/mfUtils.h3
-rw-r--r--Whisper/Whisper/ContextImpl.misc.cpp8
-rw-r--r--Whisper/Whisper/MelStreamer.cpp8
-rw-r--r--Whisper/Whisper/MelStreamer.h9
9 files changed, 204 insertions, 32 deletions
diff --git a/Examples/WhisperDesktop/TranscribeDlg.cpp b/Examples/WhisperDesktop/TranscribeDlg.cpp
index cd23c93..b99e98a 100644
--- a/Examples/WhisperDesktop/TranscribeDlg.cpp
+++ b/Examples/WhisperDesktop/TranscribeDlg.cpp
@@ -298,7 +298,6 @@ HRESULT TranscribeDlg::transcribe()
CComPtr<iAudioReader> reader;
CHECK_EX( appState.mediaFoundation->openAudioFile( transcribeArgs.pathMedia, false, &reader ) );
- CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) );
const eOutputFormat format = transcribeArgs.format;
CAtlFile outputFile;
@@ -324,6 +323,11 @@ HRESULT TranscribeDlg::transcribe()
// Run the transcribe
CHECK_EX( context->runStreamed( fullParams, progressSink, reader ) );
+ // Once finished, query duration of the audio.
+ // The duration before the processing is sometimes different, by 20 seconds for the file in that issue:
+ // https://github.com/Const-me/Whisper/issues/4
+ CHECK_EX( reader->getDuration( transcribeArgs.mediaDuration ) );
+
context->timingsPrint();
if( format == eOutputFormat::None )
diff --git a/Readme.md b/Readme.md
index cf8d25d..eea33cb 100644
--- a/Readme.md
+++ b/Readme.md
@@ -47,7 +47,7 @@ The implementation is based on the [2009 article](https://www.researchgate.net/p
* Pre-built binaries available
The only supported platform is 64-bit Windows.<br/>
-Should work on Windows 8.0 or newer, but I have only tested on Windows 10.<br/>
+Should work on Windows 8.1 or newer, but I have only tested on Windows 10.<br/>
The library requires a Direct3D 11.0 capable GPU, which in 2023 simply means “any hardware GPU”.
The most recent GPU without D3D 11.0 support was Intel [Sandy Bridge](https://en.wikipedia.org/wiki/Sandy_Bridge) from 2011.
diff --git a/Whisper/MF/MediaFoundation.cpp b/Whisper/MF/MediaFoundation.cpp
index 4a4f6a2..df6990c 100644
--- a/Whisper/MF/MediaFoundation.cpp
+++ b/Whisper/MF/MediaFoundation.cpp
@@ -7,6 +7,7 @@
#include <mfreadwrite.h>
#include "mfUtils.h"
#include "AudioCapture.h"
+#include <mfapi.h>
namespace Whisper
{
@@ -15,6 +16,7 @@ namespace Whisper
CComPtr<IMFSourceReader> reader;
bool wantStereo;
CComPtr<iMediaFoundation> mediaFoundation;
+ mutable int64_t preciseSamplesCount = 0;
HRESULT COMLIGHTCALL getReader( IMFSourceReader** pp ) const noexcept override final
{
@@ -31,7 +33,14 @@ namespace Whisper
HRESULT COMLIGHTCALL getDuration( int64_t& rdi ) const noexcept override final
{
if( reader )
- return getStreamDuration( reader, rdi );
+ {
+ if( 0 == preciseSamplesCount )
+ return getStreamDuration( reader, rdi );
+ else
+ { rdi = MFllMulDiv( preciseSamplesCount, 10'000'000, SAMPLE_RATE, 0 );
+ return S_OK;
+ }
+ }
return OLE_E_BLANK;
}
public:
@@ -48,8 +57,18 @@ namespace Whisper
logDebug16( L"Created source reader from the file \"%s\"", path );
return S_OK;
}
+ void setPreciseSamplesCount( int64_t count ) const
+ {
+ preciseSamplesCount = count;
+ }
};
+ void setPreciseSamplesCount( const iAudioReader* ar, int64_t count )
+ {
+ const AudioReader* r = static_cast<const AudioReader*>( ar );
+ r->setPreciseSamplesCount( count );
+ }
+
class MediaFoundation : public ComLight::ObjectRoot<iMediaFoundation>
{
MfStartupRaii raii;
diff --git a/Whisper/MF/PcmReader.cpp b/Whisper/MF/PcmReader.cpp
index ab92fc3..16f38c6 100644
--- a/Whisper/MF/PcmReader.cpp
+++ b/Whisper/MF/PcmReader.cpp
@@ -110,13 +110,168 @@ namespace
static const HandlerMono s_mono;
static const HandlerDownmixedStereo s_downmix;
static const HandlerStereo s_stereo;
+
+ __forceinline __m128i load( const GUID& guid )
+ {
+ return _mm_loadu_si128( ( const __m128i* )( &guid ) );
+ }
+
+ // Find audio decoder MFT, query MF_MT_SUBTYPE attribute of the current input media type of that MFT
+ HRESULT getDecoderInputSubtype( IMFSourceReader* reader, __m128i& rdi )
+ {
+ store16( &rdi, _mm_setzero_si128() );
+
+ CComPtr<IMFSourceReaderEx> readerEx;
+ CHECK( reader->QueryInterface( &readerEx ) );
+ constexpr uint32_t stream = MF_SOURCE_READER_FIRST_AUDIO_STREAM;
+ const __m128i decGuid = load( MFT_CATEGORY_AUDIO_DECODER );
+ alignas( 16 ) GUID category;
+ for( DWORD i = 0; true; i++ )
+ {
+ CComPtr<IMFTransform> mft;
+ HRESULT hr = readerEx->GetTransformForStream( stream, i, &category, &mft );
+ if( FAILED( hr ) )
+ return hr;
+ const __m128i cat = _mm_load_si128( ( const __m128i* ) & category );
+ if( !vectorEqual( decGuid, cat ) )
+ continue;
+
+ CComPtr<IMFMediaType> mt;
+ CHECK( mft->GetInputCurrentType( 0, &mt ) );
+ CHECK( mt->GetGUID( MF_MT_SUBTYPE, (GUID*)&rdi ) );
+ return S_OK;
+ }
+ }
+
+ // S_OK when the reader has an MP3 decoder for the first audio stream, S_FALSE otherwise
+ HRESULT isMp3Decoder( IMFSourceReader* reader )
+ {
+ __m128i subtype;
+ CHECK( getDecoderInputSubtype( reader, subtype ) );
+ const bool res = vectorEqual( subtype, load( MFAudioFormat_MP3 ) );
+ return res ? S_OK : S_FALSE;
+ }
+
+ // Workaround for a Microsoft's bug in Media Foundation MP3 decoder: https://github.com/Const-me/Whisper/issues/4
+ // Media Foundation is reporting incorrect media duration = 12.54. Windows Media Player does the same.
+ // Winamp and Media Player Classic are reporting 12:35, VLC reports 12:36.
+ HRESULT getPreciseDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar )
+ {
+ size_t samples = 0;
+
+ // Decode the complete stream, counting samples
+ while( true )
+ {
+ DWORD dwFlags = 0;
+ CComPtr<IMFSample> sample;
+
+ // Read the next sample
+ HRESULT hr = reader->ReadSample( (DWORD)MF_SOURCE_READER_FIRST_AUDIO_STREAM, 0, nullptr, &dwFlags, nullptr, &sample );
+ if( FAILED( hr ) )
+ {
+ logErrorHr( hr, u8"IMFSourceReader.ReadSample" );
+ return hr;
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_CURRENTMEDIATYPECHANGED )
+ {
+ // logError( u8"Media type changes ain’t supported by the library." );
+ // return E_UNEXPECTED;
+
+ // This happens for some video files at the very start of the reading, with Dolby AC3 audio track.
+ // Instead of failing the transcribe process, verify the important attributes (FP32 samples, sample rate, count of channels) haven’t changed.
+ CHECK( validateCurrentMediaType( reader, mono ? 1 : 2 ) );
+ }
+
+ if( dwFlags & MF_SOURCE_READERF_ENDOFSTREAM )
+ break;
+
+ if( !sample )
+ {
+ // printf( "No sample\n" );
+ continue;
+ }
+
+ // Get a pointer to the audio data in the sample.
+ CComPtr<IMFMediaBuffer> buffer;
+ hr = sample->ConvertToContiguousBuffer( &buffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ const float* pAudioData = nullptr;
+ DWORD cbBuffer;
+ hr = buffer->Lock( (BYTE**)&pAudioData, nullptr, &cbBuffer );
+ if( FAILED( hr ) )
+ return hr;
+
+ assert( 0 == ( cbBuffer % sizeof( float ) ) );
+ const size_t countFloats = cbBuffer / sizeof( float );
+ if( mono )
+ samples += countFloats;
+ else
+ {
+ assert( 0 == countFloats % 2 );
+ samples += countFloats / 2;
+ }
+
+ // Unlock the buffer
+ hr = buffer->Unlock();
+ if( FAILED( hr ) )
+ return hr;
+ }
+
+ // Rewind the stream to beginning
+ PROPVARIANT pv;
+ PropVariantInit( &pv );
+ pv.vt = VT_I8;
+ pv.hVal.QuadPart = 0;
+ CHECK( reader->SetCurrentPosition( GUID_NULL, pv ) );
+
+ // Make the output value
+ length = samples / FFT_STEP;
+
+ // Store the actual samples count in the reader
+ // This way the iAudioReader.getDuration() API returns correct value to the user
+ setPreciseSamplesCount( iar, samples );
+
+ return S_OK;
+ }
+
+ HRESULT getDuration( IMFSourceReader* reader, size_t& length, bool mono, const iAudioReader* iar )
+ {
+ HRESULT hr = isMp3Decoder( reader );
+ if( SUCCEEDED( hr ) )
+ {
+ if( S_OK == hr )
+ {
+ return getPreciseDuration( reader, length, mono, iar );
+ }
+ }
+ else
+ logWarningHr( hr, u8"isMp3Decoder" );
+
+ // Find out the length
+ int64_t durationTicks;
+ CHECK( getStreamDuration( reader, durationTicks ) );
+
+ // Convert length to chunks
+ // Seconds = Ticks / 10^7
+ // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7
+ // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down
+ constexpr __int64 mul = SAMPLE_RATE;
+ constexpr __int64 div = (__int64)FFT_STEP * 10'000'000;
+ length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 );
+ return S_OK;
+ }
}
-PcmReader::PcmReader( IMFSourceReader* reader, bool stereo )
+PcmReader::PcmReader( const iAudioReader* iar )
{
- if( nullptr == reader )
+ if( nullptr == iar )
throw E_POINTER;
- this->reader = reader;
+
+ check( iar->getReader( &reader ) );
+ const bool stereo = iar->requestedStereo() == S_OK;
// Set up media type, and figure out sample handler
check( reader->SetStreamSelection( MF_SOURCE_READER_ALL_STREAMS, FALSE ) );
@@ -142,17 +297,10 @@ PcmReader::PcmReader( IMFSourceReader* reader, bool stereo )
check( createMediaType( !sourceMono, &mt ) );
check( reader->SetCurrentMediaType( MF_SOURCE_READER_FIRST_AUDIO_STREAM, nullptr, mt ) );
- // Find out the length
- int64_t durationTicks;
- check( getStreamDuration( reader, durationTicks ) );
-
- // Convert length to chunks
- // Seconds = Ticks / 10^7
- // Samples = Seconds * SAMPLE_RATE = Ticks * SAMPLE_RATE / 10^7
- // Chunks = Samples / FFT_STEP = Ticks * SAMPLE_RATE / ( FFT_STEP * 10^7 ), and we want that integer rounded down
- constexpr __int64 mul = SAMPLE_RATE;
- constexpr __int64 div = (__int64)FFT_STEP * 10'000'000;
- m_length = (size_t)MFllMulDiv( durationTicks, mul, div, 0 );
+ // Find out the length.
+ // Sadly, broken Microsoft's MP3 decoder MFT made this much harder than necessary:
+ // https://github.com/Const-me/Whisper/issues/4
+ check( getDuration( reader, m_length, sourceMono, iar ) );
}
HRESULT PcmReader::readNextSample()
diff --git a/Whisper/MF/PcmReader.h b/Whisper/MF/PcmReader.h
index 9e3757e..aebfc1f 100644
--- a/Whisper/MF/PcmReader.h
+++ b/Whisper/MF/PcmReader.h
@@ -3,6 +3,7 @@
#include <mfidl.h>
#include <mfreadwrite.h>
#include "AudioBuffer.h"
+#include "../API/iMediaFoundation.cl.h"
namespace Whisper
{
@@ -44,7 +45,7 @@ namespace Whisper
public:
- PcmReader( IMFSourceReader* source, bool stereo );
+ PcmReader( const iAudioReader* reader );
// Count of chunks in the MEL spectrogram.
// The PCM audio is generally slightly longer than that, due to the incomplete last chunk.
diff --git a/Whisper/MF/mfUtils.h b/Whisper/MF/mfUtils.h
index c889a92..f67114d 100644
--- a/Whisper/MF/mfUtils.h
+++ b/Whisper/MF/mfUtils.h
@@ -12,4 +12,7 @@ namespace Whisper
HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration );
HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels );
+
+ struct iAudioReader;
+ void setPreciseSamplesCount( const iAudioReader* ar, int64_t count );
} \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp
index 9ce5000..9a156fb 100644
--- a/Whisper/Whisper/ContextImpl.misc.cpp
+++ b/Whisper/Whisper/ContextImpl.misc.cpp
@@ -384,20 +384,16 @@ HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const
mediaTimeOffset = 0;
auto profCompleteCpu = profiler.cpuBlock( eCpuBlock::RunComplete );
- CComPtr<IMFSourceReader> mfReader;
- CHECK( reader->getReader( &mfReader ) );
- const bool stereo = reader->requestedStereo() == S_OK;
-
try
{
if( params.cpuThreads > 1 )
{
- MelStreamerThread mel{ model.filters, profiler, mfReader, params.cpuThreads, stereo };
+ MelStreamerThread mel{ model.filters, profiler, reader, params.cpuThreads };
return runFullImpl( params, progress, mel );
}
else
{
- MelStreamerSimple mel{ model.filters, profiler, mfReader, stereo };
+ MelStreamerSimple mel{ model.filters, profiler, reader };
return runFullImpl( params, progress, mel );
}
}
diff --git a/Whisper/Whisper/MelStreamer.cpp b/Whisper/Whisper/MelStreamer.cpp
index 54268f2..a7c2472 100644
--- a/Whisper/Whisper/MelStreamer.cpp
+++ b/Whisper/Whisper/MelStreamer.cpp
@@ -3,8 +3,8 @@
#include "../Utils/parallelFor.h"
using namespace Whisper;
-MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, IMFSourceReader* source, bool stereo ) :
- reader( source, stereo ),
+MelStreamer::MelStreamer( const Filters& filters, ProfileCollection& prof, const iAudioReader* iar ) :
+ reader( iar ),
melContext( filters ),
profiler( prof )
{ }
@@ -231,8 +231,8 @@ HRESULT MelStreamerSimple::makeBuffer( size_t off, size_t len, const float** buf
return S_OK;
}
-MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo ) :
- MelStreamer( filters, profiler, source, stereo ),
+MelStreamerThread::MelStreamerThread( const Filters& filters, ProfileCollection& profiler, const iAudioReader* iar, int countThreads ) :
+ MelStreamer( filters, profiler, iar ),
workerThreads( countThreads )
{
if( workerThreads > 1 )
diff --git a/Whisper/Whisper/MelStreamer.h b/Whisper/Whisper/MelStreamer.h
index 152c1b6..2a891ca 100644
--- a/Whisper/Whisper/MelStreamer.h
+++ b/Whisper/Whisper/MelStreamer.h
@@ -6,6 +6,7 @@
#include <atlbase.h>
#include "../Utils/parallelFor.h"
#include "../Utils/ProfileCollection.h"
+#include "../API/iMediaFoundation.cl.h"
namespace Whisper
{
@@ -44,7 +45,7 @@ namespace Whisper
size_t getLength() const noexcept override final { return reader.getLength(); }
public:
- MelStreamer( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo );
+ MelStreamer( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader );
};
// Single-threaded MEL streamer: runs these FFTs on-demand, from within makeBuffer() method
@@ -53,8 +54,8 @@ namespace Whisper
HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ) noexcept override final;
public:
- MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, bool stereo ) :
- MelStreamer( filters, profiler, source, stereo ) { }
+ MelStreamerSimple( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader ) :
+ MelStreamer( filters, profiler, reader ) { }
};
// Multi threaded MEL streamers: runs FFT on a background thread ahead of time
@@ -92,7 +93,7 @@ namespace Whisper
public:
- MelStreamerThread( const Filters& filters, ProfileCollection& profiler, IMFSourceReader* source, int countThreads, bool stereo );
+ MelStreamerThread( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader, int countThreads );
~MelStreamerThread();
};