summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-25 23:07:03 -0700
committeryum <yum.food.vr@gmail.com>2025-07-25 23:07:03 -0700
commit6815848fb8ed06b59b6d7e57096143f1f840e7db (patch)
tree17591313c9f540b08830cf7790e195ef06c6a3e8 /app
parenta7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (diff)
Work more on hallucination filteringv1.0.0-beta03
Diffstat (limited to 'app')
-rw-r--r--app/hallucination_filter.py47
-rw-r--r--app/requirements.txt1
-rw-r--r--app/stt.py5
3 files changed, 32 insertions, 21 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py
index 9b24a85..fa7b16a 100644
--- a/app/hallucination_filter.py
+++ b/app/hallucination_filter.py
@@ -4,16 +4,23 @@ from logger import log, log_err
import numpy as np
import pandas as pd
from pathlib import Path
+import pronouncing
+import re
import sys
-
+def count_syllables(word):
+ """Count syllables in a word using pronouncing library with regex fallback."""
+ phones = pronouncing.phones_for_word(word.lower())
+ return pronouncing.syllable_count(phones[0])
+def text_syllable_count(text):
+ """Count total syllables in text."""
+ words = re.findall(r'\b\w+\b', text)
+ return sum(count_syllables(word) for word in words)
class HallucinationFilter:
"""Filter for detecting hallucinated segments in speech-to-text output."""
-
def __init__(self, model_path: Path = None):
"""
Initialize the hallucination filter.
-
Args:
model_path: Optional path to the model file. If not provided,
uses the default path.
@@ -21,46 +28,48 @@ class HallucinationFilter:
self.model = None
self.threshold = None
self.features = None
-
# Get the project root directory
app_root = Path(__file__).resolve().parent
project_root = app_root.parent
-
# Use provided path or default
if model_path is None:
model_path = project_root / "Models" / "thankyou_filter_gb.pkl"
-
# Try to load the model
log_err(f"Loading hallucination filter")
bundle = joblib.load(model_path)
self.model = bundle["model"]
self.threshold = bundle["threshold"]
- self.features = bundle["features"] # Extract feature names
+ self.features = bundle["features"]
log_err(f"Loaded hallucination filter model from {model_path}")
-
- def is_thank_you_hallucination(self, segment) -> bool:
+ def is_hallucination(self, segment) -> bool:
"""
- Check if a segment is likely a "Thank you" hallucination.
+ Check if a segment is likely a hallucination.
Returns False if model is not available.
-
Args:
segment: A segment object with attributes avg_logprob, audio_len_s,
- no_speech_prob, and compression_ratio.
-
+ no_speech_prob, compression_ratio, text, start, and end.
Returns:
bool: True if the segment is likely a hallucination, False otherwise.
"""
- # Create DataFrame with proper feature names
+ # Calculate text-based features
+ text = getattr(segment, 'text', '')
+ duration = segment.audio_len_s
+ raw_duration = segment.end_ts - segment.start_ts
+ n_syllables = text_syllable_count(text)
+ sps = n_syllables / duration
+ raw_sps = n_syllables / raw_duration
+ duration_ratio = raw_duration / duration
X = pd.DataFrame([[
segment.avg_logprob,
- segment.audio_len_s,
segment.no_speech_prob,
segment.compression_ratio,
- np.log1p(segment.audio_len_s),
- segment.avg_logprob * segment.audio_len_s
+ np.log1p(duration),
+ np.log1p(sps),
+ np.log1p(raw_duration),
+ np.log1p(raw_sps),
+ duration_ratio,
+ segment.avg_logprob * duration
]], columns=self.features)
-
# Get probability
prob = self.model.predict_proba(X)[0, 1]
return prob >= self.threshold
-
diff --git a/app/requirements.txt b/app/requirements.txt
index dc294e5..d5bc10f 100644
--- a/app/requirements.txt
+++ b/app/requirements.txt
@@ -4,6 +4,7 @@ keyboard
langcodes
noisereduce
pandas
+pronouncing
pyaudio
pygame
pydub
diff --git a/app/stt.py b/app/stt.py
index 8ab83a5..9947bae 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -578,8 +578,9 @@ class Whisper:
s.avg_logprob, s.no_speech_prob, s.compression_ratio,
audio_len_s)
- # Check with ML model for "Thank you" hallucinations
- if self.hallucination_filter.is_thank_you_hallucination(seg):
+ # Apply hallucination filter. This is a statistical model trained
+ # on personal speech data.
+ if self.hallucination_filter.is_hallucination(seg):
if self.cfg["enable_debug_mode"]:
log(f"Drop probable hallucination (case 4) " +
f"(text='{s.text}', " +