diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
| commit | a7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (patch) | |
| tree | 61d4870a019acb0e545d88e7661c8a4c7d90e499 /app | |
| parent | 5df013d26eb13ed4aef8d16aa14346e0f9be5111 (diff) | |
Experiment with hallucination reduction
- update cursorignore
- add hallucination filter training & inference code
- put logging into a central module
- segment metadata logging occurs before filtering
- segment metadata logging is on by default
- check in embedded python setup script
- include trained hallucination filter model
Diffstat (limited to 'app')
| -rw-r--r-- | app/hallucination_filter.py | 66 | ||||
| -rw-r--r-- | app/hi.py | 65 | ||||
| -rw-r--r-- | app/logger.py | 12 | ||||
| -rw-r--r-- | app/requirements.txt | 3 | ||||
| -rw-r--r-- | app/stt.py | 267 |
5 files changed, 237 insertions, 176 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py new file mode 100644 index 0000000..9b24a85 --- /dev/null +++ b/app/hallucination_filter.py @@ -0,0 +1,66 @@ +import io +import joblib +from logger import log, log_err +import numpy as np +import pandas as pd +from pathlib import Path +import sys + + +class HallucinationFilter: + """Filter for detecting hallucinated segments in speech-to-text output.""" + + def __init__(self, model_path: Path = None): + """ + Initialize the hallucination filter. + + Args: + model_path: Optional path to the model file. If not provided, + uses the default path. + """ + self.model = None + self.threshold = None + self.features = None + + # Get the project root directory + app_root = Path(__file__).resolve().parent + project_root = app_root.parent + + # Use provided path or default + if model_path is None: + model_path = project_root / "Models" / "thankyou_filter_gb.pkl" + + # Try to load the model + log_err(f"Loading hallucination filter") + bundle = joblib.load(model_path) + self.model = bundle["model"] + self.threshold = bundle["threshold"] + self.features = bundle["features"] # Extract feature names + log_err(f"Loaded hallucination filter model from {model_path}") + + def is_thank_you_hallucination(self, segment) -> bool: + """ + Check if a segment is likely a "Thank you" hallucination. + Returns False if model is not available. + + Args: + segment: A segment object with attributes avg_logprob, audio_len_s, + no_speech_prob, and compression_ratio. + + Returns: + bool: True if the segment is likely a hallucination, False otherwise. + """ + # Create DataFrame with proper feature names + X = pd.DataFrame([[ + segment.avg_logprob, + segment.audio_len_s, + segment.no_speech_prob, + segment.compression_ratio, + np.log1p(segment.audio_len_s), + segment.avg_logprob * segment.audio_len_s + ]], columns=self.features) + + # Get probability + prob = self.model.predict_proba(X)[0, 1] + return prob >= self.threshold + @@ -2,9 +2,11 @@ import app_config import argparse import io import keybind_event_machine +from logger import log, log_err from math import floor, ceil import msvcrt import os +import pygame from pythonosc import udp_client import sentencepiece as spm import steamvr @@ -13,10 +15,6 @@ import stt import sys import threading import time -import pygame - -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') -sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') # Initialize pygame mixer pygame.mixer.init() @@ -31,10 +29,10 @@ PROJECT_ROOT = os.path.dirname(APP_ROOT) def get_tokenizer(): model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model") - print(f"Loading SentencePiece tokenizer from: {model_path}") + log(f"Loading SentencePiece tokenizer from: {model_path}") sp = spm.SentencePieceProcessor() sp.load(model_path) - print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + log(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") return sp def parse_args(): @@ -137,16 +135,16 @@ def wrap_line(line: str, cols): def get_blocks(lines, tokenizer, block_width, num_blocks): if LOG_LEVEL == 2: - print(f"Lines sent to tokenizer: {''.join(lines)}") + log(f"Lines sent to tokenizer: {''.join(lines)}") tokens = tokenizer.encode_as_ids(''.join(lines)) if LOG_LEVEL == 2: - print(f"Tokens: {tokens}") + log(f"Tokens: {tokens}") pieces = [] for tok in tokens: piece = tokenizer.id_to_piece(tok) pieces.append(piece) if LOG_LEVEL == 2: - print(f"Pieces: {pieces}") + log(f"Pieces: {pieces}") # Group tokens into blocks and pad with empty characters. # Also get visual pointers - the location where each block will be rendered. @@ -168,8 +166,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks): return (blocks, visual_pointers) blocks, visual_pointers = get_blocks() if LOG_LEVEL == 2: - print(f"Blocks: {blocks}") - print(f"Visual pointers: {visual_pointers}") + log(f"Blocks: {blocks}") + log(f"Visual pointers: {visual_pointers}") # Set all blocks up to the next `num_blocks` boundary to blank tokens. # This handles the edge case where a prior message wrote data there which @@ -183,8 +181,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks): return blocks, visual_pointers blocks, visual_pointers = pad_blocks(blocks, visual_pointers) if LOG_LEVEL == 2: - print(f"Blocks (padded): {blocks}") - print(f"Visual pointers (padded): {visual_pointers}") + log(f"Blocks (padded): {blocks}") + log(f"Visual pointers (padded): {visual_pointers}") return blocks, visual_pointers @@ -223,11 +221,10 @@ def send_data(osc_client, indices, blocks, visual_pointers): blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) if LOG_LEVEL == 2: - print(f"Blocks (byte 00): {blocks_byte00}") - print(f"Blocks (byte 01): {blocks_byte01}") + log(f"Blocks (byte 00): {blocks_byte00}") + log(f"Blocks (byte 01): {blocks_byte01}") def send_osc(osc_client, addr, data): - #print(f"Sending {data} to {addr}") osc_client.send_message(addr, data) for i in range(0, len(blocks)): @@ -241,7 +238,7 @@ def send_data(osc_client, indices, blocks, visual_pointers): addr = "/avatar/parameters/" + vp_param send_osc(osc_client, addr, vp_float) if LOG_LEVEL == 2: - print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") + log(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") for j in range(0, len(blocks[i])): byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 @@ -271,7 +268,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): for line in line_wrapped: assert_equal(len(line), cfg["cols"]) if LOG_LEVEL == 2: - print(f"Wrapped lines: {line_wrapped}") + log(f"Wrapped lines: {line_wrapped}") # Get several blank lines whenever we roll over. # It's better for the reader to have some continuity when the board pages @@ -312,7 +309,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): state.blocks.append(diff_blocks[0]) state.visual_pointers.append(diff_visual_pointers[0]) elif indices[0] > len(state.blocks): - print(f"This should never happen!") + log(f"This should never happen!") sys.exit(1) else: state.blocks[indices[0]] = diff_blocks[0] @@ -345,7 +342,7 @@ def osc_thread(shared_data: SharedThreadData): continue addr = "/chatbox/input" if shared_data.cfg["enable_debug_mode"]: - print(f"Send {local_word}", flush=True) + log(f"Send {local_word}") osc_client.send_message(addr, (local_word, True, False)) last_change = time.time() remote_word = local_word @@ -354,7 +351,7 @@ def osc_thread(shared_data: SharedThreadData): tokenizer = get_tokenizer() # Prime the board - print("Priming the board") + log("Priming the board") input_state = InputState() handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) @@ -424,7 +421,7 @@ def vrInputThread(shared_data: SharedThreadData): elif now - last_rising > 0.5: # Medium press - print("CLEARING", file=sys.stderr) + log_err("CLEARING") last_medium_press_end = now state = PAUSE_STATE play_sound_with_volume(waveform2, shared_data.cfg) @@ -439,25 +436,23 @@ def vrInputThread(shared_data: SharedThreadData): # Short hold if state == RECORD_STATE: - print("PAUSED", file=sys.stderr) + log_err("PAUSED") state = PAUSE_STATE shared_data.stream.pause(True) play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: - print("RECORDING", file=sys.stderr) + log_err("RECORDING") state = RECORD_STATE if shared_data.cfg["reset_on_toggle"]: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (3)", - file=sys.stderr) + log_err("Toggle detected, dropping transcript (3)") shared_data.transcript = "" shared_data.preview = "" #audio_state.drop_transcription = True else: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (3)", - file=sys.stderr) + log_err("Toggle detected, committing preview text (3)") #audio_state.text += audio_state.preview_text shared_data.stream.pause(False) @@ -502,7 +497,7 @@ def kbInputThread(shared_data: SharedThreadData): last_press_time = cur_press_time if event == EVENT_DOUBLE_PRESS: - print("CLEARING", file=sys.stderr) + log_err("CLEARING") state = PAUSE_STATE play_sound_with_volume(waveform2, shared_data.cfg) @@ -516,23 +511,21 @@ def kbInputThread(shared_data: SharedThreadData): # Short hold if state == RECORD_STATE: - print("PAUSED", file=sys.stderr) + log_err("PAUSED") state = PAUSE_STATE shared_data.stream.pause(True) play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: - print("RECORDING", file=sys.stderr) + log_err("RECORDING") state = RECORD_STATE if shared_data.cfg["reset_on_toggle"]: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (2)", - file=sys.stderr) + log_err("Toggle detected, dropping transcript (2)") shared_data.transcript = "" shared_data.preview = "" else: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (2)", - file=sys.stderr) + log_err("Toggle detected, committing preview text (2)") shared_data.stream.pause(False) play_sound_with_volume(waveform0, shared_data.cfg) @@ -545,7 +538,7 @@ def play_sound_with_volume(filepath, cfg): sound.set_volume(volume * 0.01) sound.play() except Exception as e: - print(f"Error playing sound {filepath}: {e}", file=sys.stderr) + log_err(f"Error playing sound {filepath}: {e}") if __name__ == "__main__": cli_args = parse_args() diff --git a/app/logger.py b/app/logger.py new file mode 100644 index 0000000..72a2134 --- /dev/null +++ b/app/logger.py @@ -0,0 +1,12 @@ +import sys +import io + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + +def log(message): + print(message, file=sys.stdout, flush=True) + +def log_err(message): + print(message, file=sys.stderr, flush=True) + diff --git a/app/requirements.txt b/app/requirements.txt index c8d69df..dc294e5 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -3,10 +3,13 @@ hf-xet keyboard langcodes noisereduce +pandas pyaudio pygame pydub python-osc +scikit-learn sentencepiece silero-vad openvr +joblib @@ -2,15 +2,11 @@ from datetime import datetime from faster_whisper import WhisperModel import json import langcodes +from logger import log, log_err +import noisereduce as nr import numpy as np import os -import noisereduce as nr -try: - from profanity_filter import ProfanityFilter - PROFANITY_FILTER_AVAILABLE = True -except ImportError: - PROFANITY_FILTER_AVAILABLE = False - print("Warning: profanity_filter module not available", file=sys.stderr) +from profanity_filter import ProfanityFilter import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -19,6 +15,7 @@ import sys import time import typing import wave +from hallucination_filter import HallucinationFilter APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -55,7 +52,7 @@ class MicStream(AudioStream): which_mic = cfg["microphone"] if cfg["enable_debug_mode"]: - print(f"Finding mic {which_mic}", file=sys.stderr) + log(f"Finding mic {which_mic}") self.dumpMicDevices() got_match = False @@ -70,8 +67,8 @@ class MicStream(AudioStream): target_str = "Microphone (Beyond)" else: if cfg["enable_debug_mode"]: - print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID", file=sys.stderr) + log(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID") device_index = int(which_mic) got_match = True if not got_match: @@ -81,8 +78,7 @@ class MicStream(AudioStream): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') if target_str in device_name: - print(f"Got matching mic: {device_name}", - file=sys.stderr) + log(f"Got matching mic: {device_name}") device_index = i got_match = True break @@ -91,10 +87,10 @@ class MicStream(AudioStream): info = self.p.get_device_info_by_host_api_device_index(0, device_index) if cfg["enable_debug_mode"]: - print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + log(f"Found mic {which_mic}: {info['name']}") self.sample_rate = int(info['defaultSampleRate']) if cfg["enable_debug_mode"]: - print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + log(f"Mic sample rate: {self.sample_rate}") self.stream = self.p.open( rate=self.sample_rate, @@ -119,7 +115,7 @@ class MicStream(AudioStream): for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) + log("Input Device id ", i, " - ", device_name) def onAudioFramesAvailable(self, frames, @@ -278,7 +274,7 @@ class BoostingAudioCollector(AudioCollectorFilter): frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB) if self.cfg["enable_debug_mode"]: - print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True) + log(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})") audio = audio.apply_gain(gain) frames = np.array(audio.get_array_of_samples()) @@ -335,7 +331,6 @@ class AudioSegmenter: # Load Silero VAD model self.model = load_silero_vad() - self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s self.min_speech_duration_ms = min_speech_duration_ms @@ -351,7 +346,6 @@ class AudioSegmenter: audio_array, self.model, sampling_rate=AudioStream.FPS, - threshold=self.vad_threshold, min_silence_duration_ms=self.min_silence_duration_ms, max_speech_duration_s=self.max_speech_duration_s, min_speech_duration_ms=self.min_speech_duration_ms, @@ -371,12 +365,9 @@ class AudioSegmenter: for i in range(len(segments)): s = segments[i] - #print(f"s: {s}") - #print(f"last_end: {last_end}") if last_end: delta_frames = s['start'] - last_end - #print(f"delta frames: {delta_frames}") if delta_frames > min_delta_frames: cutoff = s['start'] else: @@ -384,8 +375,6 @@ class AudioSegmenter: if i == len(segments) - 1: now = int(len(audio) / AudioStream.FRAME_SZ) - #print(f"now: {now}") - #print(f"min d: {min_delta_frames}") delta_frames = now - s['end'] if delta_frames > min_delta_frames: cutoff = now - int(min_delta_frames / 2) @@ -402,7 +391,8 @@ class Segment: wall_ts: float, avg_logprob: float, no_speech_prob: float, - compression_ratio: float): + compression_ratio: float, + audio_len_s: float): self.transcript = transcript # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. self.start_ts = start_ts @@ -413,6 +403,7 @@ class Segment: self.avg_logprob = avg_logprob self.no_speech_prob = no_speech_prob self.compression_ratio = compression_ratio + self.audio_len_s = audio_len_s def __str__(self): ts = f"(ts: {self.start_ts}-{self.end_ts}) " @@ -423,7 +414,8 @@ class Segment: no_speech = f"(no_speech: {self.no_speech_prob}) " avg_logprob = f"(avg_logprob: {self.avg_logprob}) " - return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s = f"(max_len_s: {self.audio_len_s}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s def join_segments(a, b): if len(a) > 0 and a[-1] != ' ': @@ -431,20 +423,76 @@ def join_segments(a, b): else: return a + b + +class SegmentLogger: + def __init__(self, cfg: typing.Dict): + self.cfg = cfg + self.enabled = cfg.get("enable_segment_logging", False) + self.session_data = [] + self.log_file = None + + if self.enabled: + log_dir = os.path.join(PROJECT_ROOT, "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Create file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") + log(f"Segment logging enabled. Logging to: {self.log_file}") + + def log_segment(self, segment: Segment, commit_type: str = "commit"): + if not self.enabled: + return + + segment_data = { + "timestamp": datetime.now().isoformat(), + "type": commit_type, + "text": segment.transcript, + "start_ts": segment.start_ts, + "end_ts": segment.end_ts, + "wall_ts": segment.wall_ts, + "avg_logprob": segment.avg_logprob, + "no_speech_prob": segment.no_speech_prob, + "compression_ratio": segment.compression_ratio, + "duration": segment.end_ts - segment.start_ts, + "duration_sanity": segment.audio_len_s + } + + self.session_data.append(segment_data) + + # Write to file incrementally + try: + with open(self.log_file, 'w') as f: + json.dump({ + "session_start": self.session_data[0]["timestamp"] if self.session_data else None, + "segments": self.session_data + }, f, indent=2) + except Exception as e: + log_err(f"Error writing segment log: {e}") + + def close(self): + if self.enabled and self.session_data: + log(f"Session complete. Logged {len(self.session_data)} " + \ + "segments to {self.log_file}") + + class Whisper: def __init__(self, collector: AudioCollector, - cfg: typing.Dict): + cfg: typing.Dict, + segment_logger: SegmentLogger = None): self.collector = collector self.model = None self.cfg = cfg + self.hallucination_filter = HallucinationFilter() + self.segment_logger = segment_logger model_str = cfg["model"] model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) + log(f"Model {cfg['model']} will be saved to {model_root}") model_device = "cuda" compute_type = cfg["compute_type"] @@ -455,7 +503,7 @@ class Whisper: already_downloaded = os.path.exists(model_root) if not already_downloaded: - print(f"Model {model_str} not already downloaded, downloading now...", flush=True) + log(f"Model {model_str} not already downloaded, downloading now...") self.model = WhisperModel(model_str, device = model_device, @@ -483,19 +531,20 @@ class Whisper: # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 + audio_len_s = len(frames) / 16000.0 # Build context-aware prompt prompt = self._build_prompt() if self.cfg["enable_debug_mode"]: - print(f"Prompt: {prompt}", flush=True) + log(f"Prompt: {prompt}") t0 = time.time() segments, info = self.model.transcribe( audio, language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, + vad_filter = False, + #temperature=0.0, without_timestamps = False, initial_prompt=prompt, beam_size=self.cfg.get("beam_size", 5), @@ -504,44 +553,44 @@ class Whisper: ) res = [] for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 1) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - # Another touchup targeted at the vexatious "thanks for watching!" - # hallucination. This triggers a lot when listening to - # instrumental/electronic music. - if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 2) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - if s.avg_logprob < -0.75: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 3) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue + # Log raw segment before filtering + if self.segment_logger: + # Create a temporary segment object for logging + raw_seg = Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + self.segment_logger.log_segment(raw_seg, "raw") + # Sometimes the model reports a bum duration, breaking our filters. + # Cap the segment length above by the length of the audio in. + duration_s = min(s.end - s.start, audio_len_s) + if self.cfg["enable_debug_mode"]: - print(f"s get: {s}") + log(f"s get: {s}") if s.avg_logprob < -1.0: continue if s.compression_ratio > 2.4: continue - res.append(Segment(s.text, s.start, s.end, + + # Create segment object + seg = Segment(s.text, s.start, s.end, self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + + # Check with ML model for "Thank you" hallucinations + if self.hallucination_filter.is_thank_you_hallucination(seg): + if self.cfg["enable_debug_mode"]: + log(f"Drop probable hallucination (case 4) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})") + continue + + res.append(seg) t1 = time.time() if self.cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") + log(f"Transcription latency (s): {t1 - t0}") return res def _build_prompt(self) -> str: @@ -580,64 +629,13 @@ class TranscriptCommit: def saveAudio(audio: bytes, path: str, cfg: typing.Dict): with wave.open(path, 'wb') as wf: if cfg["enable_debug_mode"]: - print(f"Saving audio to {path}", file=sys.stderr) + log(f"Saving audio to {path}") wf.setnchannels(AudioStream.CHANNELS) wf.setsampwidth(AudioStream.FRAME_SZ) wf.setframerate(AudioStream.FPS) wf.writeframes(audio) -class SegmentLogger: - def __init__(self, cfg: typing.Dict): - self.cfg = cfg - self.enabled = cfg.get("enable_segment_logging", False) - self.session_data = [] - self.log_file = None - - if self.enabled: - log_dir = os.path.join(PROJECT_ROOT, "logs") - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - # Create file - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") - print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr) - - def log_segment(self, segment: Segment, commit_type: str = "commit"): - if not self.enabled: - return - - segment_data = { - "timestamp": datetime.now().isoformat(), - "type": commit_type, - "text": segment.transcript, - "start_ts": segment.start_ts, - "end_ts": segment.end_ts, - "wall_ts": segment.wall_ts, - "avg_logprob": segment.avg_logprob, - "no_speech_prob": segment.no_speech_prob, - "compression_ratio": segment.compression_ratio, - "duration": segment.end_ts - segment.start_ts - } - - self.session_data.append(segment_data) - - # Write to file incrementally - try: - with open(self.log_file, 'w') as f: - json.dump({ - "session_start": self.session_data[0]["timestamp"] if self.session_data else None, - "segments": self.session_data - }, f, indent=2) - except Exception as e: - print(f"Error writing segment log: {e}", file=sys.stderr) - - def close(self): - if self.enabled and self.session_data: - print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr) - - class VadCommitter: def __init__(self, cfg: typing.Dict, @@ -680,16 +678,12 @@ class VadCommitter: if delta.strip(): self.whisper.update_context(delta.strip()) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "commit") - audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: - print(f"commit segment: {s}", file=sys.stderr) + log(f"commit segment: {s}") if len(delta) > 0: - print(f"delta get: {delta}", file=sys.stderr) + log(f"delta get: {delta}") if self.cfg["save_audio"] and len(delta) > 0: ts = datetime.fromtimestamp(self.collector.now() - latency_s) @@ -705,10 +699,6 @@ class VadCommitter: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "preview") - if not has_audio: self.collector.keepLast(1.0) @@ -758,13 +748,13 @@ class ProfanityPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg self.filter = None - if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + if cfg["enable_profanity_filter"]: en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") try: self.filter = ProfanityFilter(en_profanity_path) self.filter.load() except Exception as e: - print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + log_err(f"Warning: Could not load profanity filter: {e}") self.filter = None def transform(self, commit: TranscriptCommit) -> TranscriptCommit: @@ -812,12 +802,11 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.cfg) collector = NoiseReducingAudioCollector(collector, shared_data.cfg) #collector = NormalizingAudioCollector(collector) - whisper = Whisper(collector, shared_data.cfg) + segment_logger = SegmentLogger(shared_data.cfg) + whisper = Whisper(collector, shared_data.cfg, segment_logger) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"]) - - segment_logger = SegmentLogger(shared_data.cfg) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger) plugins = [] @@ -838,7 +827,7 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.stream = stream shared_data.collector = collector - print(f"Ready to go!", flush=True) + log(f"Ready to go!") while not shared_data.exit_event.is_set(): time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); @@ -862,8 +851,8 @@ def transcriptionThread(shared_data: SharedThreadData): silence_duration = commit.start_ts - last_commit_end_ts if silence_duration > shared_data.cfg["reset_after_silence_s"]: if shared_data.cfg["enable_debug_mode"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) + log(f"Resetting transcript after {silence_duration}-second " + "silence") shared_data.transcript = "" shared_data.preview = "" whisper.recent_context = "" # Reset context too @@ -885,20 +874,18 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.preview) try: - print(f"Transcript: {shared_data.transcript}", flush=True) + log(f"Transcript: {shared_data.transcript}") except UnicodeEncodeError: - print("Failed to encode transcript - discarding delta", - file=sys.stderr) + log_err("Failed to encode transcript - discarding delta") continue try: - print(f"Preview: {shared_data.preview}", flush=True) + log(f"Preview: {shared_data.preview}") except UnicodeEncodeError: - print("Failed to encode preview - discarding", file=sys.stderr) + log_err("Failed to encode preview - discarding") if shared_data.cfg["enable_debug_mode"]: - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", - file=sys.stderr) + log(f"commit latency: {commit.latency_s}") + log(f"commit thresh: {commit.thresh_at_commit}") if len(shared_data.transcript) > 0 and \ (not shared_data.transcript.endswith(' ')) and \ |
