diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-23 17:41:49 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-23 17:41:49 -0700 |
| commit | 790c91d7ad515c3c0a22ca1341316265b8f0d779 (patch) | |
| tree | 28527bbcf87e8fab1d27eb76a1f5ea325b94d599 /app/stt.py | |
| parent | 73de7cb2d8fb964e7f76ab55420e9bc331bf7bea (diff) | |
bugfixes
* fix model acquisition
* fix local beepsnd
* fix volume control
Diffstat (limited to 'app/stt.py')
| -rw-r--r-- | app/stt.py | 62 |
1 files changed, 49 insertions, 13 deletions
@@ -3,6 +3,7 @@ from faster_whisper import WhisperModel import langcodes import numpy as np import os +import noisereduce as nr try: from profanity_filter import ProfanityFilter PROFANITY_FILTER_AVAILABLE = True @@ -260,9 +261,13 @@ class NormalizingAudioCollector(AudioCollectorFilter): return frames class BoostingAudioCollector(AudioCollectorFilter): - def __init__(self, parent: AudioCollector, target_dBFS: float, cfg: typing.Dict): + def __init__(self, parent: AudioCollector, + target_dBFS: float, + max_gain_dB: float, + cfg: typing.Dict): AudioCollectorFilter.__init__(self, parent) self.target_dBFS = target_dBFS + self.max_gain_dB = max_gain_dB self.cfg = cfg def getAudio(self) -> bytes: @@ -270,9 +275,10 @@ class BoostingAudioCollector(AudioCollectorFilter): audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB) if self.cfg["enable_debug_mode"]: - print(f"Boosting audio from {audio.dBFS}dB to {self.target_dBFS}dB", file=sys.stderr) - audio = audio.apply_gain(self.target_dBFS - audio.dBFS) + print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True) + audio = audio.apply_gain(gain) frames = np.array(audio.get_array_of_samples()) frames = np.int16(frames).tobytes() @@ -296,6 +302,26 @@ class CompressingAudioCollector(AudioCollectorFilter): return frames +class NoiseReducingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, cfg: typing.Dict): + AudioCollectorFilter.__init__(self, parent) + self.cfg = cfg + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + audio_array = np.frombuffer(audio, dtype=np.int16).astype(np.float32) + + reduced_audio = nr.reduce_noise( + y=audio_array, + sr=AudioStream.FPS, + ) + + # Convert back to int16 + reduced_audio = np.clip(reduced_audio, -32768, 32767) + frames = np.int16(reduced_audio).tobytes() + + return frames + class AudioSegmenter: def __init__(self, min_silence_ms=250, @@ -398,6 +424,12 @@ class Segment: avg_logprob = f"(avg_logprob: {self.avg_logprob}) " return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob +def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + class Whisper: def __init__(self, collector: AudioCollector, @@ -421,6 +453,9 @@ class Whisper: already_downloaded = os.path.exists(model_root) + if not already_downloaded: + print(f"Model {model_str} not already downloaded, downloading now...", flush=True) + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], @@ -433,10 +468,12 @@ class Whisper: def update_context(self, committed_text: str): """Update the context with recently committed text.""" - self.recent_context = (self.recent_context + " " + committed_text).strip() - # Keep only the last N characters to avoid prompt getting too long + self.recent_context = join_segments(self.recent_context, committed_text).strip() + # Drop half of the context window. if len(self.recent_context) > self.context_window_chars: - self.recent_context = self.recent_context[-self.context_window_chars:] + words = self.recent_context.split() + words = words[len(words)//2:] + self.recent_context = ' '.join(words) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -449,6 +486,8 @@ class Whisper: # Build context-aware prompt prompt = self._build_prompt() + print(f"Prompt: {prompt}", flush=True) + t0 = time.time() segments, info = self.model.transcribe( audio, @@ -698,8 +737,10 @@ def transcriptionThread(shared_data: SharedThreadData): stream = MicStream(shared_data.cfg) collector = AudioCollector(stream) collector = CompressingAudioCollector(collector) - collector = BoostingAudioCollector(collector, -12.0, shared_data.cfg) - collector = NormalizingAudioCollector(collector) + collector = BoostingAudioCollector(collector, -24.0, 24.0, + shared_data.cfg) + collector = NoiseReducingAudioCollector(collector, shared_data.cfg) + #collector = NormalizingAudioCollector(collector) whisper = Whisper(collector, shared_data.cfg) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], @@ -761,11 +802,6 @@ def transcriptionThread(shared_data: SharedThreadData): # breaking OSC pager. if len(shared_data.transcript) >= 1024: shared_data.transcript = shared_data.transcript[-512:] - def join_segments(a, b): - if len(a) > 0 and a[-1] != ' ': - return a + ' ' + b - else: - return a + b shared_data.transcript = \ join_segments(shared_data.transcript, commit.delta) shared_data.preview = commit.preview |
