diff options
| -rw-r--r-- | app/hallucination_filter.py | 84 | ||||
| -rw-r--r-- | app/stt.py | 2 | ||||
| -rw-r--r-- | config.yaml | 4 | ||||
| -rw-r--r-- | train_hallucination_filter.py | 99 | ||||
| -rw-r--r-- | ui/index.html | 2 | ||||
| -rw-r--r-- | ui/renderer.js | 6 |
6 files changed, 113 insertions, 84 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py index fa7b16a..1a80e62 100644 --- a/app/hallucination_filter.py +++ b/app/hallucination_filter.py @@ -11,36 +11,39 @@ import sys def count_syllables(word): """Count syllables in a word using pronouncing library with regex fallback.""" phones = pronouncing.phones_for_word(word.lower()) + if len(phones) == 0: + return 0 return pronouncing.syllable_count(phones[0]) + def text_syllable_count(text): """Count total syllables in text.""" words = re.findall(r'\b\w+\b', text) return sum(count_syllables(word) for word in words) + class HallucinationFilter: """Filter for detecting hallucinated segments in speech-to-text output.""" - def __init__(self, model_path: Path = None): + def __init__(self, cfg, model_path: Path = None): """ Initialize the hallucination filter. Args: model_path: Optional path to the model file. If not provided, uses the default path. """ + self.cfg = cfg self.model = None self.threshold = None self.features = None # Get the project root directory app_root = Path(__file__).resolve().parent project_root = app_root.parent - # Use provided path or default - if model_path is None: - model_path = project_root / "Models" / "thankyou_filter_gb.pkl" + model_path = project_root / "Models" / "thankyou_filter_gb.pkl" # Try to load the model - log_err(f"Loading hallucination filter") + log(f"Loading hallucination filter") bundle = joblib.load(model_path) self.model = bundle["model"] self.threshold = bundle["threshold"] self.features = bundle["features"] - log_err(f"Loaded hallucination filter model from {model_path}") + log(f"Loaded hallucination filter model from {model_path}") def is_hallucination(self, segment) -> bool: """ Check if a segment is likely a hallucination. @@ -51,25 +54,50 @@ class HallucinationFilter: Returns: bool: True if the segment is likely a hallucination, False otherwise. """ - # Calculate text-based features - text = getattr(segment, 'text', '') - duration = segment.audio_len_s - raw_duration = segment.end_ts - segment.start_ts - n_syllables = text_syllable_count(text) - sps = n_syllables / duration - raw_sps = n_syllables / raw_duration - duration_ratio = raw_duration / duration - X = pd.DataFrame([[ - segment.avg_logprob, - segment.no_speech_prob, - segment.compression_ratio, - np.log1p(duration), - np.log1p(sps), - np.log1p(raw_duration), - np.log1p(raw_sps), - duration_ratio, - segment.avg_logprob * duration - ]], columns=self.features) - # Get probability - prob = self.model.predict_proba(X)[0, 1] - return prob >= self.threshold + s = segment # Brevity + + if s.no_speech_prob == 0: + # no_speech is not available. Use fancy classifier trained on my + # speech data. + text = s.transcript + duration = s.audio_len_s + raw_duration = s.end_ts - s.start_ts + n_syllables = text_syllable_count(text) + sps = n_syllables / duration + raw_sps = n_syllables / raw_duration + duration_ratio = raw_duration / duration + X = pd.DataFrame([[ + s.avg_logprob, + s.no_speech_prob, + s.compression_ratio, + np.log1p(duration), + np.log1p(sps), + np.log1p(raw_duration), + np.log1p(raw_sps), + duration_ratio, + s.avg_logprob * duration + ]], columns=self.features) + # Get probability + prob = self.model.predict_proba(X)[0, 1] + return prob >= self.threshold + + # If no_speech is set, use simpler filter. + if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 1) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + return True + # Another touchup targeted at the vexatious "thanks for watching!" + # hallucination. This triggers a lot when listening to + # instrumental/electronic music. + if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 2) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + return True + return False + @@ -485,7 +485,7 @@ class Whisper: self.collector = collector self.model = None self.cfg = cfg - self.hallucination_filter = HallucinationFilter() + self.hallucination_filter = HallucinationFilter(cfg) self.segment_logger = segment_logger model_str = cfg["model"] diff --git a/config.yaml b/config.yaml index 9cec4a3..519457b 100644 --- a/config.yaml +++ b/config.yaml @@ -1,8 +1,8 @@ compute_type: float16 language: english -model: turbo +model: distil-large-v3.5 microphone: 4 -user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. keybind: ctrl+alt+x button_hand: right button_type: b diff --git a/train_hallucination_filter.py b/train_hallucination_filter.py index dc3ce36..9db867f 100644 --- a/train_hallucination_filter.py +++ b/train_hallucination_filter.py @@ -19,6 +19,8 @@ from sklearn.model_selection import StratifiedKFold, cross_val_predict def count_syllables(word): """Count syllables in a word using pronouncing library with regex fallback.""" phones = pronouncing.phones_for_word(word.lower()) + if len(phones) == 0: + return 0 return pronouncing.syllable_count(phones[0]) def text_syllable_count(text): @@ -36,55 +38,52 @@ def load_segments(log_dir): for file in files: if not file.endswith('.json'): continue - try: - with open(os.path.join(root, file), 'r') as f: - data = json.load(f) - - for segment in data.get('segments', []): - if 'duration_sanity' not in segment: - continue - - # Extract all available features - text = segment.get('text', '') - duration = segment.get('duration_sanity', 0) - - # Calculate raw duration from timestamps - start_ts = segment.get('start_ts', 0) - end_ts = segment.get('end_ts', 0) - raw_duration = end_ts - start_ts - - seg_data = { - 'avg_logprob': segment.get('avg_logprob', 0), - 'no_speech_prob': segment.get('no_speech_prob', 0), - 'duration_sanity': duration, - 'raw_duration': raw_duration, - 'compression_ratio': segment.get('compression_ratio', 1), - 'text': text - } - - # Add speech rate features - n_syllables = text_syllable_count(text) - seg_data['sps'] = n_syllables / duration - seg_data['log_sps'] = np.log1p(seg_data['sps']) - seg_data['raw_sps'] = n_syllables / raw_duration - seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps']) - - # Add derived features - seg_data['log_duration'] = np.log1p(duration) - seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration - seg_data['log_raw_duration'] = np.log1p(raw_duration) - seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0 - - # Deduplicate: skip if this exact metadata already seen - key = tuple(sorted(seg_data.items())) - if key in seen: - num_dupes += 1 - continue - seen.add(key) - - segments.append(seg_data) - except Exception as e: - print(f"Error loading {file}: {e}") + with open(os.path.join(root, file), 'r') as f: + data = json.load(f) + + for segment in data["segments"]: + if 'duration_sanity' not in segment: + continue + + # Extract all available features + text = segment["text"] + duration = segment["duration_sanity"] + + # Calculate raw duration from timestamps + start_ts = segment["start_ts"] + end_ts = segment["end_ts"] + raw_duration = end_ts - start_ts + + seg_data = { + 'avg_logprob': segment["avg_logprob"], + 'no_speech_prob': segment["no_speech_prob"], + 'duration_sanity': duration, + 'raw_duration': raw_duration, + 'compression_ratio': segment["compression_ratio"], + 'text': text + } + + # Add speech rate features + n_syllables = text_syllable_count(text) + seg_data['sps'] = n_syllables / duration + seg_data['log_sps'] = np.log1p(seg_data['sps']) + seg_data['raw_sps'] = n_syllables / raw_duration + seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps']) + + # Add derived features + seg_data['log_duration'] = np.log1p(duration) + seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration + seg_data['log_raw_duration'] = np.log1p(raw_duration) + seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0 + + # Deduplicate: skip if this exact metadata already seen + key = tuple(sorted(seg_data.items())) + if key in seen: + num_dupes += 1 + continue + seen.add(key) + + segments.append(seg_data) print(f"Skipped {num_dupes} duplicate segments") return pd.DataFrame(segments) @@ -112,7 +111,7 @@ def log_seed_data(seeds_df, seed_type, label_desc): def main(): # Find logs directory log_dir = None - for pattern in ["ui/dist/logs", "logs", "ui/dist/*/logs", "ui/dist/*/*/logs", "ui/dist/*/*/*/logs"]: + for pattern in ["ui/dist/win-unpacked/resources/logs"]: paths = list(Path(".").glob(pattern)) if paths: log_dir = str(paths[0]) diff --git a/ui/index.html b/ui/index.html index 70eaa68..1828b90 100644 --- a/ui/index.html +++ b/ui/index.html @@ -32,8 +32,8 @@ <option value="base">base</option> <option value="small">small</option> <option value="medium">medium</option> + <option value="distil-large-v3.5">distilled large 3.5</option> <option value="large">large</option> - <option value="turbo">turbo</option> </select> </div> <div> diff --git a/ui/renderer.js b/ui/renderer.js index c2df4d2..05f1866 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -88,7 +88,8 @@ class LoadingOverlay { this.overlay.classList.add('hidden'); // Restore original states of form inputs and buttons const leftPanel = this.overlay.parentElement; - const inputs = leftPanel.querySelectorAll('input, select, textarea, button'); + const inputs = leftPanel.querySelectorAll( + 'input, select, textarea, button'); inputs.forEach(input => { // Restore original disabled state input.disabled = this.originalStates.get(input) || false; @@ -107,7 +108,8 @@ function showStatus(message, type = 'info') { statusEl.textContent = message; // Remove all status classes - const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800']; + const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', + 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800']; statusEl.classList.remove(...statusClasses); // Add appropriate classes based on type |
