summaryrefslogtreecommitdiffstats
path: root/train_hallucination_filter.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_hallucination_filter.py')
-rw-r--r--train_hallucination_filter.py99
1 files changed, 49 insertions, 50 deletions
diff --git a/train_hallucination_filter.py b/train_hallucination_filter.py
index dc3ce36..9db867f 100644
--- a/train_hallucination_filter.py
+++ b/train_hallucination_filter.py
@@ -19,6 +19,8 @@ from sklearn.model_selection import StratifiedKFold, cross_val_predict
def count_syllables(word):
"""Count syllables in a word using pronouncing library with regex fallback."""
phones = pronouncing.phones_for_word(word.lower())
+ if len(phones) == 0:
+ return 0
return pronouncing.syllable_count(phones[0])
def text_syllable_count(text):
@@ -36,55 +38,52 @@ def load_segments(log_dir):
for file in files:
if not file.endswith('.json'):
continue
- try:
- with open(os.path.join(root, file), 'r') as f:
- data = json.load(f)
-
- for segment in data.get('segments', []):
- if 'duration_sanity' not in segment:
- continue
-
- # Extract all available features
- text = segment.get('text', '')
- duration = segment.get('duration_sanity', 0)
-
- # Calculate raw duration from timestamps
- start_ts = segment.get('start_ts', 0)
- end_ts = segment.get('end_ts', 0)
- raw_duration = end_ts - start_ts
-
- seg_data = {
- 'avg_logprob': segment.get('avg_logprob', 0),
- 'no_speech_prob': segment.get('no_speech_prob', 0),
- 'duration_sanity': duration,
- 'raw_duration': raw_duration,
- 'compression_ratio': segment.get('compression_ratio', 1),
- 'text': text
- }
-
- # Add speech rate features
- n_syllables = text_syllable_count(text)
- seg_data['sps'] = n_syllables / duration
- seg_data['log_sps'] = np.log1p(seg_data['sps'])
- seg_data['raw_sps'] = n_syllables / raw_duration
- seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps'])
-
- # Add derived features
- seg_data['log_duration'] = np.log1p(duration)
- seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration
- seg_data['log_raw_duration'] = np.log1p(raw_duration)
- seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0
-
- # Deduplicate: skip if this exact metadata already seen
- key = tuple(sorted(seg_data.items()))
- if key in seen:
- num_dupes += 1
- continue
- seen.add(key)
-
- segments.append(seg_data)
- except Exception as e:
- print(f"Error loading {file}: {e}")
+ with open(os.path.join(root, file), 'r') as f:
+ data = json.load(f)
+
+ for segment in data["segments"]:
+ if 'duration_sanity' not in segment:
+ continue
+
+ # Extract all available features
+ text = segment["text"]
+ duration = segment["duration_sanity"]
+
+ # Calculate raw duration from timestamps
+ start_ts = segment["start_ts"]
+ end_ts = segment["end_ts"]
+ raw_duration = end_ts - start_ts
+
+ seg_data = {
+ 'avg_logprob': segment["avg_logprob"],
+ 'no_speech_prob': segment["no_speech_prob"],
+ 'duration_sanity': duration,
+ 'raw_duration': raw_duration,
+ 'compression_ratio': segment["compression_ratio"],
+ 'text': text
+ }
+
+ # Add speech rate features
+ n_syllables = text_syllable_count(text)
+ seg_data['sps'] = n_syllables / duration
+ seg_data['log_sps'] = np.log1p(seg_data['sps'])
+ seg_data['raw_sps'] = n_syllables / raw_duration
+ seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps'])
+
+ # Add derived features
+ seg_data['log_duration'] = np.log1p(duration)
+ seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration
+ seg_data['log_raw_duration'] = np.log1p(raw_duration)
+ seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0
+
+ # Deduplicate: skip if this exact metadata already seen
+ key = tuple(sorted(seg_data.items()))
+ if key in seen:
+ num_dupes += 1
+ continue
+ seen.add(key)
+
+ segments.append(seg_data)
print(f"Skipped {num_dupes} duplicate segments")
return pd.DataFrame(segments)
@@ -112,7 +111,7 @@ def log_seed_data(seeds_df, seed_type, label_desc):
def main():
# Find logs directory
log_dir = None
- for pattern in ["ui/dist/logs", "logs", "ui/dist/*/logs", "ui/dist/*/*/logs", "ui/dist/*/*/*/logs"]:
+ for pattern in ["ui/dist/win-unpacked/resources/logs"]:
paths = list(Path(".").glob(pattern))
if paths:
log_dir = str(paths[0])