From 1136acfc365f357d2df13a263714e8ae0614c4f9 Mon Sep 17 00:00:00 2001 From: yum Date: Sun, 26 Feb 2023 19:42:33 -0800 Subject: Add retainDuration option to CaptureParams This allows users to retain a suffix of the PCM buffer after a VAD segmentation event, reducing some instances of words being lost at the start of the next VAD window. --- Whisper/API/MfStructs.h | 3 +++ Whisper/MF/AudioBuffer.h | 15 +++++++++++++++ Whisper/Whisper/ContextImpl.capture.cpp | 5 ++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Whisper/API/MfStructs.h b/Whisper/API/MfStructs.h index 39255de..c23d633 100644 --- a/Whisper/API/MfStructs.h +++ b/Whisper/API/MfStructs.h @@ -28,6 +28,9 @@ namespace Whisper float maxDuration = 3.0f; float dropStartSilence = 0.25f; float pauseDuration = 0.333f; + // After audio is segmented using VAD, as many as this many seconds of + // audio will be retained as the input to the next transcription window. + float retainDuration = 0.25f; // Flags for the audio capture uint32_t flags = 0; }; diff --git a/Whisper/MF/AudioBuffer.h b/Whisper/MF/AudioBuffer.h index 77be1e0..63c4a8c 100644 --- a/Whisper/MF/AudioBuffer.h +++ b/Whisper/MF/AudioBuffer.h @@ -48,12 +48,27 @@ namespace Whisper void dropFirst(size_t len) { assert(len <= mono.size()); + if (len >= mono.size()) { + mono.clear(); + return; + } size_t remainder = mono.size() - len; auto tmp = std::vector(remainder); memcpy(tmp.data(), mono.data() + len, remainder); mono = std::move(tmp); } + void retainLast(size_t len) + { + if (len >= mono.size()) { + return; + } + size_t prefix_len = mono.size() - len; + auto tmp = std::vector(len); + memcpy(tmp.data(), mono.data() + prefix_len, len); + mono = std::move(tmp); + } + void normalize() { const auto &min = *std::min_element(mono.begin(), mono.end()); diff --git a/Whisper/Whisper/ContextImpl.capture.cpp b/Whisper/Whisper/ContextImpl.capture.cpp index bc88249..0100fcd 100644 --- a/Whisper/Whisper/ContextImpl.capture.cpp +++ b/Whisper/Whisper/ContextImpl.capture.cpp @@ -53,6 +53,7 @@ namespace struct CaptureParams { uint32_t minDuration, maxDuration, dropStartSilence, pauseDuration; + uint32_t retainDuration; uint32_t flags; CaptureParams( const sCaptureParams& cp ) @@ -64,6 +65,8 @@ namespace __m128i ints = _mm_cvtps_epi32( floats ); store16( &minDuration, ints ); + retainDuration = std::round(retainDuration * SAMPLE_RATE); + flags = cp.flags; } }; @@ -142,7 +145,7 @@ namespace buffer.pcm.normalize(); SubmitThreadpoolWork( work ); pcmStartTime = nextSampleTime; - pcm.clear(); + pcm.retainLast(captureParams.retainDuration); vad.clear(); return S_OK; } -- cgit v1.2.3