diff options
| author | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-07-25 21:28:50 -0700 |
| commit | a7f9b7b5fb33bead6bcfb0ad6867b57f2ddc42af (patch) | |
| tree | 61d4870a019acb0e545d88e7661c8a4c7d90e499 /app/hallucination_filter.py | |
| parent | 5df013d26eb13ed4aef8d16aa14346e0f9be5111 (diff) | |
Experiment with hallucination reduction
- update cursorignore
- add hallucination filter training & inference code
- put logging into a central module
- segment metadata logging occurs before filtering
- segment metadata logging is on by default
- check in embedded python setup script
- include trained hallucination filter model
Diffstat (limited to 'app/hallucination_filter.py')
| -rw-r--r-- | app/hallucination_filter.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py new file mode 100644 index 0000000..9b24a85 --- /dev/null +++ b/app/hallucination_filter.py @@ -0,0 +1,66 @@ +import io +import joblib +from logger import log, log_err +import numpy as np +import pandas as pd +from pathlib import Path +import sys + + +class HallucinationFilter: + """Filter for detecting hallucinated segments in speech-to-text output.""" + + def __init__(self, model_path: Path = None): + """ + Initialize the hallucination filter. + + Args: + model_path: Optional path to the model file. If not provided, + uses the default path. + """ + self.model = None + self.threshold = None + self.features = None + + # Get the project root directory + app_root = Path(__file__).resolve().parent + project_root = app_root.parent + + # Use provided path or default + if model_path is None: + model_path = project_root / "Models" / "thankyou_filter_gb.pkl" + + # Try to load the model + log_err(f"Loading hallucination filter") + bundle = joblib.load(model_path) + self.model = bundle["model"] + self.threshold = bundle["threshold"] + self.features = bundle["features"] # Extract feature names + log_err(f"Loaded hallucination filter model from {model_path}") + + def is_thank_you_hallucination(self, segment) -> bool: + """ + Check if a segment is likely a "Thank you" hallucination. + Returns False if model is not available. + + Args: + segment: A segment object with attributes avg_logprob, audio_len_s, + no_speech_prob, and compression_ratio. + + Returns: + bool: True if the segment is likely a hallucination, False otherwise. + """ + # Create DataFrame with proper feature names + X = pd.DataFrame([[ + segment.avg_logprob, + segment.audio_len_s, + segment.no_speech_prob, + segment.compression_ratio, + np.log1p(segment.audio_len_s), + segment.avg_logprob * segment.audio_len_s + ]], columns=self.features) + + # Get probability + prob = self.model.predict_proba(X)[0, 1] + return prob >= self.threshold + |
