diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-25 23:07:03 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-25 23:07:03 -0700 |
| commit | 6815848fb8ed06b59b6d7e57096143f1f840e7db (patch) | |
| tree | 17591313c9f540b08830cf7790e195ef06c6a3e8 /app | |
| parent | a7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (diff) | |
Work more on hallucination filteringv1.0.0-beta03
Diffstat (limited to 'app')
| -rw-r--r-- | app/hallucination_filter.py | 47 | ||||
| -rw-r--r-- | app/requirements.txt | 1 | ||||
| -rw-r--r-- | app/stt.py | 5 |
3 files changed, 32 insertions, 21 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py index 9b24a85..fa7b16a 100644 --- a/app/hallucination_filter.py +++ b/app/hallucination_filter.py @@ -4,16 +4,23 @@ from logger import log, log_err import numpy as np import pandas as pd from pathlib import Path +import pronouncing +import re import sys - +def count_syllables(word): + """Count syllables in a word using pronouncing library with regex fallback.""" + phones = pronouncing.phones_for_word(word.lower()) + return pronouncing.syllable_count(phones[0]) +def text_syllable_count(text): + """Count total syllables in text.""" + words = re.findall(r'\b\w+\b', text) + return sum(count_syllables(word) for word in words) class HallucinationFilter: """Filter for detecting hallucinated segments in speech-to-text output.""" - def __init__(self, model_path: Path = None): """ Initialize the hallucination filter. - Args: model_path: Optional path to the model file. If not provided, uses the default path. @@ -21,46 +28,48 @@ class HallucinationFilter: self.model = None self.threshold = None self.features = None - # Get the project root directory app_root = Path(__file__).resolve().parent project_root = app_root.parent - # Use provided path or default if model_path is None: model_path = project_root / "Models" / "thankyou_filter_gb.pkl" - # Try to load the model log_err(f"Loading hallucination filter") bundle = joblib.load(model_path) self.model = bundle["model"] self.threshold = bundle["threshold"] - self.features = bundle["features"] # Extract feature names + self.features = bundle["features"] log_err(f"Loaded hallucination filter model from {model_path}") - - def is_thank_you_hallucination(self, segment) -> bool: + def is_hallucination(self, segment) -> bool: """ - Check if a segment is likely a "Thank you" hallucination. + Check if a segment is likely a hallucination. Returns False if model is not available. - Args: segment: A segment object with attributes avg_logprob, audio_len_s, - no_speech_prob, and compression_ratio. - + no_speech_prob, compression_ratio, text, start, and end. Returns: bool: True if the segment is likely a hallucination, False otherwise. """ - # Create DataFrame with proper feature names + # Calculate text-based features + text = getattr(segment, 'text', '') + duration = segment.audio_len_s + raw_duration = segment.end_ts - segment.start_ts + n_syllables = text_syllable_count(text) + sps = n_syllables / duration + raw_sps = n_syllables / raw_duration + duration_ratio = raw_duration / duration X = pd.DataFrame([[ segment.avg_logprob, - segment.audio_len_s, segment.no_speech_prob, segment.compression_ratio, - np.log1p(segment.audio_len_s), - segment.avg_logprob * segment.audio_len_s + np.log1p(duration), + np.log1p(sps), + np.log1p(raw_duration), + np.log1p(raw_sps), + duration_ratio, + segment.avg_logprob * duration ]], columns=self.features) - # Get probability prob = self.model.predict_proba(X)[0, 1] return prob >= self.threshold - diff --git a/app/requirements.txt b/app/requirements.txt index dc294e5..d5bc10f 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -4,6 +4,7 @@ keyboard langcodes noisereduce pandas +pronouncing pyaudio pygame pydub @@ -578,8 +578,9 @@ class Whisper: s.avg_logprob, s.no_speech_prob, s.compression_ratio, audio_len_s) - # Check with ML model for "Thank you" hallucinations - if self.hallucination_filter.is_thank_you_hallucination(seg): + # Apply hallucination filter. This is a statistical model trained + # on personal speech data. + if self.hallucination_filter.is_hallucination(seg): if self.cfg["enable_debug_mode"]: log(f"Drop probable hallucination (case 4) " + f"(text='{s.text}', " + |
