diff options
Diffstat (limited to 'app/stt.py')
| -rw-r--r-- | app/stt.py | 267 |
1 files changed, 127 insertions, 140 deletions
@@ -2,15 +2,11 @@ from datetime import datetime from faster_whisper import WhisperModel import json import langcodes +from logger import log, log_err +import noisereduce as nr import numpy as np import os -import noisereduce as nr -try: - from profanity_filter import ProfanityFilter - PROFANITY_FILTER_AVAILABLE = True -except ImportError: - PROFANITY_FILTER_AVAILABLE = False - print("Warning: profanity_filter module not available", file=sys.stderr) +from profanity_filter import ProfanityFilter import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -19,6 +15,7 @@ import sys import time import typing import wave +from hallucination_filter import HallucinationFilter APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -55,7 +52,7 @@ class MicStream(AudioStream): which_mic = cfg["microphone"] if cfg["enable_debug_mode"]: - print(f"Finding mic {which_mic}", file=sys.stderr) + log(f"Finding mic {which_mic}") self.dumpMicDevices() got_match = False @@ -70,8 +67,8 @@ class MicStream(AudioStream): target_str = "Microphone (Beyond)" else: if cfg["enable_debug_mode"]: - print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID", file=sys.stderr) + log(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID") device_index = int(which_mic) got_match = True if not got_match: @@ -81,8 +78,7 @@ class MicStream(AudioStream): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') if target_str in device_name: - print(f"Got matching mic: {device_name}", - file=sys.stderr) + log(f"Got matching mic: {device_name}") device_index = i got_match = True break @@ -91,10 +87,10 @@ class MicStream(AudioStream): info = self.p.get_device_info_by_host_api_device_index(0, device_index) if cfg["enable_debug_mode"]: - print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + log(f"Found mic {which_mic}: {info['name']}") self.sample_rate = int(info['defaultSampleRate']) if cfg["enable_debug_mode"]: - print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + log(f"Mic sample rate: {self.sample_rate}") self.stream = self.p.open( rate=self.sample_rate, @@ -119,7 +115,7 @@ class MicStream(AudioStream): for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) + log("Input Device id ", i, " - ", device_name) def onAudioFramesAvailable(self, frames, @@ -278,7 +274,7 @@ class BoostingAudioCollector(AudioCollectorFilter): frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB) if self.cfg["enable_debug_mode"]: - print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True) + log(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})") audio = audio.apply_gain(gain) frames = np.array(audio.get_array_of_samples()) @@ -335,7 +331,6 @@ class AudioSegmenter: # Load Silero VAD model self.model = load_silero_vad() - self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s self.min_speech_duration_ms = min_speech_duration_ms @@ -351,7 +346,6 @@ class AudioSegmenter: audio_array, self.model, sampling_rate=AudioStream.FPS, - threshold=self.vad_threshold, min_silence_duration_ms=self.min_silence_duration_ms, max_speech_duration_s=self.max_speech_duration_s, min_speech_duration_ms=self.min_speech_duration_ms, @@ -371,12 +365,9 @@ class AudioSegmenter: for i in range(len(segments)): s = segments[i] - #print(f"s: {s}") - #print(f"last_end: {last_end}") if last_end: delta_frames = s['start'] - last_end - #print(f"delta frames: {delta_frames}") if delta_frames > min_delta_frames: cutoff = s['start'] else: @@ -384,8 +375,6 @@ class AudioSegmenter: if i == len(segments) - 1: now = int(len(audio) / AudioStream.FRAME_SZ) - #print(f"now: {now}") - #print(f"min d: {min_delta_frames}") delta_frames = now - s['end'] if delta_frames > min_delta_frames: cutoff = now - int(min_delta_frames / 2) @@ -402,7 +391,8 @@ class Segment: wall_ts: float, avg_logprob: float, no_speech_prob: float, - compression_ratio: float): + compression_ratio: float, + audio_len_s: float): self.transcript = transcript # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. self.start_ts = start_ts @@ -413,6 +403,7 @@ class Segment: self.avg_logprob = avg_logprob self.no_speech_prob = no_speech_prob self.compression_ratio = compression_ratio + self.audio_len_s = audio_len_s def __str__(self): ts = f"(ts: {self.start_ts}-{self.end_ts}) " @@ -423,7 +414,8 @@ class Segment: no_speech = f"(no_speech: {self.no_speech_prob}) " avg_logprob = f"(avg_logprob: {self.avg_logprob}) " - return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s = f"(max_len_s: {self.audio_len_s}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s def join_segments(a, b): if len(a) > 0 and a[-1] != ' ': @@ -431,20 +423,76 @@ def join_segments(a, b): else: return a + b + +class SegmentLogger: + def __init__(self, cfg: typing.Dict): + self.cfg = cfg + self.enabled = cfg.get("enable_segment_logging", False) + self.session_data = [] + self.log_file = None + + if self.enabled: + log_dir = os.path.join(PROJECT_ROOT, "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Create file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") + log(f"Segment logging enabled. Logging to: {self.log_file}") + + def log_segment(self, segment: Segment, commit_type: str = "commit"): + if not self.enabled: + return + + segment_data = { + "timestamp": datetime.now().isoformat(), + "type": commit_type, + "text": segment.transcript, + "start_ts": segment.start_ts, + "end_ts": segment.end_ts, + "wall_ts": segment.wall_ts, + "avg_logprob": segment.avg_logprob, + "no_speech_prob": segment.no_speech_prob, + "compression_ratio": segment.compression_ratio, + "duration": segment.end_ts - segment.start_ts, + "duration_sanity": segment.audio_len_s + } + + self.session_data.append(segment_data) + + # Write to file incrementally + try: + with open(self.log_file, 'w') as f: + json.dump({ + "session_start": self.session_data[0]["timestamp"] if self.session_data else None, + "segments": self.session_data + }, f, indent=2) + except Exception as e: + log_err(f"Error writing segment log: {e}") + + def close(self): + if self.enabled and self.session_data: + log(f"Session complete. Logged {len(self.session_data)} " + \ + "segments to {self.log_file}") + + class Whisper: def __init__(self, collector: AudioCollector, - cfg: typing.Dict): + cfg: typing.Dict, + segment_logger: SegmentLogger = None): self.collector = collector self.model = None self.cfg = cfg + self.hallucination_filter = HallucinationFilter() + self.segment_logger = segment_logger model_str = cfg["model"] model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) + log(f"Model {cfg['model']} will be saved to {model_root}") model_device = "cuda" compute_type = cfg["compute_type"] @@ -455,7 +503,7 @@ class Whisper: already_downloaded = os.path.exists(model_root) if not already_downloaded: - print(f"Model {model_str} not already downloaded, downloading now...", flush=True) + log(f"Model {model_str} not already downloaded, downloading now...") self.model = WhisperModel(model_str, device = model_device, @@ -483,19 +531,20 @@ class Whisper: # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 + audio_len_s = len(frames) / 16000.0 # Build context-aware prompt prompt = self._build_prompt() if self.cfg["enable_debug_mode"]: - print(f"Prompt: {prompt}", flush=True) + log(f"Prompt: {prompt}") t0 = time.time() segments, info = self.model.transcribe( audio, language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, + vad_filter = False, + #temperature=0.0, without_timestamps = False, initial_prompt=prompt, beam_size=self.cfg.get("beam_size", 5), @@ -504,44 +553,44 @@ class Whisper: ) res = [] for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 1) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - # Another touchup targeted at the vexatious "thanks for watching!" - # hallucination. This triggers a lot when listening to - # instrumental/electronic music. - if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 2) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - if s.avg_logprob < -0.75: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 3) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue + # Log raw segment before filtering + if self.segment_logger: + # Create a temporary segment object for logging + raw_seg = Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + self.segment_logger.log_segment(raw_seg, "raw") + # Sometimes the model reports a bum duration, breaking our filters. + # Cap the segment length above by the length of the audio in. + duration_s = min(s.end - s.start, audio_len_s) + if self.cfg["enable_debug_mode"]: - print(f"s get: {s}") + log(f"s get: {s}") if s.avg_logprob < -1.0: continue if s.compression_ratio > 2.4: continue - res.append(Segment(s.text, s.start, s.end, + + # Create segment object + seg = Segment(s.text, s.start, s.end, self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + + # Check with ML model for "Thank you" hallucinations + if self.hallucination_filter.is_thank_you_hallucination(seg): + if self.cfg["enable_debug_mode"]: + log(f"Drop probable hallucination (case 4) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})") + continue + + res.append(seg) t1 = time.time() if self.cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") + log(f"Transcription latency (s): {t1 - t0}") return res def _build_prompt(self) -> str: @@ -580,64 +629,13 @@ class TranscriptCommit: def saveAudio(audio: bytes, path: str, cfg: typing.Dict): with wave.open(path, 'wb') as wf: if cfg["enable_debug_mode"]: - print(f"Saving audio to {path}", file=sys.stderr) + log(f"Saving audio to {path}") wf.setnchannels(AudioStream.CHANNELS) wf.setsampwidth(AudioStream.FRAME_SZ) wf.setframerate(AudioStream.FPS) wf.writeframes(audio) -class SegmentLogger: - def __init__(self, cfg: typing.Dict): - self.cfg = cfg - self.enabled = cfg.get("enable_segment_logging", False) - self.session_data = [] - self.log_file = None - - if self.enabled: - log_dir = os.path.join(PROJECT_ROOT, "logs") - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - # Create file - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") - print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr) - - def log_segment(self, segment: Segment, commit_type: str = "commit"): - if not self.enabled: - return - - segment_data = { - "timestamp": datetime.now().isoformat(), - "type": commit_type, - "text": segment.transcript, - "start_ts": segment.start_ts, - "end_ts": segment.end_ts, - "wall_ts": segment.wall_ts, - "avg_logprob": segment.avg_logprob, - "no_speech_prob": segment.no_speech_prob, - "compression_ratio": segment.compression_ratio, - "duration": segment.end_ts - segment.start_ts - } - - self.session_data.append(segment_data) - - # Write to file incrementally - try: - with open(self.log_file, 'w') as f: - json.dump({ - "session_start": self.session_data[0]["timestamp"] if self.session_data else None, - "segments": self.session_data - }, f, indent=2) - except Exception as e: - print(f"Error writing segment log: {e}", file=sys.stderr) - - def close(self): - if self.enabled and self.session_data: - print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr) - - class VadCommitter: def __init__(self, cfg: typing.Dict, @@ -680,16 +678,12 @@ class VadCommitter: if delta.strip(): self.whisper.update_context(delta.strip()) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "commit") - audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: - print(f"commit segment: {s}", file=sys.stderr) + log(f"commit segment: {s}") if len(delta) > 0: - print(f"delta get: {delta}", file=sys.stderr) + log(f"delta get: {delta}") if self.cfg["save_audio"] and len(delta) > 0: ts = datetime.fromtimestamp(self.collector.now() - latency_s) @@ -705,10 +699,6 @@ class VadCommitter: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "preview") - if not has_audio: self.collector.keepLast(1.0) @@ -758,13 +748,13 @@ class ProfanityPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg self.filter = None - if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + if cfg["enable_profanity_filter"]: en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") try: self.filter = ProfanityFilter(en_profanity_path) self.filter.load() except Exception as e: - print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + log_err(f"Warning: Could not load profanity filter: {e}") self.filter = None def transform(self, commit: TranscriptCommit) -> TranscriptCommit: @@ -812,12 +802,11 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.cfg) collector = NoiseReducingAudioCollector(collector, shared_data.cfg) #collector = NormalizingAudioCollector(collector) - whisper = Whisper(collector, shared_data.cfg) + segment_logger = SegmentLogger(shared_data.cfg) + whisper = Whisper(collector, shared_data.cfg, segment_logger) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"]) - - segment_logger = SegmentLogger(shared_data.cfg) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger) plugins = [] @@ -838,7 +827,7 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.stream = stream shared_data.collector = collector - print(f"Ready to go!", flush=True) + log(f"Ready to go!") while not shared_data.exit_event.is_set(): time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); @@ -862,8 +851,8 @@ def transcriptionThread(shared_data: SharedThreadData): silence_duration = commit.start_ts - last_commit_end_ts if silence_duration > shared_data.cfg["reset_after_silence_s"]: if shared_data.cfg["enable_debug_mode"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) + log(f"Resetting transcript after {silence_duration}-second " + "silence") shared_data.transcript = "" shared_data.preview = "" whisper.recent_context = "" # Reset context too @@ -885,20 +874,18 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.preview) try: - print(f"Transcript: {shared_data.transcript}", flush=True) + log(f"Transcript: {shared_data.transcript}") except UnicodeEncodeError: - print("Failed to encode transcript - discarding delta", - file=sys.stderr) + log_err("Failed to encode transcript - discarding delta") continue try: - print(f"Preview: {shared_data.preview}", flush=True) + log(f"Preview: {shared_data.preview}") except UnicodeEncodeError: - print("Failed to encode preview - discarding", file=sys.stderr) + log_err("Failed to encode preview - discarding") if shared_data.cfg["enable_debug_mode"]: - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", - file=sys.stderr) + log(f"commit latency: {commit.latency_s}") + log(f"commit thresh: {commit.thresh_at_commit}") if len(shared_data.transcript) > 0 and \ (not shared_data.transcript.endswith(' ')) and \ |
