summaryrefslogtreecommitdiffstats
path: root/Whisper
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-28 18:09:15 +0100
committerKonstantin <const@const.me>2023-01-28 18:09:15 +0100
commit9253de193022e78cc4f91f4f1f7e14ba099e6388 (patch)
tree41cf07bb396f892b362d4fffff20ab6711eec9c8 /Whisper
parentcfd20a0f796ab6cc046b080bb7af8967cb7c361b (diff)
Diarize feature, initial version
Diffstat (limited to 'Whisper')
-rw-r--r--Whisper/Whisper.vcxproj1
-rw-r--r--Whisper/Whisper.vcxproj.filters1
-rw-r--r--Whisper/Whisper/ContextImpl.cpp16
-rw-r--r--Whisper/Whisper/ContextImpl.diarize.cpp97
-rw-r--r--Whisper/Whisper/ContextImpl.h3
-rw-r--r--Whisper/Whisper/ContextImpl.misc.cpp5
-rw-r--r--Whisper/Whisper/MelStreamer.cpp41
-rw-r--r--Whisper/Whisper/MelStreamer.h2
-rw-r--r--Whisper/Whisper/Spectrogram.cpp5
-rw-r--r--Whisper/Whisper/Spectrogram.h2
-rw-r--r--Whisper/Whisper/iSpectrogram.h8
11 files changed, 176 insertions, 5 deletions
diff --git a/Whisper/Whisper.vcxproj b/Whisper/Whisper.vcxproj
index 237db29..0a52e46 100644
--- a/Whisper/Whisper.vcxproj
+++ b/Whisper/Whisper.vcxproj
@@ -139,6 +139,7 @@
<ClCompile Include="Utils\Logger.cpp" />
<ClCompile Include="MF\AudioCapture.cpp" />
<ClCompile Include="Utils\miscUtils.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.diarize.cpp" />
<ClCompile Include="Whisper\voiceActivityDetection.cpp" />
<ClCompile Include="Whisper\ContextImpl.capture.cpp" />
<ClCompile Include="Whisper\MelStreamer.cpp" />
diff --git a/Whisper/Whisper.vcxproj.filters b/Whisper/Whisper.vcxproj.filters
index 0ffd30c..37a9366 100644
--- a/Whisper/Whisper.vcxproj.filters
+++ b/Whisper/Whisper.vcxproj.filters
@@ -84,6 +84,7 @@
<ClCompile Include="CPU\mulMatImpl.panel.cpp" />
<ClCompile Include="ML\Reshaper.cpp" />
<ClCompile Include="Utils\DelayExecution.cpp" />
+ <ClCompile Include="Whisper\ContextImpl.diarize.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="source\ggml.h" />
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 59d644b..eccda35 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -189,6 +189,21 @@ static std::string to_timestamp( int64_t t, bool comma = false )
return std::string( buf );
}
+class ContextImpl::CurrentSpectrogramRaii
+{
+ ContextImpl* ctx;
+public:
+ CurrentSpectrogramRaii( ContextImpl* c, iSpectrogram& mel )
+ {
+ ctx = c;
+ c->currentSpectrogram = &mel;
+ }
+ ~CurrentSpectrogramRaii()
+ {
+ ctx->currentSpectrogram = nullptr;
+ }
+};
+
HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const sProgressSink& progress, iSpectrogram& mel )
{
// Ported from whisper_full() function
@@ -199,6 +214,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
return E_NOTIMPL;
}
+ CurrentSpectrogramRaii _cs( this, mel );
const int seek_start = params.offset_ms / 10;
const int seek_end = seek_start + ( params.duration_ms == 0 ? (int)mel.getLength() : params.duration_ms / 10 );
diff --git a/Whisper/Whisper/ContextImpl.diarize.cpp b/Whisper/Whisper/ContextImpl.diarize.cpp
new file mode 100644
index 0000000..90acb04
--- /dev/null
+++ b/Whisper/Whisper/ContextImpl.diarize.cpp
@@ -0,0 +1,97 @@
+#include "stdafx.h"
+#include "ContextImpl.h"
+using namespace Whisper;
+
+namespace
+{
+ // Offset the timestamp with mediaTimeOffset to find the time relative to the start of the iSpectrogram buffer,
+ // then scale from 100 nanosecond ticks into the Whisper's 10ms chunks, rounding down
+ inline int64_t chunkOffset( int64_t time, int64_t mediaTimeOffset )
+ {
+ time -= mediaTimeOffset;
+ return ( time * 100 ) / 10'000'000;
+ }
+
+ // Compute per-channel sum of std::absf( pcm ) in the specified buffer,
+ // and return left / right numbers in the lower 2 lanes of the SSE vector
+ inline __m128 __vectorcall computeChannelsEnergy( const std::vector<StereoSample>& sourceVector )
+ {
+ const StereoSample* rsi = sourceVector.data();
+ const StereoSample* const rsiEnd = rsi + sourceVector.size();
+ const StereoSample* const rsiEndAligned = rsi + ( sourceVector.size() & ( ~(size_t)1 ) );
+
+ const __m128 absMask = _mm_set1_ps( -0.0f );
+ __m128 acc = _mm_setzero_ps();
+ for( ; rsi < rsiEndAligned; rsi += 2 )
+ {
+ __m128 v = _mm_loadu_ps( (const float*)rsi );
+ v = _mm_andnot_ps( absMask, v );
+ acc = _mm_add_ps( acc, v );
+ }
+ if( rsi != rsiEnd )
+ {
+ __m128 v = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
+ v = _mm_andnot_ps( absMask, v );
+ acc = _mm_add_ps( acc, v );
+ }
+
+ // Return acc.xy + acc.zw
+ acc = _mm_add_ps( acc, _mm_movehl_ps( acc, acc ) );
+ return acc;
+ }
+
+ inline eSpeakerChannel produceResult( const __m128 ev )
+ {
+ // Original code did following:
+ // if( energy0 > 1.1 * energy1 ) speaker = "(speaker 0)"; else if( energy1 > 1.1 * energy0 ) speaker = "(speaker 1)"; else speaker = "(speaker ?)";
+
+ // Flip left/right channels
+ __m128 tmp = _mm_shuffle_ps( ev, ev, _MM_SHUFFLE( 3, 2, 0, 1 ) );
+ // Multiply by the magic number
+ tmp = _mm_mul_ps( tmp, _mm_set1_ps( 1.1f ) );
+ // Compare for ev > tmp
+ tmp = _mm_cmpgt_ps( ev, tmp );
+ const uint32_t mask = (uint32_t)_mm_movemask_ps( tmp ) & 0b11;
+
+ assert( mask != 0b11 ); // That would mean the following is true: ( ( left > right * 1.1 ) && ( right > left * 1.1 ) )
+
+ return (eSpeakerChannel)mask;
+ }
+}
+
+HRESULT COMLIGHTCALL ContextImpl::detectSpeaker( const sTimeInterval& time, eSpeakerChannel& result ) const noexcept
+{
+ // Ensure we have the spectrogram
+ if( nullptr == currentSpectrogram )
+ {
+ logError( u8"Because the audio is streamed, iContext.detectSpeaker() method only works when called from the callbacks" );
+ return OLE_E_BLANK;
+ }
+
+ // Load the timestamps
+ int64_t begin = (int64_t)time.begin.ticks;
+ int64_t end = (int64_t)time.end.ticks;
+ // Offset + scale into chunks
+ begin = chunkOffset( begin, mediaTimeOffset );
+ end = chunkOffset( end, mediaTimeOffset );
+
+ int64_t len = end - begin;
+ if( len <= 0 )
+ {
+ result = eSpeakerChannel::Unsure;
+ return S_OK;
+ }
+
+ // Extract the slice of stereo PCM data
+ HRESULT hr = currentSpectrogram->copyStereoPcm( (size_t)begin, (size_t)len, diarizeBuffer );
+ if( hr == OLE_E_BLANK )
+ {
+ result = eSpeakerChannel::NoStereoData;
+ return S_OK;
+ }
+ CHECK( hr );
+
+ const __m128 energyVec = computeChannelsEnergy( diarizeBuffer );
+ result = produceResult( energyVec );
+ return S_OK;
+} \ No newline at end of file
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
index 448efd5..c404b8f 100644
--- a/Whisper/Whisper/ContextImpl.h
+++ b/Whisper/Whisper/ContextImpl.h
@@ -15,6 +15,8 @@ namespace Whisper
DirectCompute::WhisperContext context;
Spectrogram spectrogram;
int64_t mediaTimeOffset = 0;
+ iSpectrogram* currentSpectrogram = nullptr;
+ class CurrentSpectrogramRaii;
ProfileCollection profiler;
HRESULT COMLIGHTCALL getModel( iModel** pp ) override final;
@@ -68,6 +70,7 @@ namespace Whisper
int defaultThreadsCount() const;
__m128i getMemoryUse() const;
+ mutable std::vector<StereoSample> diarizeBuffer;
public:
diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp
index 9eb4c04..9a156fb 100644
--- a/Whisper/Whisper/ContextImpl.misc.cpp
+++ b/Whisper/Whisper/ContextImpl.misc.cpp
@@ -401,9 +401,4 @@ HRESULT COMLIGHTCALL ContextImpl::runStreamed( const sFullParams& params, const
{
return hr;
}
-}
-
-HRESULT COMLIGHTCALL ContextImpl::detectSpeaker( const sTimeInterval& time, eSpeakerChannel& result ) const noexcept
-{
- return E_NOTIMPL;
} \ No newline at end of file
diff --git a/Whisper/Whisper/MelStreamer.cpp b/Whisper/Whisper/MelStreamer.cpp
index a7c2472..bb38846 100644
--- a/Whisper/Whisper/MelStreamer.cpp
+++ b/Whisper/Whisper/MelStreamer.cpp
@@ -490,4 +490,45 @@ MelStreamerThread::~MelStreamerThread()
if( res == WAIT_OBJECT_0 )
return;
// TODO: log a warning
+}
+
+HRESULT MelStreamer::copyStereoPcm( size_t offset, size_t length, std::vector<StereoSample>& buffer ) const
+{
+ if( queuePcmStereo.empty() )
+ return OLE_E_BLANK;
+
+ if( offset < streamStartOffset )
+ {
+ logError( u8"MelStreamer doesn't support backwards seek" );
+ return E_UNEXPECTED;
+ }
+
+ // Offset relative to the first chunk on the queue
+ const size_t off = offset - streamStartOffset;
+ if( off >= queuePcmStereo.size() )
+ return E_BOUNDS;
+
+ // Resize the output buffer
+ try
+ {
+ buffer.resize( length * FFT_STEP );
+ }
+ catch( const std::bad_alloc& )
+ {
+ return E_OUTOFMEMORY;
+ }
+ StereoSample* rdi = buffer.data();
+
+ // Copy PCM chunks from the queue
+ const size_t lengthToCopy = std::min( length, queuePcmStereo.size() - off );
+ for( size_t i = 0; i < lengthToCopy; i++, rdi += FFT_STEP )
+ {
+ const float* rsi = queuePcmStereo[ i + off ].stereo.data();
+ memcpy( rdi, rsi, 8 * FFT_STEP );
+ }
+ // If needed, write zeros to the tail
+ if( lengthToCopy == length )
+ return S_OK;
+ memset( rdi, 0, ( length - lengthToCopy ) * FFT_STEP );
+ return S_OK;
} \ No newline at end of file
diff --git a/Whisper/Whisper/MelStreamer.h b/Whisper/Whisper/MelStreamer.h
index 2a891ca..387622c 100644
--- a/Whisper/Whisper/MelStreamer.h
+++ b/Whisper/Whisper/MelStreamer.h
@@ -44,6 +44,8 @@ namespace Whisper
size_t getLength() const noexcept override final { return reader.getLength(); }
+ HRESULT copyStereoPcm( size_t offset, size_t length, std::vector<StereoSample>& buffer ) const override final;
+
public:
MelStreamer( const Filters& filters, ProfileCollection& profiler, const iAudioReader* reader );
};
diff --git a/Whisper/Whisper/Spectrogram.cpp b/Whisper/Whisper/Spectrogram.cpp
index 400f9b1..76130bd 100644
--- a/Whisper/Whisper/Spectrogram.cpp
+++ b/Whisper/Whisper/Spectrogram.cpp
@@ -121,4 +121,9 @@ void Whisper::computeSignalEnergy( std::vector<float>& result, const iAudioBuffe
sum += fabsf( samples[ i + j ] );
result[ i ] = sum / ( 2 * hw + 1 );
}
+}
+
+HRESULT Spectrogram::copyStereoPcm( size_t offset, size_t length, std::vector<StereoSample>& buffer ) const
+{
+ return E_NOTIMPL;
} \ No newline at end of file
diff --git a/Whisper/Whisper/Spectrogram.h b/Whisper/Whisper/Spectrogram.h
index 04e2c06..28019ee 100644
--- a/Whisper/Whisper/Spectrogram.h
+++ b/Whisper/Whisper/Spectrogram.h
@@ -24,6 +24,8 @@ namespace Whisper
class MelContext;
+ HRESULT copyStereoPcm( size_t offset, size_t length, std::vector<StereoSample>& buffer ) const override final;
+
public:
size_t getLength() const noexcept override final
{
diff --git a/Whisper/Whisper/iSpectrogram.h b/Whisper/Whisper/iSpectrogram.h
index 2e3199d..35af009 100644
--- a/Whisper/Whisper/iSpectrogram.h
+++ b/Whisper/Whisper/iSpectrogram.h
@@ -3,6 +3,11 @@
namespace Whisper
{
+ struct alignas( 8 ) StereoSample
+ {
+ float left, right;
+ };
+
__interface iSpectrogram
{
// Make a buffer with length * N_MEL floats, starting at the specified offset
@@ -11,6 +16,9 @@ namespace Whisper
// Apparently, the length unit is 160 input samples = 10 milliseconds of audio
size_t getLength() const;
+
+ // If the source data is stereo, copy the specified slice of the data into the provided vector
+ HRESULT copyStereoPcm( size_t offset, size_t length, std::vector<StereoSample>& buffer ) const;
};
// RAII class to deal with iSpectrogram's makeBuffer method.