diff options
| author | yum <yum.food.vr@gmail.com> | 2025-09-03 16:07:40 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-09-03 16:07:40 -0700 |
| commit | bce085367e546e5801a41f55f2d5e84e12cc8607 (patch) | |
| tree | 2a2232b237be570d594c31a0226737ea9e53bed5 /train_hallucination_filter.py | |
| parent | 6815848fb8ed06b59b6d7e57096143f1f840e7db (diff) | |
Drop turbo; use old logic when no_speech ts available
Diffstat (limited to 'train_hallucination_filter.py')
| -rw-r--r-- | train_hallucination_filter.py | 99 |
1 files changed, 49 insertions, 50 deletions
diff --git a/train_hallucination_filter.py b/train_hallucination_filter.py index dc3ce36..9db867f 100644 --- a/train_hallucination_filter.py +++ b/train_hallucination_filter.py @@ -19,6 +19,8 @@ from sklearn.model_selection import StratifiedKFold, cross_val_predict def count_syllables(word): """Count syllables in a word using pronouncing library with regex fallback.""" phones = pronouncing.phones_for_word(word.lower()) + if len(phones) == 0: + return 0 return pronouncing.syllable_count(phones[0]) def text_syllable_count(text): @@ -36,55 +38,52 @@ def load_segments(log_dir): for file in files: if not file.endswith('.json'): continue - try: - with open(os.path.join(root, file), 'r') as f: - data = json.load(f) - - for segment in data.get('segments', []): - if 'duration_sanity' not in segment: - continue - - # Extract all available features - text = segment.get('text', '') - duration = segment.get('duration_sanity', 0) - - # Calculate raw duration from timestamps - start_ts = segment.get('start_ts', 0) - end_ts = segment.get('end_ts', 0) - raw_duration = end_ts - start_ts - - seg_data = { - 'avg_logprob': segment.get('avg_logprob', 0), - 'no_speech_prob': segment.get('no_speech_prob', 0), - 'duration_sanity': duration, - 'raw_duration': raw_duration, - 'compression_ratio': segment.get('compression_ratio', 1), - 'text': text - } - - # Add speech rate features - n_syllables = text_syllable_count(text) - seg_data['sps'] = n_syllables / duration - seg_data['log_sps'] = np.log1p(seg_data['sps']) - seg_data['raw_sps'] = n_syllables / raw_duration - seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps']) - - # Add derived features - seg_data['log_duration'] = np.log1p(duration) - seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration - seg_data['log_raw_duration'] = np.log1p(raw_duration) - seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0 - - # Deduplicate: skip if this exact metadata already seen - key = tuple(sorted(seg_data.items())) - if key in seen: - num_dupes += 1 - continue - seen.add(key) - - segments.append(seg_data) - except Exception as e: - print(f"Error loading {file}: {e}") + with open(os.path.join(root, file), 'r') as f: + data = json.load(f) + + for segment in data["segments"]: + if 'duration_sanity' not in segment: + continue + + # Extract all available features + text = segment["text"] + duration = segment["duration_sanity"] + + # Calculate raw duration from timestamps + start_ts = segment["start_ts"] + end_ts = segment["end_ts"] + raw_duration = end_ts - start_ts + + seg_data = { + 'avg_logprob': segment["avg_logprob"], + 'no_speech_prob': segment["no_speech_prob"], + 'duration_sanity': duration, + 'raw_duration': raw_duration, + 'compression_ratio': segment["compression_ratio"], + 'text': text + } + + # Add speech rate features + n_syllables = text_syllable_count(text) + seg_data['sps'] = n_syllables / duration + seg_data['log_sps'] = np.log1p(seg_data['sps']) + seg_data['raw_sps'] = n_syllables / raw_duration + seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps']) + + # Add derived features + seg_data['log_duration'] = np.log1p(duration) + seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration + seg_data['log_raw_duration'] = np.log1p(raw_duration) + seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0 + + # Deduplicate: skip if this exact metadata already seen + key = tuple(sorted(seg_data.items())) + if key in seen: + num_dupes += 1 + continue + seen.add(key) + + segments.append(seg_data) print(f"Skipped {num_dupes} duplicate segments") return pd.DataFrame(segments) @@ -112,7 +111,7 @@ def log_seed_data(seeds_df, seed_type, label_desc): def main(): # Find logs directory log_dir = None - for pattern in ["ui/dist/logs", "logs", "ui/dist/*/logs", "ui/dist/*/*/logs", "ui/dist/*/*/*/logs"]: + for pattern in ["ui/dist/win-unpacked/resources/logs"]: paths = list(Path(".").glob(pattern)) if paths: log_dir = str(paths[0]) |
