diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
| commit | a7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (patch) | |
| tree | 61d4870a019acb0e545d88e7661c8a4c7d90e499 | |
| parent | 5df013d26eb13ed4aef8d16aa14346e0f9be5111 (diff) | |
Experiment with hallucination reduction
- update cursorignore
- add hallucination filter training & inference code
- put logging into a central module
- segment metadata logging occurs before filtering
- segment metadata logging is on by default
- check in embedded python setup script
- include trained hallucination filter model
| -rw-r--r-- | .cursorignore | 4 | ||||
| -rw-r--r-- | app/hallucination_filter.py | 66 | ||||
| -rw-r--r-- | app/hi.py | 65 | ||||
| -rw-r--r-- | app/logger.py | 12 | ||||
| -rw-r--r-- | app/requirements.txt | 3 | ||||
| -rw-r--r-- | app/stt.py | 267 | ||||
| -rw-r--r-- | train_hallucination_filter.py | 291 | ||||
| -rw-r--r-- | ui/build_scripts/setup-embedded-python.js | 104 | ||||
| -rw-r--r-- | ui/config-schema.js | 4 | ||||
| -rw-r--r-- | ui/package.json | 5 |
10 files changed, 642 insertions, 179 deletions
diff --git a/.cursorignore b/.cursorignore index a8f4624..bb76706 100644 --- a/.cursorignore +++ b/.cursorignore @@ -1,2 +1,4 @@ **/node_modules -**/site-packages
\ No newline at end of file +**/site-packages +venv +ui/dist
\ No newline at end of file diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py new file mode 100644 index 0000000..9b24a85 --- /dev/null +++ b/app/hallucination_filter.py @@ -0,0 +1,66 @@ +import io +import joblib +from logger import log, log_err +import numpy as np +import pandas as pd +from pathlib import Path +import sys + + +class HallucinationFilter: + """Filter for detecting hallucinated segments in speech-to-text output.""" + + def __init__(self, model_path: Path = None): + """ + Initialize the hallucination filter. + + Args: + model_path: Optional path to the model file. If not provided, + uses the default path. + """ + self.model = None + self.threshold = None + self.features = None + + # Get the project root directory + app_root = Path(__file__).resolve().parent + project_root = app_root.parent + + # Use provided path or default + if model_path is None: + model_path = project_root / "Models" / "thankyou_filter_gb.pkl" + + # Try to load the model + log_err(f"Loading hallucination filter") + bundle = joblib.load(model_path) + self.model = bundle["model"] + self.threshold = bundle["threshold"] + self.features = bundle["features"] # Extract feature names + log_err(f"Loaded hallucination filter model from {model_path}") + + def is_thank_you_hallucination(self, segment) -> bool: + """ + Check if a segment is likely a "Thank you" hallucination. + Returns False if model is not available. + + Args: + segment: A segment object with attributes avg_logprob, audio_len_s, + no_speech_prob, and compression_ratio. + + Returns: + bool: True if the segment is likely a hallucination, False otherwise. + """ + # Create DataFrame with proper feature names + X = pd.DataFrame([[ + segment.avg_logprob, + segment.audio_len_s, + segment.no_speech_prob, + segment.compression_ratio, + np.log1p(segment.audio_len_s), + segment.avg_logprob * segment.audio_len_s + ]], columns=self.features) + + # Get probability + prob = self.model.predict_proba(X)[0, 1] + return prob >= self.threshold + @@ -2,9 +2,11 @@ import app_config import argparse import io import keybind_event_machine +from logger import log, log_err from math import floor, ceil import msvcrt import os +import pygame from pythonosc import udp_client import sentencepiece as spm import steamvr @@ -13,10 +15,6 @@ import stt import sys import threading import time -import pygame - -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') -sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') # Initialize pygame mixer pygame.mixer.init() @@ -31,10 +29,10 @@ PROJECT_ROOT = os.path.dirname(APP_ROOT) def get_tokenizer(): model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model") - print(f"Loading SentencePiece tokenizer from: {model_path}") + log(f"Loading SentencePiece tokenizer from: {model_path}") sp = spm.SentencePieceProcessor() sp.load(model_path) - print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + log(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") return sp def parse_args(): @@ -137,16 +135,16 @@ def wrap_line(line: str, cols): def get_blocks(lines, tokenizer, block_width, num_blocks): if LOG_LEVEL == 2: - print(f"Lines sent to tokenizer: {''.join(lines)}") + log(f"Lines sent to tokenizer: {''.join(lines)}") tokens = tokenizer.encode_as_ids(''.join(lines)) if LOG_LEVEL == 2: - print(f"Tokens: {tokens}") + log(f"Tokens: {tokens}") pieces = [] for tok in tokens: piece = tokenizer.id_to_piece(tok) pieces.append(piece) if LOG_LEVEL == 2: - print(f"Pieces: {pieces}") + log(f"Pieces: {pieces}") # Group tokens into blocks and pad with empty characters. # Also get visual pointers - the location where each block will be rendered. @@ -168,8 +166,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks): return (blocks, visual_pointers) blocks, visual_pointers = get_blocks() if LOG_LEVEL == 2: - print(f"Blocks: {blocks}") - print(f"Visual pointers: {visual_pointers}") + log(f"Blocks: {blocks}") + log(f"Visual pointers: {visual_pointers}") # Set all blocks up to the next `num_blocks` boundary to blank tokens. # This handles the edge case where a prior message wrote data there which @@ -183,8 +181,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks): return blocks, visual_pointers blocks, visual_pointers = pad_blocks(blocks, visual_pointers) if LOG_LEVEL == 2: - print(f"Blocks (padded): {blocks}") - print(f"Visual pointers (padded): {visual_pointers}") + log(f"Blocks (padded): {blocks}") + log(f"Visual pointers (padded): {visual_pointers}") return blocks, visual_pointers @@ -223,11 +221,10 @@ def send_data(osc_client, indices, blocks, visual_pointers): blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) if LOG_LEVEL == 2: - print(f"Blocks (byte 00): {blocks_byte00}") - print(f"Blocks (byte 01): {blocks_byte01}") + log(f"Blocks (byte 00): {blocks_byte00}") + log(f"Blocks (byte 01): {blocks_byte01}") def send_osc(osc_client, addr, data): - #print(f"Sending {data} to {addr}") osc_client.send_message(addr, data) for i in range(0, len(blocks)): @@ -241,7 +238,7 @@ def send_data(osc_client, indices, blocks, visual_pointers): addr = "/avatar/parameters/" + vp_param send_osc(osc_client, addr, vp_float) if LOG_LEVEL == 2: - print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") + log(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") for j in range(0, len(blocks[i])): byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 @@ -271,7 +268,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): for line in line_wrapped: assert_equal(len(line), cfg["cols"]) if LOG_LEVEL == 2: - print(f"Wrapped lines: {line_wrapped}") + log(f"Wrapped lines: {line_wrapped}") # Get several blank lines whenever we roll over. # It's better for the reader to have some continuity when the board pages @@ -312,7 +309,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): state.blocks.append(diff_blocks[0]) state.visual_pointers.append(diff_visual_pointers[0]) elif indices[0] > len(state.blocks): - print(f"This should never happen!") + log(f"This should never happen!") sys.exit(1) else: state.blocks[indices[0]] = diff_blocks[0] @@ -345,7 +342,7 @@ def osc_thread(shared_data: SharedThreadData): continue addr = "/chatbox/input" if shared_data.cfg["enable_debug_mode"]: - print(f"Send {local_word}", flush=True) + log(f"Send {local_word}") osc_client.send_message(addr, (local_word, True, False)) last_change = time.time() remote_word = local_word @@ -354,7 +351,7 @@ def osc_thread(shared_data: SharedThreadData): tokenizer = get_tokenizer() # Prime the board - print("Priming the board") + log("Priming the board") input_state = InputState() handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) @@ -424,7 +421,7 @@ def vrInputThread(shared_data: SharedThreadData): elif now - last_rising > 0.5: # Medium press - print("CLEARING", file=sys.stderr) + log_err("CLEARING") last_medium_press_end = now state = PAUSE_STATE play_sound_with_volume(waveform2, shared_data.cfg) @@ -439,25 +436,23 @@ def vrInputThread(shared_data: SharedThreadData): # Short hold if state == RECORD_STATE: - print("PAUSED", file=sys.stderr) + log_err("PAUSED") state = PAUSE_STATE shared_data.stream.pause(True) play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: - print("RECORDING", file=sys.stderr) + log_err("RECORDING") state = RECORD_STATE if shared_data.cfg["reset_on_toggle"]: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (3)", - file=sys.stderr) + log_err("Toggle detected, dropping transcript (3)") shared_data.transcript = "" shared_data.preview = "" #audio_state.drop_transcription = True else: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (3)", - file=sys.stderr) + log_err("Toggle detected, committing preview text (3)") #audio_state.text += audio_state.preview_text shared_data.stream.pause(False) @@ -502,7 +497,7 @@ def kbInputThread(shared_data: SharedThreadData): last_press_time = cur_press_time if event == EVENT_DOUBLE_PRESS: - print("CLEARING", file=sys.stderr) + log_err("CLEARING") state = PAUSE_STATE play_sound_with_volume(waveform2, shared_data.cfg) @@ -516,23 +511,21 @@ def kbInputThread(shared_data: SharedThreadData): # Short hold if state == RECORD_STATE: - print("PAUSED", file=sys.stderr) + log_err("PAUSED") state = PAUSE_STATE shared_data.stream.pause(True) play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: - print("RECORDING", file=sys.stderr) + log_err("RECORDING") state = RECORD_STATE if shared_data.cfg["reset_on_toggle"]: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (2)", - file=sys.stderr) + log_err("Toggle detected, dropping transcript (2)") shared_data.transcript = "" shared_data.preview = "" else: if shared_data.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (2)", - file=sys.stderr) + log_err("Toggle detected, committing preview text (2)") shared_data.stream.pause(False) play_sound_with_volume(waveform0, shared_data.cfg) @@ -545,7 +538,7 @@ def play_sound_with_volume(filepath, cfg): sound.set_volume(volume * 0.01) sound.play() except Exception as e: - print(f"Error playing sound {filepath}: {e}", file=sys.stderr) + log_err(f"Error playing sound {filepath}: {e}") if __name__ == "__main__": cli_args = parse_args() diff --git a/app/logger.py b/app/logger.py new file mode 100644 index 0000000..72a2134 --- /dev/null +++ b/app/logger.py @@ -0,0 +1,12 @@ +import sys +import io + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + +def log(message): + print(message, file=sys.stdout, flush=True) + +def log_err(message): + print(message, file=sys.stderr, flush=True) + diff --git a/app/requirements.txt b/app/requirements.txt index c8d69df..dc294e5 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -3,10 +3,13 @@ hf-xet keyboard langcodes noisereduce +pandas pyaudio pygame pydub python-osc +scikit-learn sentencepiece silero-vad openvr +joblib @@ -2,15 +2,11 @@ from datetime import datetime from faster_whisper import WhisperModel import json import langcodes +from logger import log, log_err +import noisereduce as nr import numpy as np import os -import noisereduce as nr -try: - from profanity_filter import ProfanityFilter - PROFANITY_FILTER_AVAILABLE = True -except ImportError: - PROFANITY_FILTER_AVAILABLE = False - print("Warning: profanity_filter module not available", file=sys.stderr) +from profanity_filter import ProfanityFilter import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -19,6 +15,7 @@ import sys import time import typing import wave +from hallucination_filter import HallucinationFilter APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -55,7 +52,7 @@ class MicStream(AudioStream): which_mic = cfg["microphone"] if cfg["enable_debug_mode"]: - print(f"Finding mic {which_mic}", file=sys.stderr) + log(f"Finding mic {which_mic}") self.dumpMicDevices() got_match = False @@ -70,8 +67,8 @@ class MicStream(AudioStream): target_str = "Microphone (Beyond)" else: if cfg["enable_debug_mode"]: - print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID", file=sys.stderr) + log(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID") device_index = int(which_mic) got_match = True if not got_match: @@ -81,8 +78,7 @@ class MicStream(AudioStream): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') if target_str in device_name: - print(f"Got matching mic: {device_name}", - file=sys.stderr) + log(f"Got matching mic: {device_name}") device_index = i got_match = True break @@ -91,10 +87,10 @@ class MicStream(AudioStream): info = self.p.get_device_info_by_host_api_device_index(0, device_index) if cfg["enable_debug_mode"]: - print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + log(f"Found mic {which_mic}: {info['name']}") self.sample_rate = int(info['defaultSampleRate']) if cfg["enable_debug_mode"]: - print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + log(f"Mic sample rate: {self.sample_rate}") self.stream = self.p.open( rate=self.sample_rate, @@ -119,7 +115,7 @@ class MicStream(AudioStream): for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) + log("Input Device id ", i, " - ", device_name) def onAudioFramesAvailable(self, frames, @@ -278,7 +274,7 @@ class BoostingAudioCollector(AudioCollectorFilter): frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB) if self.cfg["enable_debug_mode"]: - print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True) + log(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})") audio = audio.apply_gain(gain) frames = np.array(audio.get_array_of_samples()) @@ -335,7 +331,6 @@ class AudioSegmenter: # Load Silero VAD model self.model = load_silero_vad() - self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s self.min_speech_duration_ms = min_speech_duration_ms @@ -351,7 +346,6 @@ class AudioSegmenter: audio_array, self.model, sampling_rate=AudioStream.FPS, - threshold=self.vad_threshold, min_silence_duration_ms=self.min_silence_duration_ms, max_speech_duration_s=self.max_speech_duration_s, min_speech_duration_ms=self.min_speech_duration_ms, @@ -371,12 +365,9 @@ class AudioSegmenter: for i in range(len(segments)): s = segments[i] - #print(f"s: {s}") - #print(f"last_end: {last_end}") if last_end: delta_frames = s['start'] - last_end - #print(f"delta frames: {delta_frames}") if delta_frames > min_delta_frames: cutoff = s['start'] else: @@ -384,8 +375,6 @@ class AudioSegmenter: if i == len(segments) - 1: now = int(len(audio) / AudioStream.FRAME_SZ) - #print(f"now: {now}") - #print(f"min d: {min_delta_frames}") delta_frames = now - s['end'] if delta_frames > min_delta_frames: cutoff = now - int(min_delta_frames / 2) @@ -402,7 +391,8 @@ class Segment: wall_ts: float, avg_logprob: float, no_speech_prob: float, - compression_ratio: float): + compression_ratio: float, + audio_len_s: float): self.transcript = transcript # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. self.start_ts = start_ts @@ -413,6 +403,7 @@ class Segment: self.avg_logprob = avg_logprob self.no_speech_prob = no_speech_prob self.compression_ratio = compression_ratio + self.audio_len_s = audio_len_s def __str__(self): ts = f"(ts: {self.start_ts}-{self.end_ts}) " @@ -423,7 +414,8 @@ class Segment: no_speech = f"(no_speech: {self.no_speech_prob}) " avg_logprob = f"(avg_logprob: {self.avg_logprob}) " - return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s = f"(max_len_s: {self.audio_len_s}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s def join_segments(a, b): if len(a) > 0 and a[-1] != ' ': @@ -431,20 +423,76 @@ def join_segments(a, b): else: return a + b + +class SegmentLogger: + def __init__(self, cfg: typing.Dict): + self.cfg = cfg + self.enabled = cfg.get("enable_segment_logging", False) + self.session_data = [] + self.log_file = None + + if self.enabled: + log_dir = os.path.join(PROJECT_ROOT, "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Create file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") + log(f"Segment logging enabled. Logging to: {self.log_file}") + + def log_segment(self, segment: Segment, commit_type: str = "commit"): + if not self.enabled: + return + + segment_data = { + "timestamp": datetime.now().isoformat(), + "type": commit_type, + "text": segment.transcript, + "start_ts": segment.start_ts, + "end_ts": segment.end_ts, + "wall_ts": segment.wall_ts, + "avg_logprob": segment.avg_logprob, + "no_speech_prob": segment.no_speech_prob, + "compression_ratio": segment.compression_ratio, + "duration": segment.end_ts - segment.start_ts, + "duration_sanity": segment.audio_len_s + } + + self.session_data.append(segment_data) + + # Write to file incrementally + try: + with open(self.log_file, 'w') as f: + json.dump({ + "session_start": self.session_data[0]["timestamp"] if self.session_data else None, + "segments": self.session_data + }, f, indent=2) + except Exception as e: + log_err(f"Error writing segment log: {e}") + + def close(self): + if self.enabled and self.session_data: + log(f"Session complete. Logged {len(self.session_data)} " + \ + "segments to {self.log_file}") + + class Whisper: def __init__(self, collector: AudioCollector, - cfg: typing.Dict): + cfg: typing.Dict, + segment_logger: SegmentLogger = None): self.collector = collector self.model = None self.cfg = cfg + self.hallucination_filter = HallucinationFilter() + self.segment_logger = segment_logger model_str = cfg["model"] model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) + log(f"Model {cfg['model']} will be saved to {model_root}") model_device = "cuda" compute_type = cfg["compute_type"] @@ -455,7 +503,7 @@ class Whisper: already_downloaded = os.path.exists(model_root) if not already_downloaded: - print(f"Model {model_str} not already downloaded, downloading now...", flush=True) + log(f"Model {model_str} not already downloaded, downloading now...") self.model = WhisperModel(model_str, device = model_device, @@ -483,19 +531,20 @@ class Whisper: # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 + audio_len_s = len(frames) / 16000.0 # Build context-aware prompt prompt = self._build_prompt() if self.cfg["enable_debug_mode"]: - print(f"Prompt: {prompt}", flush=True) + log(f"Prompt: {prompt}") t0 = time.time() segments, info = self.model.transcribe( audio, language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, + vad_filter = False, + #temperature=0.0, without_timestamps = False, initial_prompt=prompt, beam_size=self.cfg.get("beam_size", 5), @@ -504,44 +553,44 @@ class Whisper: ) res = [] for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 1) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - # Another touchup targeted at the vexatious "thanks for watching!" - # hallucination. This triggers a lot when listening to - # instrumental/electronic music. - if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 2) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - if s.avg_logprob < -0.75: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 3) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue + # Log raw segment before filtering + if self.segment_logger: + # Create a temporary segment object for logging + raw_seg = Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + self.segment_logger.log_segment(raw_seg, "raw") + # Sometimes the model reports a bum duration, breaking our filters. + # Cap the segment length above by the length of the audio in. + duration_s = min(s.end - s.start, audio_len_s) + if self.cfg["enable_debug_mode"]: - print(f"s get: {s}") + log(f"s get: {s}") if s.avg_logprob < -1.0: continue if s.compression_ratio > 2.4: continue - res.append(Segment(s.text, s.start, s.end, + + # Create segment object + seg = Segment(s.text, s.start, s.end, self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + s.avg_logprob, s.no_speech_prob, s.compression_ratio, + audio_len_s) + + # Check with ML model for "Thank you" hallucinations + if self.hallucination_filter.is_thank_you_hallucination(seg): + if self.cfg["enable_debug_mode"]: + log(f"Drop probable hallucination (case 4) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})") + continue + + res.append(seg) t1 = time.time() if self.cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") + log(f"Transcription latency (s): {t1 - t0}") return res def _build_prompt(self) -> str: @@ -580,64 +629,13 @@ class TranscriptCommit: def saveAudio(audio: bytes, path: str, cfg: typing.Dict): with wave.open(path, 'wb') as wf: if cfg["enable_debug_mode"]: - print(f"Saving audio to {path}", file=sys.stderr) + log(f"Saving audio to {path}") wf.setnchannels(AudioStream.CHANNELS) wf.setsampwidth(AudioStream.FRAME_SZ) wf.setframerate(AudioStream.FPS) wf.writeframes(audio) -class SegmentLogger: - def __init__(self, cfg: typing.Dict): - self.cfg = cfg - self.enabled = cfg.get("enable_segment_logging", False) - self.session_data = [] - self.log_file = None - - if self.enabled: - log_dir = os.path.join(PROJECT_ROOT, "logs") - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - # Create file - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") - print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr) - - def log_segment(self, segment: Segment, commit_type: str = "commit"): - if not self.enabled: - return - - segment_data = { - "timestamp": datetime.now().isoformat(), - "type": commit_type, - "text": segment.transcript, - "start_ts": segment.start_ts, - "end_ts": segment.end_ts, - "wall_ts": segment.wall_ts, - "avg_logprob": segment.avg_logprob, - "no_speech_prob": segment.no_speech_prob, - "compression_ratio": segment.compression_ratio, - "duration": segment.end_ts - segment.start_ts - } - - self.session_data.append(segment_data) - - # Write to file incrementally - try: - with open(self.log_file, 'w') as f: - json.dump({ - "session_start": self.session_data[0]["timestamp"] if self.session_data else None, - "segments": self.session_data - }, f, indent=2) - except Exception as e: - print(f"Error writing segment log: {e}", file=sys.stderr) - - def close(self): - if self.enabled and self.session_data: - print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr) - - class VadCommitter: def __init__(self, cfg: typing.Dict, @@ -680,16 +678,12 @@ class VadCommitter: if delta.strip(): self.whisper.update_context(delta.strip()) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "commit") - audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: - print(f"commit segment: {s}", file=sys.stderr) + log(f"commit segment: {s}") if len(delta) > 0: - print(f"delta get: {delta}", file=sys.stderr) + log(f"delta get: {delta}") if self.cfg["save_audio"] and len(delta) > 0: ts = datetime.fromtimestamp(self.collector.now() - latency_s) @@ -705,10 +699,6 @@ class VadCommitter: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) - if self.segment_logger: - for s in segments: - self.segment_logger.log_segment(s, "preview") - if not has_audio: self.collector.keepLast(1.0) @@ -758,13 +748,13 @@ class ProfanityPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg self.filter = None - if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + if cfg["enable_profanity_filter"]: en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") try: self.filter = ProfanityFilter(en_profanity_path) self.filter.load() except Exception as e: - print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + log_err(f"Warning: Could not load profanity filter: {e}") self.filter = None def transform(self, commit: TranscriptCommit) -> TranscriptCommit: @@ -812,12 +802,11 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.cfg) collector = NoiseReducingAudioCollector(collector, shared_data.cfg) #collector = NormalizingAudioCollector(collector) - whisper = Whisper(collector, shared_data.cfg) + segment_logger = SegmentLogger(shared_data.cfg) + whisper = Whisper(collector, shared_data.cfg, segment_logger) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"]) - - segment_logger = SegmentLogger(shared_data.cfg) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger) plugins = [] @@ -838,7 +827,7 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.stream = stream shared_data.collector = collector - print(f"Ready to go!", flush=True) + log(f"Ready to go!") while not shared_data.exit_event.is_set(): time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); @@ -862,8 +851,8 @@ def transcriptionThread(shared_data: SharedThreadData): silence_duration = commit.start_ts - last_commit_end_ts if silence_duration > shared_data.cfg["reset_after_silence_s"]: if shared_data.cfg["enable_debug_mode"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) + log(f"Resetting transcript after {silence_duration}-second " + "silence") shared_data.transcript = "" shared_data.preview = "" whisper.recent_context = "" # Reset context too @@ -885,20 +874,18 @@ def transcriptionThread(shared_data: SharedThreadData): shared_data.preview) try: - print(f"Transcript: {shared_data.transcript}", flush=True) + log(f"Transcript: {shared_data.transcript}") except UnicodeEncodeError: - print("Failed to encode transcript - discarding delta", - file=sys.stderr) + log_err("Failed to encode transcript - discarding delta") continue try: - print(f"Preview: {shared_data.preview}", flush=True) + log(f"Preview: {shared_data.preview}") except UnicodeEncodeError: - print("Failed to encode preview - discarding", file=sys.stderr) + log_err("Failed to encode preview - discarding") if shared_data.cfg["enable_debug_mode"]: - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", - file=sys.stderr) + log(f"commit latency: {commit.latency_s}") + log(f"commit thresh: {commit.thresh_at_commit}") if len(shared_data.transcript) > 0 and \ (not shared_data.transcript.endswith(' ')) and \ diff --git a/train_hallucination_filter.py b/train_hallucination_filter.py new file mode 100644 index 0000000..446e893 --- /dev/null +++ b/train_hallucination_filter.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +import json +import os +import re +from pathlib import Path +import numpy as np +import pandas as pd +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.pipeline import Pipeline +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, confusion_matrix +import joblib +import warnings +warnings.filterwarnings('ignore') + +try: + import pronouncing + HAS_PRONOUNCING = True +except ImportError: + HAS_PRONOUNCING = False + print("Warning: pronouncing library not found. Using regex fallback for syllable counting.") + +def count_syllables(word): + """Count syllables in a word using pronouncing library with regex fallback.""" + if HAS_PRONOUNCING: + phones = pronouncing.phones_for_word(word.lower()) + if phones: + return pronouncing.syllable_count(phones[0]) + + # Fallback: count vowel groups + vowel_groups = re.findall(r'[aeiouy]+', word, re.IGNORECASE) + return max(1, len(vowel_groups)) + +def text_syllable_count(text): + """Count total syllables in text.""" + words = re.findall(r'\b\w+\b', text) + return sum(count_syllables(word) for word in words) + +def load_segments(log_dir): + """Load segments from JSON files.""" + segments = [] + + for root, dirs, files in os.walk(log_dir): + for file in files: + if not file.endswith('.json'): + continue + try: + with open(os.path.join(root, file), 'r') as f: + data = json.load(f) + + for segment in data.get('segments', []): + if 'duration_sanity' not in segment: + continue + + # Extract all available features + text = segment.get('text', '') + duration = segment.get('duration_sanity', 0) + + # Calculate raw duration from timestamps + start_ts = segment.get('start_ts', 0) + end_ts = segment.get('end_ts', 0) + raw_duration = end_ts - start_ts + + seg_data = { + 'avg_logprob': segment.get('avg_logprob', 0), + 'no_speech_prob': segment.get('no_speech_prob', 0), + 'duration_sanity': duration, + 'raw_duration': raw_duration, + 'compression_ratio': segment.get('compression_ratio', 1), + 'text': text + } + + # Add speech rate features + n_words = len(re.findall(r'\b\w+\b', text)) + n_chars = len(text) + n_syllables = text_syllable_count(text) + + seg_data['n_words'] = n_words + seg_data['n_syllables'] = n_syllables + seg_data['n_chars'] = n_chars + + # Calculate rates (words/syllables/chars per second) + seg_data['sps'] = n_syllables / duration + + # Calculate raw speech rate (using timestamp-based duration) + seg_data['raw_sps'] = n_syllables / raw_duration + + # Add derived features + seg_data['log_duration'] = np.log1p(duration) + seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration + seg_data['log_sps'] = np.log1p(seg_data['sps']) # Log-scaled speech rate + seg_data['log_raw_duration'] = np.log1p(raw_duration) + seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0 + seg_data['raw_log_sps'] = np.log1p(seg_data['raw_sps']) # Log-scaled raw speech rate + + segments.append(seg_data) + except Exception as e: + print(f"Error loading {file}: {e}") + + return pd.DataFrame(segments) + +def main(): + # Find logs directory + log_dir = None + for pattern in ["ui/dist/logs", "logs", "ui/dist/*/logs", "ui/dist/*/*/logs", "ui/dist/*/*/*/logs"]: + paths = list(Path(".").glob(pattern)) + if paths: + log_dir = str(paths[0]) + break + + if not log_dir: + print("Could not find logs directory.") + return + + # Load data + print("Loading segments from logs...") + df = load_segments(log_dir) + + if len(df) == 0: + print("No segments found in logs!") + return + + print(f"Loaded {len(df)} segments") + + # Print speech rate statistics + print("\nSpeech rate statistics:") + print(f"Syllables per second: mean={df['sps'].mean():.2f}, std={df['sps'].std():.2f}, max={df['sps'].max():.2f}") + print(f"Raw syllables per second: mean={df['raw_sps'].mean():.2f}, std={df['raw_sps'].std():.2f}, max={df['raw_sps'].max():.2f}") + print(f"Duration ratio (raw/sanity): mean={df['duration_ratio'].mean():.2f}, std={df['duration_ratio'].std():.2f}") + + # Step 1: Apply heuristic rules for seed labeling + print("\nApplying heuristic rules for seed labeling...") + + # Conservative positive seeds (likely hallucinations) + h_pos = ( + (df['avg_logprob'] < -0.8) # This low of a logprob is almost always a hallucination + | (df['compression_ratio'] > 2.3) # High compressibility is usually a hallucination + | (df['sps'] > 6) # No one speaks this fast + | (df['sps'] < 0.5) # No one speaks this slow + ) + + # Conservative negative seeds (likely valid) + h_neg = ( + (df['avg_logprob'] > -0.5) # solid confidence drop + & (df['compression_ratio'] < 1.2) + & (df['sps'] < 6) + & (df['sps'] > 0.5) + ) + + # Create seed labels (NaN for unlabeled) + df['seed_label'] = np.where(h_pos, 1, + np.where(h_neg, 0, np.nan)) + + n_pos_seeds = (df['seed_label'] == 1).sum() + n_neg_seeds = (df['seed_label'] == 0).sum() + n_unlabeled = df['seed_label'].isna().sum() + + print(f"Seed labeling results:") + print(f" Positive seeds (hallucinations): {n_pos_seeds} ({n_pos_seeds/len(df):.1%})") + print(f" Negative seeds (valid): {n_neg_seeds} ({n_neg_seeds/len(df):.1%})") + print(f" Unlabeled: {n_unlabeled} ({n_unlabeled/len(df):.1%})") + + if n_pos_seeds == 0 or n_neg_seeds == 0: + print("Warning: Not enough seed labels. Adjusting thresholds might help.") + return + + # Show examples of positive seeds + pos_seeds = df[df['seed_label'] == 1].head(5) + if len(pos_seeds) > 0: + print(f"\nExample positive seeds (likely hallucinations):") + for _, seg in pos_seeds.iterrows(): + print(f" SPS={seg['sps']:.1f}, logprob={seg['avg_logprob']:.2f}, text='{seg['text'][:50]}...'") + + # Define features + features = ['avg_logprob', 'duration_sanity', 'no_speech_prob', + 'compression_ratio', 'log_duration', 'logprob_duration_interaction', + 'sps', 'log_sps', 'raw_duration', 'log_raw_duration', + 'duration_ratio', 'raw_log_sps'] + + X = df[features].values + + # Step 2: Train kNN on seed labels + print("\nTraining k-NN classifier on seed labels...") + + labeled_mask = df['seed_label'].notna() + X_seed = X[labeled_mask] + y_seed = df.loc[labeled_mask, 'seed_label'].values.astype(int) + + # Create pipeline with scaling (important for kNN) + knn_pipeline = Pipeline([ + ('scale', StandardScaler()), + ('knn', KNeighborsClassifier( + n_neighbors=15, # adjust based on dataset size + weights='distance' # closer neighbors weigh more + )) + ]) + + knn_pipeline.fit(X_seed, y_seed) + + # Step 3: Predict on all data + print("Propagating labels to unlabeled segments...") + + # Get probabilities for all segments + proba = knn_pipeline.predict_proba(X)[:, 1] # probability of being hallucination + df['knn_score'] = proba + + # Choose threshold - use 95th percentile of negative seeds + neg_seed_scores = proba[df['seed_label'] == 0] + threshold = max(0.05, np.percentile(neg_seed_scores, 95)) + + print(f"\nChosen threshold: {threshold:.3f}") + print(f"Based on 95th percentile of negative seed scores") + + # Apply threshold + df['is_hallucination'] = (proba >= threshold).astype(int) + + # Print results + n_hallucinations = df['is_hallucination'].sum() + print(f"\nDetected hallucinations: {n_hallucinations} ({n_hallucinations/len(df):.1%})") + + # Step 4: Train final gradient boosting model on kNN labels + print("\nTraining final Gradient Boosting classifier...") + + X_final = df[features] + y_final = df['is_hallucination'] + + # Split data + X_train, X_test, y_train, y_test = train_test_split( + X_final, y_final, test_size=0.3, stratify=y_final, random_state=42 + ) + + # Train model + model = GradientBoostingClassifier( + n_estimators=80, + max_depth=3, + learning_rate=0.05, + random_state=42 + ) + model.fit(X_train, y_train) + + # Evaluate + y_pred = model.predict(X_test) + y_proba_gb = model.predict_proba(X_test)[:, 1] + + print("\nFinal Model Performance:") + print(classification_report(y_test, y_pred)) + + # Confusion matrix + tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel() + tpr = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 + + print(f"\nDetection rate (TPR): {tpr:.1%}") + print(f"False positive rate (FPR): {fpr:.1%}") + + # Feature importance + print("\nFeature Importance:") + for feat, imp in sorted(zip(features, model.feature_importances_), + key=lambda x: x[1], reverse=True): + print(f" {feat}: {imp:.3f}") + + # Show example detections + hallucination_examples = df[df['is_hallucination'] == 1].head(10) + if len(hallucination_examples) > 0: + print(f"\nExample detected hallucinations:") + for _, seg in hallucination_examples.iterrows(): + print(f" Score={seg['knn_score']:.3f}, SPS={seg['sps']:.1f}, text='{seg['text'][:60]}...'") + + # Save model + model_dir = Path("Models") + model_dir.mkdir(exist_ok=True) + + model_bundle = { + "model": model, + "threshold": 0.5, # Using standard threshold since we trained on binary labels + "features": features, + "heuristic_thresholds": { + "avg_logprob_high": -1.0, + "compression_ratio_high": 2.4, + "sps_high": 9.0 + } + } + + output_path = model_dir / "hallucination_filter_gb.pkl" + joblib.dump(model_bundle, output_path) + print(f"\nModel saved to: {output_path}") + +if __name__ == "__main__": + main() diff --git a/ui/build_scripts/setup-embedded-python.js b/ui/build_scripts/setup-embedded-python.js new file mode 100644 index 0000000..0622915 --- /dev/null +++ b/ui/build_scripts/setup-embedded-python.js @@ -0,0 +1,104 @@ +const { execSync } = require('child_process'); +const path = require('path'); +const fs = require('fs'); +const https = require('https'); +const { promisify } = require('util'); +const stream = require('stream'); +const pipeline = promisify(stream.pipeline); +const extract = require('extract-zip'); + +const projectRoot = path.join(__dirname, '..', '..'); +const pythonPath = path.join(projectRoot, 'python_embedded'); +const dllPath = path.join(projectRoot, 'dll_empty'); + +const PYTHON_URL = 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip'; +const PIP_URL = 'https://bootstrap.pypa.io/get-pip.py'; + +async function downloadFile(url, dest) { + console.log(`Downloading ${url}...`); + const file = fs.createWriteStream(dest); + + return new Promise((resolve, reject) => { + https.get(url, (response) => { + if (response.statusCode === 302 || response.statusCode === 301) { + // Handle redirect + return downloadFile(response.headers.location, dest).then(resolve).catch(reject); + } + + response.pipe(file); + file.on('finish', () => { + file.close(); + console.log(`Downloaded to ${dest}`); + resolve(); + }); + }).on('error', (err) => { + fs.unlink(dest, () => {}); // Delete the file on error + reject(err); + }); + }); +} + +async function setupEmbeddedPython() { + console.log('Setting up embedded Python...'); + + // Delete existing directories + if (fs.existsSync(pythonPath)) { + fs.rmSync(pythonPath, { recursive: true, force: true }); + console.log('Deleted existing Python directory'); + } + if (fs.existsSync(dllPath)) { + fs.rmSync(dllPath, { recursive: true, force: true }); + console.log('Deleted existing dll directory'); + } + + // Create directories + fs.mkdirSync(pythonPath, { recursive: true }); + fs.mkdirSync(dllPath, { recursive: true }); + console.log('Created Python and dll directories'); + + // Download Python + const pythonZip = path.join(projectRoot, 'python-3.10.11-embed-amd64.zip'); + if (!fs.existsSync(pythonZip)) { + await downloadFile(PYTHON_URL, pythonZip); + } + + // Extract Python + console.log('Extracting Python...'); + await extract(pythonZip, { dir: pythonPath }); + console.log('Python extracted successfully'); + + // Update python310._pth to include the app directory and enable site packages + const pthFile = path.join(pythonPath, 'python310._pth'); + const pthContent = fs.readFileSync(pthFile, 'utf8'); + fs.writeFileSync(pthFile, pthContent + '\n../app\nimport site\n'); + console.log('Updated python310._pth'); + + // Download get-pip.py + const getPipPath = path.join(pythonPath, 'get-pip.py'); + await downloadFile(PIP_URL, getPipPath); + + // Install pip + console.log('Installing pip...'); + try { + execSync(`"${path.join(pythonPath, 'python.exe')}" "${getPipPath}"`, { + stdio: 'inherit', + cwd: pythonPath + }); + console.log('pip installed successfully'); + } catch (error) { + console.error('Failed to install pip:', error); + process.exit(1); + } + + // Clean up + fs.unlinkSync(getPipPath); + + console.log('Embedded Python setup complete!'); +} + +// Run the setup +setupEmbeddedPython().catch(err => { + console.error('Setup failed:', err); + process.exit(1); +}); + diff --git a/ui/config-schema.js b/ui/config-schema.js index fb90f3f..39b74b6 100644 --- a/ui/config-schema.js +++ b/ui/config-schema.js @@ -29,7 +29,7 @@ const CONFIG_SCHEMA = { enable_debug_mode: { type: 'boolean', default: 0 }, enable_previews: { type: 'boolean', default: 1 }, save_audio: { type: 'boolean', default: 0 }, - enable_segment_logging: { type: 'boolean', default: 0 }, + enable_segment_logging: { type: 'boolean', default: 1 }, use_cpu: { type: 'boolean', default: 0 }, enable_lowercase_filter: { type: 'boolean', default: 0 }, enable_uppercase_filter: { type: 'boolean', default: 0 }, @@ -54,4 +54,4 @@ if (typeof module !== 'undefined' && module.exports) { } else { window.CONFIG_SCHEMA = CONFIG_SCHEMA; window.getDefaultConfig = getDefaultConfig; -}
\ No newline at end of file +} diff --git a/ui/package.json b/ui/package.json index ce22dee..a3647dc 100644 --- a/ui/package.json +++ b/ui/package.json @@ -63,6 +63,11 @@ "from": "../Sounds", "to": "Sounds", "filter": ["*.wav"] + }, + { + "from": "../Models", + "to": "Models", + "filter": ["**/*.pkl"] } ], "win": { |
