summaryrefslogtreecommitdiffstats
path: root/app/stt.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/stt.py')
-rw-r--r--app/stt.py267
1 files changed, 127 insertions, 140 deletions
diff --git a/app/stt.py b/app/stt.py
index 4ec559b..8ab83a5 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -2,15 +2,11 @@ from datetime import datetime
from faster_whisper import WhisperModel
import json
import langcodes
+from logger import log, log_err
+import noisereduce as nr
import numpy as np
import os
-import noisereduce as nr
-try:
- from profanity_filter import ProfanityFilter
- PROFANITY_FILTER_AVAILABLE = True
-except ImportError:
- PROFANITY_FILTER_AVAILABLE = False
- print("Warning: profanity_filter module not available", file=sys.stderr)
+from profanity_filter import ProfanityFilter
import pyaudio
from pydub import AudioSegment
from shared_thread_data import SharedThreadData
@@ -19,6 +15,7 @@ import sys
import time
import typing
import wave
+from hallucination_filter import HallucinationFilter
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(APP_ROOT)
@@ -55,7 +52,7 @@ class MicStream(AudioStream):
which_mic = cfg["microphone"]
if cfg["enable_debug_mode"]:
- print(f"Finding mic {which_mic}", file=sys.stderr)
+ log(f"Finding mic {which_mic}")
self.dumpMicDevices()
got_match = False
@@ -70,8 +67,8 @@ class MicStream(AudioStream):
target_str = "Microphone (Beyond)"
else:
if cfg["enable_debug_mode"]:
- print(f"Mic {which_mic} requested, treating it as a numerical " +
- "device ID", file=sys.stderr)
+ log(f"Mic {which_mic} requested, treating it as a numerical " +
+ "device ID")
device_index = int(which_mic)
got_match = True
if not got_match:
@@ -81,8 +78,7 @@ class MicStream(AudioStream):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
if target_str in device_name:
- print(f"Got matching mic: {device_name}",
- file=sys.stderr)
+ log(f"Got matching mic: {device_name}")
device_index = i
got_match = True
break
@@ -91,10 +87,10 @@ class MicStream(AudioStream):
info = self.p.get_device_info_by_host_api_device_index(0, device_index)
if cfg["enable_debug_mode"]:
- print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr)
+ log(f"Found mic {which_mic}: {info['name']}")
self.sample_rate = int(info['defaultSampleRate'])
if cfg["enable_debug_mode"]:
- print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr)
+ log(f"Mic sample rate: {self.sample_rate}")
self.stream = self.p.open(
rate=self.sample_rate,
@@ -119,7 +115,7 @@ class MicStream(AudioStream):
for i in range(0, numdevices):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
- print("Input Device id ", i, " - ", device_name)
+ log("Input Device id ", i, " - ", device_name)
def onAudioFramesAvailable(self,
frames,
@@ -278,7 +274,7 @@ class BoostingAudioCollector(AudioCollectorFilter):
frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB)
if self.cfg["enable_debug_mode"]:
- print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True)
+ log(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})")
audio = audio.apply_gain(gain)
frames = np.array(audio.get_array_of_samples())
@@ -335,7 +331,6 @@ class AudioSegmenter:
# Load Silero VAD model
self.model = load_silero_vad()
- self.vad_threshold = 0.3
self.min_silence_duration_ms = min_silence_ms
self.max_speech_duration_s = max_speech_s
self.min_speech_duration_ms = min_speech_duration_ms
@@ -351,7 +346,6 @@ class AudioSegmenter:
audio_array,
self.model,
sampling_rate=AudioStream.FPS,
- threshold=self.vad_threshold,
min_silence_duration_ms=self.min_silence_duration_ms,
max_speech_duration_s=self.max_speech_duration_s,
min_speech_duration_ms=self.min_speech_duration_ms,
@@ -371,12 +365,9 @@ class AudioSegmenter:
for i in range(len(segments)):
s = segments[i]
- #print(f"s: {s}")
- #print(f"last_end: {last_end}")
if last_end:
delta_frames = s['start'] - last_end
- #print(f"delta frames: {delta_frames}")
if delta_frames > min_delta_frames:
cutoff = s['start']
else:
@@ -384,8 +375,6 @@ class AudioSegmenter:
if i == len(segments) - 1:
now = int(len(audio) / AudioStream.FRAME_SZ)
- #print(f"now: {now}")
- #print(f"min d: {min_delta_frames}")
delta_frames = now - s['end']
if delta_frames > min_delta_frames:
cutoff = now - int(min_delta_frames / 2)
@@ -402,7 +391,8 @@ class Segment:
wall_ts: float,
avg_logprob: float,
no_speech_prob: float,
- compression_ratio: float):
+ compression_ratio: float,
+ audio_len_s: float):
self.transcript = transcript
# start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
self.start_ts = start_ts
@@ -413,6 +403,7 @@ class Segment:
self.avg_logprob = avg_logprob
self.no_speech_prob = no_speech_prob
self.compression_ratio = compression_ratio
+ self.audio_len_s = audio_len_s
def __str__(self):
ts = f"(ts: {self.start_ts}-{self.end_ts}) "
@@ -423,7 +414,8 @@ class Segment:
no_speech = f"(no_speech: {self.no_speech_prob}) "
avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
- return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+ max_len_s = f"(max_len_s: {self.audio_len_s}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + max_len_s
def join_segments(a, b):
if len(a) > 0 and a[-1] != ' ':
@@ -431,20 +423,76 @@ def join_segments(a, b):
else:
return a + b
+
+class SegmentLogger:
+ def __init__(self, cfg: typing.Dict):
+ self.cfg = cfg
+ self.enabled = cfg.get("enable_segment_logging", False)
+ self.session_data = []
+ self.log_file = None
+
+ if self.enabled:
+ log_dir = os.path.join(PROJECT_ROOT, "logs")
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ # Create file
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json")
+ log(f"Segment logging enabled. Logging to: {self.log_file}")
+
+ def log_segment(self, segment: Segment, commit_type: str = "commit"):
+ if not self.enabled:
+ return
+
+ segment_data = {
+ "timestamp": datetime.now().isoformat(),
+ "type": commit_type,
+ "text": segment.transcript,
+ "start_ts": segment.start_ts,
+ "end_ts": segment.end_ts,
+ "wall_ts": segment.wall_ts,
+ "avg_logprob": segment.avg_logprob,
+ "no_speech_prob": segment.no_speech_prob,
+ "compression_ratio": segment.compression_ratio,
+ "duration": segment.end_ts - segment.start_ts,
+ "duration_sanity": segment.audio_len_s
+ }
+
+ self.session_data.append(segment_data)
+
+ # Write to file incrementally
+ try:
+ with open(self.log_file, 'w') as f:
+ json.dump({
+ "session_start": self.session_data[0]["timestamp"] if self.session_data else None,
+ "segments": self.session_data
+ }, f, indent=2)
+ except Exception as e:
+ log_err(f"Error writing segment log: {e}")
+
+ def close(self):
+ if self.enabled and self.session_data:
+ log(f"Session complete. Logged {len(self.session_data)} " + \
+ "segments to {self.log_file}")
+
+
class Whisper:
def __init__(self,
collector: AudioCollector,
- cfg: typing.Dict):
+ cfg: typing.Dict,
+ segment_logger: SegmentLogger = None):
self.collector = collector
self.model = None
self.cfg = cfg
+ self.hallucination_filter = HallucinationFilter()
+ self.segment_logger = segment_logger
model_str = cfg["model"]
model_root = os.path.join(PROJECT_ROOT, "Models",
os.path.normpath(model_str))
if cfg["enable_debug_mode"]:
- print(f"Model {cfg['model']} will be saved to {model_root}",
- file=sys.stderr)
+ log(f"Model {cfg['model']} will be saved to {model_root}")
model_device = "cuda"
compute_type = cfg["compute_type"]
@@ -455,7 +503,7 @@ class Whisper:
already_downloaded = os.path.exists(model_root)
if not already_downloaded:
- print(f"Model {model_str} not already downloaded, downloading now...", flush=True)
+ log(f"Model {model_str} not already downloaded, downloading now...")
self.model = WhisperModel(model_str,
device = model_device,
@@ -483,19 +531,20 @@ class Whisper:
# Convert audio to float32
audio = np.frombuffer(frames,
dtype=np.int16).flatten().astype(np.float32) / 32768.0
+ audio_len_s = len(frames) / 16000.0
# Build context-aware prompt
prompt = self._build_prompt()
if self.cfg["enable_debug_mode"]:
- print(f"Prompt: {prompt}", flush=True)
+ log(f"Prompt: {prompt}")
t0 = time.time()
segments, info = self.model.transcribe(
audio,
language = langcodes.find(self.cfg["language"]).language,
- vad_filter = True,
- temperature=0.0,
+ vad_filter = False,
+ #temperature=0.0,
without_timestamps = False,
initial_prompt=prompt,
beam_size=self.cfg.get("beam_size", 5),
@@ -504,44 +553,44 @@ class Whisper:
)
res = []
for s in segments:
- # Manual touchup. I see a decent number of hallucinations sneaking
- # in with high `no_speech_prob` and modest `avg_logprob`.
- if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 1) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
- # Another touchup targeted at the vexatious "thanks for watching!"
- # hallucination. This triggers a lot when listening to
- # instrumental/electronic music.
- if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 2) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
- if s.avg_logprob < -0.75:
- if self.cfg["enable_debug_mode"]:
- print(f"Drop probable hallucination (case 3) " +
- f"(text='{s.text}', " +
- f"no_speech_prob={s.no_speech_prob}, " +
- f"avg_logprob={s.avg_logprob})", file=sys.stderr)
- continue
+ # Log raw segment before filtering
+ if self.segment_logger:
+ # Create a temporary segment object for logging
+ raw_seg = Segment(s.text, s.start, s.end,
+ self.collector.begin(),
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio,
+ audio_len_s)
+ self.segment_logger.log_segment(raw_seg, "raw")
+ # Sometimes the model reports a bum duration, breaking our filters.
+ # Cap the segment length above by the length of the audio in.
+ duration_s = min(s.end - s.start, audio_len_s)
+
if self.cfg["enable_debug_mode"]:
- print(f"s get: {s}")
+ log(f"s get: {s}")
if s.avg_logprob < -1.0:
continue
if s.compression_ratio > 2.4:
continue
- res.append(Segment(s.text, s.start, s.end,
+
+ # Create segment object
+ seg = Segment(s.text, s.start, s.end,
self.collector.begin(),
- s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio,
+ audio_len_s)
+
+ # Check with ML model for "Thank you" hallucinations
+ if self.hallucination_filter.is_thank_you_hallucination(seg):
+ if self.cfg["enable_debug_mode"]:
+ log(f"Drop probable hallucination (case 4) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})")
+ continue
+
+ res.append(seg)
t1 = time.time()
if self.cfg["enable_debug_mode"]:
- print(f"Transcription latency (s): {t1 - t0}")
+ log(f"Transcription latency (s): {t1 - t0}")
return res
def _build_prompt(self) -> str:
@@ -580,64 +629,13 @@ class TranscriptCommit:
def saveAudio(audio: bytes, path: str, cfg: typing.Dict):
with wave.open(path, 'wb') as wf:
if cfg["enable_debug_mode"]:
- print(f"Saving audio to {path}", file=sys.stderr)
+ log(f"Saving audio to {path}")
wf.setnchannels(AudioStream.CHANNELS)
wf.setsampwidth(AudioStream.FRAME_SZ)
wf.setframerate(AudioStream.FPS)
wf.writeframes(audio)
-class SegmentLogger:
- def __init__(self, cfg: typing.Dict):
- self.cfg = cfg
- self.enabled = cfg.get("enable_segment_logging", False)
- self.session_data = []
- self.log_file = None
-
- if self.enabled:
- log_dir = os.path.join(PROJECT_ROOT, "logs")
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
-
- # Create file
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json")
- print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr)
-
- def log_segment(self, segment: Segment, commit_type: str = "commit"):
- if not self.enabled:
- return
-
- segment_data = {
- "timestamp": datetime.now().isoformat(),
- "type": commit_type,
- "text": segment.transcript,
- "start_ts": segment.start_ts,
- "end_ts": segment.end_ts,
- "wall_ts": segment.wall_ts,
- "avg_logprob": segment.avg_logprob,
- "no_speech_prob": segment.no_speech_prob,
- "compression_ratio": segment.compression_ratio,
- "duration": segment.end_ts - segment.start_ts
- }
-
- self.session_data.append(segment_data)
-
- # Write to file incrementally
- try:
- with open(self.log_file, 'w') as f:
- json.dump({
- "session_start": self.session_data[0]["timestamp"] if self.session_data else None,
- "segments": self.session_data
- }, f, indent=2)
- except Exception as e:
- print(f"Error writing segment log: {e}", file=sys.stderr)
-
- def close(self):
- if self.enabled and self.session_data:
- print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr)
-
-
class VadCommitter:
def __init__(self,
cfg: typing.Dict,
@@ -680,16 +678,12 @@ class VadCommitter:
if delta.strip():
self.whisper.update_context(delta.strip())
- if self.segment_logger:
- for s in segments:
- self.segment_logger.log_segment(s, "commit")
-
audio = self.collector.getAudio()
if self.cfg["enable_debug_mode"]:
for s in segments:
- print(f"commit segment: {s}", file=sys.stderr)
+ log(f"commit segment: {s}")
if len(delta) > 0:
- print(f"delta get: {delta}", file=sys.stderr)
+ log(f"delta get: {delta}")
if self.cfg["save_audio"] and len(delta) > 0:
ts = datetime.fromtimestamp(self.collector.now() - latency_s)
@@ -705,10 +699,6 @@ class VadCommitter:
segments = self.whisper.transcribe(audio)
preview = "".join(s.transcript for s in segments)
- if self.segment_logger:
- for s in segments:
- self.segment_logger.log_segment(s, "preview")
-
if not has_audio:
self.collector.keepLast(1.0)
@@ -758,13 +748,13 @@ class ProfanityPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
self.filter = None
- if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]:
+ if cfg["enable_profanity_filter"]:
en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en")
try:
self.filter = ProfanityFilter(en_profanity_path)
self.filter.load()
except Exception as e:
- print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr)
+ log_err(f"Warning: Could not load profanity filter: {e}")
self.filter = None
def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
@@ -812,12 +802,11 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.cfg)
collector = NoiseReducingAudioCollector(collector, shared_data.cfg)
#collector = NormalizingAudioCollector(collector)
- whisper = Whisper(collector, shared_data.cfg)
+ segment_logger = SegmentLogger(shared_data.cfg)
+ whisper = Whisper(collector, shared_data.cfg, segment_logger)
segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
max_speech_s=shared_data.cfg["max_speech_duration_s"],
min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"])
-
- segment_logger = SegmentLogger(shared_data.cfg)
committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger)
plugins = []
@@ -838,7 +827,7 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.stream = stream
shared_data.collector = collector
- print(f"Ready to go!", flush=True)
+ log(f"Ready to go!")
while not shared_data.exit_event.is_set():
time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0);
@@ -862,8 +851,8 @@ def transcriptionThread(shared_data: SharedThreadData):
silence_duration = commit.start_ts - last_commit_end_ts
if silence_duration > shared_data.cfg["reset_after_silence_s"]:
if shared_data.cfg["enable_debug_mode"]:
- print(f"Resetting transcript after {silence_duration}-second "
- "silence", file=sys.stderr)
+ log(f"Resetting transcript after {silence_duration}-second "
+ "silence")
shared_data.transcript = ""
shared_data.preview = ""
whisper.recent_context = "" # Reset context too
@@ -885,20 +874,18 @@ def transcriptionThread(shared_data: SharedThreadData):
shared_data.preview)
try:
- print(f"Transcript: {shared_data.transcript}", flush=True)
+ log(f"Transcript: {shared_data.transcript}")
except UnicodeEncodeError:
- print("Failed to encode transcript - discarding delta",
- file=sys.stderr)
+ log_err("Failed to encode transcript - discarding delta")
continue
try:
- print(f"Preview: {shared_data.preview}", flush=True)
+ log(f"Preview: {shared_data.preview}")
except UnicodeEncodeError:
- print("Failed to encode preview - discarding", file=sys.stderr)
+ log_err("Failed to encode preview - discarding")
if shared_data.cfg["enable_debug_mode"]:
- print(f"commit latency: {commit.latency_s}", file=sys.stderr)
- print(f"commit thresh: {commit.thresh_at_commit}",
- file=sys.stderr)
+ log(f"commit latency: {commit.latency_s}")
+ log(f"commit thresh: {commit.thresh_at_commit}")
if len(shared_data.transcript) > 0 and \
(not shared_data.transcript.endswith(' ')) and \