summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-25 21:28:50 -0700
committeryum <yum.food.vr@gmail.com>2025-07-25 21:28:50 -0700
commita7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (patch)
tree61d4870a019acb0e545d88e7661c8a4c7d90e499 /app
parent5df013d26eb13ed4aef8d16aa14346e0f9be5111 (diff)
Experiment with hallucination reduction
- update cursorignore - add hallucination filter training & inference code - put logging into a central module - segment metadata logging occurs before filtering - segment metadata logging is on by default - check in embedded python setup script - include trained hallucination filter model
Diffstat (limited to 'app')
-rw-r--r--app/hallucination_filter.py66
-rw-r--r--app/hi.py65
-rw-r--r--app/logger.py12
-rw-r--r--app/requirements.txt3
-rw-r--r--app/stt.py267
5 files changed, 237 insertions, 176 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py
new file mode 100644
index 0000000..9b24a85
--- /dev/null
+++ b/app/hallucination_filter.py
@@ -0,0 +1,66 @@
+import io
+import joblib
+from logger import log, log_err
+import numpy as np
+import pandas as pd
+from pathlib import Path
+import sys
+
+
+class HallucinationFilter:
+ """Filter for detecting hallucinated segments in speech-to-text output."""
+
+ def __init__(self, model_path: Path = None):
+ """
+ Initialize the hallucination filter.
+
+ Args:
+ model_path: Optional path to the model file. If not provided,
+ uses the default path.
+ """
+ self.model = None
+ self.threshold = None
+ self.features = None
+
+ # Get the project root directory
+ app_root = Path(__file__).resolve().parent
+ project_root = app_root.parent
+
+ # Use provided path or default
+ if model_path is None:
+ model_path = project_root / "Models" / "thankyou_filter_gb.pkl"
+
+ # Try to load the model
+ log_err(f"Loading hallucination filter")
+ bundle = joblib.load(model_path)
+ self.model = bundle["model"]
+ self.threshold = bundle["threshold"]
+ self.features = bundle["features"] # Extract feature names
+ log_err(f"Loaded hallucination filter model from {model_path}")
+
+ def is_thank_you_hallucination(self, segment) -> bool:
+ """
+ Check if a segment is likely a "Thank you" hallucination.
+ Returns False if model is not available.
+
+ Args:
+ segment: A segment object with attributes avg_logprob, audio_len_s,
+ no_speech_prob, and compression_ratio.
+
+ Returns:
+ bool: True if the segment is likely a hallucination, False otherwise.
+ """
+ # Create DataFrame with proper feature names
+ X = pd.DataFrame([[
+ segment.avg_logprob,
+ segment.audio_len_s,
+ segment.no_speech_prob,
+ segment.compression_ratio,
+ np.log1p(segment.audio_len_s),
+ segment.avg_logprob * segment.audio_len_s
+ ]], columns=self.features)
+
+ # Get probability
+ prob = self.model.predict_proba(X)[0, 1]
+ return prob >= self.threshold
+
diff --git a/app/hi.py b/app/hi.py
index bb09418..7ea4616 100644
--- a/app/hi.py
+++ b/app/hi.py
@@ -2,9 +2,11 @@ import app_config
import argparse
import io
import keybind_event_machine
+from logger import log, log_err
from math import floor, ceil
import msvcrt
import os
+import pygame
from pythonosc import udp_client
import sentencepiece as spm
import steamvr
@@ -13,10 +15,6 @@ import stt
import sys
import threading
import time
-import pygame
-
-sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
-sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
# Initialize pygame mixer
pygame.mixer.init()
@@ -31,10 +29,10 @@ PROJECT_ROOT = os.path.dirname(APP_ROOT)
def get_tokenizer():
model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model")
- print(f"Loading SentencePiece tokenizer from: {model_path}")
+ log(f"Loading SentencePiece tokenizer from: {model_path}")
sp = spm.SentencePieceProcessor()
sp.load(model_path)
- print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
+ log(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
return sp
def parse_args():
@@ -137,16 +135,16 @@ def wrap_line(line: str, cols):
def get_blocks(lines, tokenizer, block_width, num_blocks):
if LOG_LEVEL == 2:
- print(f"Lines sent to tokenizer: {''.join(lines)}")
+ log(f"Lines sent to tokenizer: {''.join(lines)}")
tokens = tokenizer.encode_as_ids(''.join(lines))
if LOG_LEVEL == 2:
- print(f"Tokens: {tokens}")
+ log(f"Tokens: {tokens}")
pieces = []
for tok in tokens:
piece = tokenizer.id_to_piece(tok)
pieces.append(piece)
if LOG_LEVEL == 2:
- print(f"Pieces: {pieces}")
+ log(f"Pieces: {pieces}")
# Group tokens into blocks and pad with empty characters.
# Also get visual pointers - the location where each block will be rendered.
@@ -168,8 +166,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks):
return (blocks, visual_pointers)
blocks, visual_pointers = get_blocks()
if LOG_LEVEL == 2:
- print(f"Blocks: {blocks}")
- print(f"Visual pointers: {visual_pointers}")
+ log(f"Blocks: {blocks}")
+ log(f"Visual pointers: {visual_pointers}")
# Set all blocks up to the next `num_blocks` boundary to blank tokens.
# This handles the edge case where a prior message wrote data there which
@@ -183,8 +181,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks):
return blocks, visual_pointers
blocks, visual_pointers = pad_blocks(blocks, visual_pointers)
if LOG_LEVEL == 2:
- print(f"Blocks (padded): {blocks}")
- print(f"Visual pointers (padded): {visual_pointers}")
+ log(f"Blocks (padded): {blocks}")
+ log(f"Visual pointers (padded): {visual_pointers}")
return blocks, visual_pointers
@@ -223,11 +221,10 @@ def send_data(osc_client, indices, blocks, visual_pointers):
blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks)
if LOG_LEVEL == 2:
- print(f"Blocks (byte 00): {blocks_byte00}")
- print(f"Blocks (byte 01): {blocks_byte01}")
+ log(f"Blocks (byte 00): {blocks_byte00}")
+ log(f"Blocks (byte 01): {blocks_byte01}")
def send_osc(osc_client, addr, data):
- #print(f"Sending {data} to {addr}")
osc_client.send_message(addr, data)
for i in range(0, len(blocks)):
@@ -241,7 +238,7 @@ def send_data(osc_client, indices, blocks, visual_pointers):
addr = "/avatar/parameters/" + vp_param
send_osc(osc_client, addr, vp_float)
if LOG_LEVEL == 2:
- print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
+ log(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
for j in range(0, len(blocks[i])):
byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5
byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5
@@ -271,7 +268,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
for line in line_wrapped:
assert_equal(len(line), cfg["cols"])
if LOG_LEVEL == 2:
- print(f"Wrapped lines: {line_wrapped}")
+ log(f"Wrapped lines: {line_wrapped}")
# Get several blank lines whenever we roll over.
# It's better for the reader to have some continuity when the board pages
@@ -312,7 +309,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
state.blocks.append(diff_blocks[0])
state.visual_pointers.append(diff_visual_pointers[0])
elif indices[0] > len(state.blocks):
- print(f"This should never happen!")
+ log(f"This should never happen!")
sys.exit(1)
else:
state.blocks[indices[0]] = diff_blocks[0]
@@ -345,7 +342,7 @@ def osc_thread(shared_data: SharedThreadData):
continue
addr = "/chatbox/input"
if shared_data.cfg["enable_debug_mode"]:
- print(f"Send {local_word}", flush=True)
+ log(f"Send {local_word}")
osc_client.send_message(addr, (local_word, True, False))
last_change = time.time()
remote_word = local_word
@@ -354,7 +351,7 @@ def osc_thread(shared_data: SharedThreadData):
tokenizer = get_tokenizer()
# Prime the board
- print("Priming the board")
+ log("Priming the board")
input_state = InputState()
handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
@@ -424,7 +421,7 @@ def vrInputThread(shared_data: SharedThreadData):
elif now - last_rising > 0.5:
# Medium press
- print("CLEARING", file=sys.stderr)
+ log_err("CLEARING")
last_medium_press_end = now
state = PAUSE_STATE
play_sound_with_volume(waveform2, shared_data.cfg)
@@ -439,25 +436,23 @@ def vrInputThread(shared_data: SharedThreadData):
# Short hold
if state == RECORD_STATE:
- print("PAUSED", file=sys.stderr)
+ log_err("PAUSED")
state = PAUSE_STATE
shared_data.stream.pause(True)
play_sound_with_volume(waveform1, shared_data.cfg)
elif state == PAUSE_STATE:
- print("RECORDING", file=sys.stderr)
+ log_err("RECORDING")
state = RECORD_STATE
if shared_data.cfg["reset_on_toggle"]:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, dropping transcript (3)",
- file=sys.stderr)
+ log_err("Toggle detected, dropping transcript (3)")
shared_data.transcript = ""
shared_data.preview = ""
#audio_state.drop_transcription = True
else:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, committing preview text (3)",
- file=sys.stderr)
+ log_err("Toggle detected, committing preview text (3)")
#audio_state.text += audio_state.preview_text
shared_data.stream.pause(False)
@@ -502,7 +497,7 @@ def kbInputThread(shared_data: SharedThreadData):
last_press_time = cur_press_time
if event == EVENT_DOUBLE_PRESS:
- print("CLEARING", file=sys.stderr)
+ log_err("CLEARING")
state = PAUSE_STATE
play_sound_with_volume(waveform2, shared_data.cfg)
@@ -516,23 +511,21 @@ def kbInputThread(shared_data: SharedThreadData):
# Short hold
if state == RECORD_STATE:
- print("PAUSED", file=sys.stderr)
+ log_err("PAUSED")
state = PAUSE_STATE
shared_data.stream.pause(True)
play_sound_with_volume(waveform1, shared_data.cfg)
elif state == PAUSE_STATE:
- print("RECORDING", file=sys.stderr)
+ log_err("RECORDING")
state = RECORD_STATE
if shared_data.cfg["reset_on_toggle"]:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, dropping transcript (2)",
- file=sys.stderr)
+ log_err("Toggle detected, dropping transcript (2)")
shared_data.transcript = ""
shared_data.preview = ""
else:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, committing preview text (2)",
- file=sys.stderr)
+ log_err("Toggle detected, committing preview text (2)")
shared_data.stream.pause(False)
play_sound_with_volume(waveform0, shared_data.cfg)
@@ -545,7 +538,7 @@ def play_sound_with_volume(filepath, cfg):
sound.set_volume(volume * 0.01)
sound.play()
except Exception as e:
- print(f"Error playing sound {filepath}: {e}", file=sys.stderr)
+ log_err(f"Error playing sound {filepath}: {e}")
if __name__ == "__main__":
cli_args = parse_args()
diff --git a/app/logger.py b/app/logger.py
new file mode 100644
index 0000000..72a2134
--- /dev/null
+++ b/app/logger.py
@@ -0,0 +1,12 @@
+import sys
+import io
+
+sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
+sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+
+def log(message):
+ print(message, file=sys.stdout, flush=True)
+
+def log_err(message):
+ print(message, file=sys.stderr, flush=True)
+
diff --git a/app/requirements.txt b/app/requirements.txt
index c8d69df..dc294e5 100644
--- a/app/requirements.txt
+++ b/app/requirements.txt
@@ -3,10 +3,13 @@ hf-xet
keyboard
langcodes
noisereduce
+pandas
pyaudio
pygame
pydub
python-osc
+scikit-learn
sentencepiece
silero-vad
openvr
+joblib
diff --git a/app/stt.py b/app/stt.py
index 4ec559b..8ab83a5 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -2,15 +2,11 @@ from datetime import datetime
from faster_whisper import WhisperModel
import json
import langcodes
+from logger import log, log_err
+import noisereduce as nr
import numpy as np
import os
-import noisereduce as nr
-try:
- from profanity_filter import ProfanityFilter
- PROFANITY_FILTER_AVAILABLE = True
-except ImportError:
- PROFANITY_FILTER_AVAILABLE = False
- print("Warning: profanity_filter module not available", file=sys.stderr)
+from profanity_filter import ProfanityFilter
import pyaudio
from pydub import AudioSegment
from shared_thread_data import SharedThreadData
@@ -19,6 +15,7 @@ import sys
import time
import typing
import wave
+from hallucination_filter import HallucinationFilter
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(APP_ROOT)
@@ -55,7 +52,7 @@ class MicStream(AudioStream):
which_mic = cfg["microphone"]
if cfg["enable_debug_mode"]:
- print(f"Finding mic {which_mic}", file=sys.stderr)
+ log(f"Finding mic {which_mic}")
self.dumpMicDevices()
got_match = False
@@ -70,8 +67,8 @@ class MicStream(AudioStream):
target_str = "Microphone (Beyond)"
else:
if cfg["enable_debug_mode"]:
- print(f"Mic {which_mic} requested, treating it as a numerical " +
- "device ID", file=sys.stderr)
+ log(f"Mic {which_mic} requested, treating it as a numerical " +
+ "device ID")
device_index = int(which_mic)
got_match = True
if not got_match:
@@ -81,8 +78,7 @@ class MicStream(AudioStream):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
if target_str in device_name:
- print(f"Got matching mic: {device_name}",
- file=sys.stderr)
+ log(f"Got matching mic: {device_name}")
device_index = i
got_match = True
break
@@ -91,10 +87,10 @@ class MicStream(AudioStream):
info = self.p.get_device_info_by_host_api_device_index(0, device_index)
if cfg["enable_debug_mode"]:
- print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr)
+ log(f"Found mic {which_mic}: {info['name']}")
self.sample_rate = int(info['defaultSampleRate'])
if cfg["enable_debug_mode"]:
- print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr)
+ log(f"Mic sample rate: {self.sample_rate}")
self.stream = self.p.open(
rate=self.sample_rate,
@@ -119,7 +115,7 @@ class MicStream(AudioStream):
for i in range(0, numdevices):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
- print("Input Device id ", i, " - ", device_name)
+ log("Input Device id ", i, " - ", device_name)
def onAudioFramesAvailable(self,
frames,
@@ -278,7 +274,7 @@ class BoostingAudioCollector(AudioCollectorFilter):
frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB)
if self.cfg["enable_debug_mode"]:
- print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True)
+ log(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})")
audio = audio.apply_gain(gain)
frames = np.array(audio.get_array_of_samples())
@@ -335,7 +331,6 @@ class AudioSegmenter:
# Load Silero VAD model
self.model = load_silero_vad()
- self.vad_threshold = 0.3
self.min_silence_duration_ms = min_silence_ms
self.max_speech_duration_s = max_speech_s
self.min_speech_duration_ms = min_speech_duration_ms
@@ -351,7 +346,6 @@ class AudioSegmenter:
audio_array,
self.model,
sampling_rate=AudioStream.FPS,
- threshold=self.vad_threshold,
min_silence_duration_ms=self.min_silence_duration_ms,
max_speech_duration_s=self.max_speech_duration_s,
min_speech_duration_ms=self.min_speech_duration_ms,
@@ -371,12 +365,9 @@ class AudioSegmenter:
for i in range(len(segments)):
s = segments[i]
- #print(f"s: {s}")
- #print(f"last_end: {last_end}")
if last_end:
delta_frames = s['start'] - last_end
- #print(f"delta frames: {delta_frames}")
if delta_frames > min_delta_frames:
cutoff = s['start']
else:
@@ -384,8 +375,6 @@ class AudioSegmenter:
if i == len(segments) - 1:
now = int(len(audio) / AudioStream.FRAME_SZ)
- #print(f"now: {now}")
- #print(f"min d: {min_delta_frames}")
delta_frames = now - s['end']
if delta_frames > min_delta_frames:
cutoff = now - int(min_delta_frames / 2)
@@ -402,7 +391,8 @@ class Segment:
wall_ts: float,
avg_logprob: float,
no_speech_prob: float,
- compression_ratio: float):
+ compression_ratio: float,
+ audio_len_s: float):
self.transcript = transcript
# start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
self.start_ts = start_ts
@@ -413,6 +403,7 @@ class Segment:
self.avg_logprob = avg_logprob
self.no_speech_prob = no_speech_prob
self.compression_ratio = compression_ratio
+ self.audio_len_s = audio_len_s
def __str__(self):
ts = f"(ts: {self.start_ts}-{self.end_ts}) "
@@ -423,7 +414,8 @@ class Segment:
no_speech = f"(no_speech: {self.no_speech_prob}) "
avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
- return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+ max_len_s = f"(max_len_s: {self.audio_len_s}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s
def join_segments(a, b):
if len(a) > 0 and a[-1] != ' ':
@@ -431,20 +423,76 @@ def join_segments(a, b):
else:
return a + b
+
+class SegmentLogger:
+ def __init__(self, cfg: typing.Dict):
+ self.cfg = cfg
+ self.enabled = cfg.get("enable_segment_logging", False)
+ self.session_data = []
+ self.log_file = None
+
+ if self.enabled:
+ log_dir = os.path.join(PROJECT_ROOT, "logs")
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ # Create file
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json")
+ log(f"Segment logging enabled. Logging to: {self.log_file}")
+
+ def log_segment(self, segment: Segment, commit_type: str = "commit"):
+ if not self.enabled:
+ return
+
+ segment_data = {
+ "timestamp": datetime.now().isoformat(),
+ "type": commit_type,
+ "text": segment.transcript,
+ "start_ts": segment.start_ts,
+ "end_ts": segment.end_ts,
+ "wall_ts": segment.wall_ts,
+ "avg_logprob": segment.avg_logprob,
+ "no_speech_prob": segment.no_speech_prob,
+ "compression_ratio": segment.compression_ratio,
+ "duration": segment.end_ts - segment.start_ts,
+ "duration_sanity": segment.audio_len_s
+ }
+
+ self.session_data.append(segment_data)
+
+ # Write to file incrementally
+ try:
+ with open(self.log_file, 'w') as f:
+ json.dump({
+ "session_start": self.session_data[0]["timestamp"] if self.session_data else None,
+ "segments": self.session_data
+ }, f, indent=2)
+ except Exception as e:
+ log_err(f"Error writing segment log: {e}")
+
+ def close(self):
+ if self.enabled and self.session_data:
+ log(f"Session complete. Logged {len(self.session_data)} " + \
+ "segments to {self.log_file}")
+
+
class Whisper:
def __init__(self,
collector: AudioCollector,
- cfg: typing.Dict):
+ cfg: typing.Dict,
+ segment_logger: SegmentLogger = None):
self.collector = collector
self.model = None
self.cfg = cfg
+ self.hallucination_filter = HallucinationFilter()
+ self.segment_logger = segment_logger
model_str = cfg["model"]
model_root = os.path.join(PROJECT_ROOT, "Models",
os.path.normpath(model_str))
if cfg["enable_debug_mode"]:
- print(f"Model {cfg['model']} will be saved to {model_root}",
- file=sys.stderr)
+ log(f"Model {cfg['model']} will be saved to {model_root}")
model_device = "cuda"
compute_type = cfg["compute_type"]
@@ -455,7 +503,7 @@ class Whisper:
already_downloaded = os.path.exists(model_root)
if not already_downloaded:
- print(f"Model {model_str} not already downloaded, downloading now...", flush=True)
+ log(f"Model {model_str} not already downloaded, downloading now...")
self.model = WhisperModel(model_str,
device = model_device,
@@ -483,19 +531,20 @@ class Whisper:
# Convert audio to float32
audio = np.frombuffer(frames,
dtype=np.int16).flatten().astype(np.float32) / 32768.0
+ audio_len_s = len(frames) / 16000.0
# Build context-aware prompt
prompt = self._build_prompt()
if self.cfg["enable_debug_mode"]:
- print(f"Prompt: {prompt}", flush=True)
+ log(f"Prompt: {prompt}")
t0 = time.time()
segments, info = self.model.transcribe(
audio,
language = langcodes.find(self.cfg["language"]).language,
- vad_filter = True,
- temperature=0.0,
+ vad_filter = False,
+ #temperature=0.0,
without_timestamps = False,
initial_prompt=prompt,
beam_size=self.cfg.get("beam_size", 5),
@@ -504,44 +553,44 @@ class Whisper:
)
res = []
for s in segments:
- # Manual touchup. I see a decent number of hallucinations sneaking
- # in with high `no_speech_prob` and modest `avg_logprob`.
- if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 1) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
- # Another touchup targeted at the vexatious "thanks for watching!"
- # hallucination. This triggers a lot when listening to
- # instrumental/electronic music.
- if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 2) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
- if s.avg_logprob < -0.75:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 3) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
+ # Log raw segment before filtering
+ if self.segment_logger:
+ # Create a temporary segment object for logging
+ raw_seg = Segment(s.text, s.start, s.end,
+ self.collector.begin(),
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio,
+ audio_len_s)
+ self.segment_logger.log_segment(raw_seg, "raw")
+ # Sometimes the model reports a bum duration, breaking our filters.
+ # Cap the segment length above by the length of the audio in.
+ duration_s = min(s.end - s.start, audio_len_s)
+
if self.cfg["enable_debug_mode"]:
- print(f"s get: {s}")
+ log(f"s get: {s}")
if s.avg_logprob < -1.0:
continue
if s.compression_ratio > 2.4:
continue
- res.append(Segment(s.text, s.start, s.end,
+
+ # Create segment object
+ seg = Segment(s.text, s.start, s.end,
self.collector.begin(),
- s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio,
+ audio_len_s)
+
+ # Check with ML model for "Thank you" hallucinations
+ if self.hallucination_filter.is_thank_you_hallucination(seg):
+ if self.cfg["enable_debug_mode"]:
+ log(f"Drop probable hallucination (case 4) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})")
+ continue
+
+ res.append(seg)
t1 = time.time()
if self.cfg["enable_debug_mode"]:
- print(f"Transcription latency (s): {t1 - t0}")
+ log(f"Transcription latency (s): {t1 - t0}")
return res
def _build_prompt(self) -> str:
@@ -580,64 +629,13 @@ class TranscriptCommit:
def saveAudio(audio: bytes, path: str, cfg: typing.Dict):
with wave.open(path, 'wb') as wf:
if cfg["enable_debug_mode"]:
- print(f"Saving audio to {path}", file=sys.stderr)
+ log(f"Saving audio to {path}")
wf.setnchannels(AudioStream.CHANNELS)
wf.setsampwidth(AudioStream.FRAME_SZ)
wf.setframerate(AudioStream.FPS)
wf.writeframes(audio)
-class SegmentLogger:
- def __init__(self, cfg: typing.Dict):
- self.cfg = cfg
- self.enabled = cfg.get("enable_segment_logging", False)
- self.session_data = []
- self.log_file = None
-
- if self.enabled:
- log_dir = os.path.join(PROJECT_ROOT, "logs")
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
-
- # Create file
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json")
- print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr)
-
- def log_segment(self, segment: Segment, commit_type: str = "commit"):
- if not self.enabled:
- return
-
- segment_data = {
- "timestamp": datetime.now().isoformat(),
- "type": commit_type,
- "text": segment.transcript,
- "start_ts": segment.start_ts,
- "end_ts": segment.end_ts,
- "wall_ts": segment.wall_ts,
- "avg_logprob": segment.avg_logprob,
- "no_speech_prob": segment.no_speech_prob,
- "compression_ratio": segment.compression_ratio,
- "duration": segment.end_ts - segment.start_ts
- }
-
- self.session_data.append(segment_data)
-
- # Write to file incrementally
- try:
- with open(self.log_file, 'w') as f:
- json.dump({
- "session_start": self.session_data[0]["timestamp"] if self.session_data else None,
- "segments": self.session_data
- }, f, indent=2)
- except Exception as e:
- print(f"Error writing segment log: {e}", file=sys.stderr)
-
- def close(self):
- if self.enabled and self.session_data:
- print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr)
-
-
class VadCommitter:
def __init__(self,
cfg: typing.Dict,
@@ -680,16 +678,12 @@ class VadCommitter:
if delta.strip():
self.whisper.update_context(delta.strip())
- if self.segment_logger:
- for s in segments:
- self.segment_logger.log_segment(s, "commit")
-
audio = self.collector.getAudio()
if self.cfg["enable_debug_mode"]:
for s in segments:
- print(f"commit segment: {s}", file=sys.stderr)
+ log(f"commit segment: {s}")
if len(delta) > 0:
- print(f"delta get: {delta}", file=sys.stderr)
+ log(f"delta get: {delta}")
if self.cfg["save_audio"] and len(delta) > 0:
ts = datetime.fromtimestamp(self.collector.now() - latency_s)
@@ -705,10 +699,6 @@ class VadCommitter:
segments = self.whisper.transcribe(audio)
preview = "".join(s.transcript for s in segments)
- if self.segment_logger:
- for s in segments:
- self.segment_logger.log_segment(s, "preview")
-
if not has_audio:
self.collector.keepLast(1.0)
@@ -758,13 +748,13 @@ class ProfanityPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
self.filter = None
- if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]:
+ if cfg["enable_profanity_filter"]:
en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en")
try:
self.filter = ProfanityFilter(en_profanity_path)
self.filter.load()
except Exception as e:
- print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr)
+ log_err(f"Warning: Could not load profanity filter: {e}")
self.filter = None
def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
@@ -812,12 +802,11 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.cfg)
collector = NoiseReducingAudioCollector(collector, shared_data.cfg)
#collector = NormalizingAudioCollector(collector)
- whisper = Whisper(collector, shared_data.cfg)
+ segment_logger = SegmentLogger(shared_data.cfg)
+ whisper = Whisper(collector, shared_data.cfg, segment_logger)
segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
max_speech_s=shared_data.cfg["max_speech_duration_s"],
min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"])
-
- segment_logger = SegmentLogger(shared_data.cfg)
committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger)
plugins = []
@@ -838,7 +827,7 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.stream = stream
shared_data.collector = collector
- print(f"Ready to go!", flush=True)
+ log(f"Ready to go!")
while not shared_data.exit_event.is_set():
time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0);
@@ -862,8 +851,8 @@ def transcriptionThread(shared_data: SharedThreadData):
silence_duration = commit.start_ts - last_commit_end_ts
if silence_duration > shared_data.cfg["reset_after_silence_s"]:
if shared_data.cfg["enable_debug_mode"]:
- print(f"Resetting transcript after {silence_duration}-second "
- "silence", file=sys.stderr)
+ log(f"Resetting transcript after {silence_duration}-second "
+ "silence")
shared_data.transcript = ""
shared_data.preview = ""
whisper.recent_context = "" # Reset context too
@@ -885,20 +874,18 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.preview)
try:
- print(f"Transcript: {shared_data.transcript}", flush=True)
+ log(f"Transcript: {shared_data.transcript}")
except UnicodeEncodeError:
- print("Failed to encode transcript - discarding delta",
- file=sys.stderr)
+ log_err("Failed to encode transcript - discarding delta")
continue
try:
- print(f"Preview: {shared_data.preview}", flush=True)
+ log(f"Preview: {shared_data.preview}")
except UnicodeEncodeError:
- print("Failed to encode preview - discarding", file=sys.stderr)
+ log_err("Failed to encode preview - discarding")
if shared_data.cfg["enable_debug_mode"]:
- print(f"commit latency: {commit.latency_s}", file=sys.stderr)
- print(f"commit thresh: {commit.thresh_at_commit}",
- file=sys.stderr)
+ log(f"commit latency: {commit.latency_s}")
+ log(f"commit thresh: {commit.thresh_at_commit}")
if len(shared_data.transcript) > 0 and \
(not shared_data.transcript.endswith(' ')) and \