summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--app/hallucination_filter.py84
-rw-r--r--app/stt.py2
-rw-r--r--config.yaml4
-rw-r--r--train_hallucination_filter.py99
-rw-r--r--ui/index.html2
-rw-r--r--ui/renderer.js6
6 files changed, 113 insertions, 84 deletions
diff --git a/app/hallucination_filter.py b/app/hallucination_filter.py
index fa7b16a..1a80e62 100644
--- a/app/hallucination_filter.py
+++ b/app/hallucination_filter.py
@@ -11,36 +11,39 @@ import sys
def count_syllables(word):
"""Count syllables in a word using pronouncing library with regex fallback."""
phones = pronouncing.phones_for_word(word.lower())
+ if len(phones) == 0:
+ return 0
return pronouncing.syllable_count(phones[0])
+
def text_syllable_count(text):
"""Count total syllables in text."""
words = re.findall(r'\b\w+\b', text)
return sum(count_syllables(word) for word in words)
+
class HallucinationFilter:
"""Filter for detecting hallucinated segments in speech-to-text output."""
- def __init__(self, model_path: Path = None):
+ def __init__(self, cfg, model_path: Path = None):
"""
Initialize the hallucination filter.
Args:
model_path: Optional path to the model file. If not provided,
uses the default path.
"""
+ self.cfg = cfg
self.model = None
self.threshold = None
self.features = None
# Get the project root directory
app_root = Path(__file__).resolve().parent
project_root = app_root.parent
- # Use provided path or default
- if model_path is None:
- model_path = project_root / "Models" / "thankyou_filter_gb.pkl"
+ model_path = project_root / "Models" / "thankyou_filter_gb.pkl"
# Try to load the model
- log_err(f"Loading hallucination filter")
+ log(f"Loading hallucination filter")
bundle = joblib.load(model_path)
self.model = bundle["model"]
self.threshold = bundle["threshold"]
self.features = bundle["features"]
- log_err(f"Loaded hallucination filter model from {model_path}")
+ log(f"Loaded hallucination filter model from {model_path}")
def is_hallucination(self, segment) -> bool:
"""
Check if a segment is likely a hallucination.
@@ -51,25 +54,50 @@ class HallucinationFilter:
Returns:
bool: True if the segment is likely a hallucination, False otherwise.
"""
- # Calculate text-based features
- text = getattr(segment, 'text', '')
- duration = segment.audio_len_s
- raw_duration = segment.end_ts - segment.start_ts
- n_syllables = text_syllable_count(text)
- sps = n_syllables / duration
- raw_sps = n_syllables / raw_duration
- duration_ratio = raw_duration / duration
- X = pd.DataFrame([[
- segment.avg_logprob,
- segment.no_speech_prob,
- segment.compression_ratio,
- np.log1p(duration),
- np.log1p(sps),
- np.log1p(raw_duration),
- np.log1p(raw_sps),
- duration_ratio,
- segment.avg_logprob * duration
- ]], columns=self.features)
- # Get probability
- prob = self.model.predict_proba(X)[0, 1]
- return prob >= self.threshold
+ s = segment # Brevity
+
+ if s.no_speech_prob == 0:
+ # no_speech is not available. Use fancy classifier trained on my
+ # speech data.
+ text = s.transcript
+ duration = s.audio_len_s
+ raw_duration = s.end_ts - s.start_ts
+ n_syllables = text_syllable_count(text)
+ sps = n_syllables / duration
+ raw_sps = n_syllables / raw_duration
+ duration_ratio = raw_duration / duration
+ X = pd.DataFrame([[
+ s.avg_logprob,
+ s.no_speech_prob,
+ s.compression_ratio,
+ np.log1p(duration),
+ np.log1p(sps),
+ np.log1p(raw_duration),
+ np.log1p(raw_sps),
+ duration_ratio,
+ s.avg_logprob * duration
+ ]], columns=self.features)
+ # Get probability
+ prob = self.model.predict_proba(X)[0, 1]
+ return prob >= self.threshold
+
+ # If no_speech is set, use simpler filter.
+ if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 1) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ return True
+ # Another touchup targeted at the vexatious "thanks for watching!"
+ # hallucination. This triggers a lot when listening to
+ # instrumental/electronic music.
+ if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 2) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ return True
+ return False
+
diff --git a/app/stt.py b/app/stt.py
index 9947bae..78da707 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -485,7 +485,7 @@ class Whisper:
self.collector = collector
self.model = None
self.cfg = cfg
- self.hallucination_filter = HallucinationFilter()
+ self.hallucination_filter = HallucinationFilter(cfg)
self.segment_logger = segment_logger
model_str = cfg["model"]
diff --git a/config.yaml b/config.yaml
index 9cec4a3..519457b 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,8 +1,8 @@
compute_type: float16
language: english
-model: turbo
+model: distil-large-v3.5
microphone: 4
-user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.
+user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc.
keybind: ctrl+alt+x
button_hand: right
button_type: b
diff --git a/train_hallucination_filter.py b/train_hallucination_filter.py
index dc3ce36..9db867f 100644
--- a/train_hallucination_filter.py
+++ b/train_hallucination_filter.py
@@ -19,6 +19,8 @@ from sklearn.model_selection import StratifiedKFold, cross_val_predict
def count_syllables(word):
"""Count syllables in a word using pronouncing library with regex fallback."""
phones = pronouncing.phones_for_word(word.lower())
+ if len(phones) == 0:
+ return 0
return pronouncing.syllable_count(phones[0])
def text_syllable_count(text):
@@ -36,55 +38,52 @@ def load_segments(log_dir):
for file in files:
if not file.endswith('.json'):
continue
- try:
- with open(os.path.join(root, file), 'r') as f:
- data = json.load(f)
-
- for segment in data.get('segments', []):
- if 'duration_sanity' not in segment:
- continue
-
- # Extract all available features
- text = segment.get('text', '')
- duration = segment.get('duration_sanity', 0)
-
- # Calculate raw duration from timestamps
- start_ts = segment.get('start_ts', 0)
- end_ts = segment.get('end_ts', 0)
- raw_duration = end_ts - start_ts
-
- seg_data = {
- 'avg_logprob': segment.get('avg_logprob', 0),
- 'no_speech_prob': segment.get('no_speech_prob', 0),
- 'duration_sanity': duration,
- 'raw_duration': raw_duration,
- 'compression_ratio': segment.get('compression_ratio', 1),
- 'text': text
- }
-
- # Add speech rate features
- n_syllables = text_syllable_count(text)
- seg_data['sps'] = n_syllables / duration
- seg_data['log_sps'] = np.log1p(seg_data['sps'])
- seg_data['raw_sps'] = n_syllables / raw_duration
- seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps'])
-
- # Add derived features
- seg_data['log_duration'] = np.log1p(duration)
- seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration
- seg_data['log_raw_duration'] = np.log1p(raw_duration)
- seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0
-
- # Deduplicate: skip if this exact metadata already seen
- key = tuple(sorted(seg_data.items()))
- if key in seen:
- num_dupes += 1
- continue
- seen.add(key)
-
- segments.append(seg_data)
- except Exception as e:
- print(f"Error loading {file}: {e}")
+ with open(os.path.join(root, file), 'r') as f:
+ data = json.load(f)
+
+ for segment in data["segments"]:
+ if 'duration_sanity' not in segment:
+ continue
+
+ # Extract all available features
+ text = segment["text"]
+ duration = segment["duration_sanity"]
+
+ # Calculate raw duration from timestamps
+ start_ts = segment["start_ts"]
+ end_ts = segment["end_ts"]
+ raw_duration = end_ts - start_ts
+
+ seg_data = {
+ 'avg_logprob': segment["avg_logprob"],
+ 'no_speech_prob': segment["no_speech_prob"],
+ 'duration_sanity': duration,
+ 'raw_duration': raw_duration,
+ 'compression_ratio': segment["compression_ratio"],
+ 'text': text
+ }
+
+ # Add speech rate features
+ n_syllables = text_syllable_count(text)
+ seg_data['sps'] = n_syllables / duration
+ seg_data['log_sps'] = np.log1p(seg_data['sps'])
+ seg_data['raw_sps'] = n_syllables / raw_duration
+ seg_data['log_raw_sps'] = np.log1p(seg_data['raw_sps'])
+
+ # Add derived features
+ seg_data['log_duration'] = np.log1p(duration)
+ seg_data['logprob_duration_interaction'] = seg_data['avg_logprob'] * duration
+ seg_data['log_raw_duration'] = np.log1p(raw_duration)
+ seg_data['duration_ratio'] = raw_duration / duration if duration > 0 else 1.0
+
+ # Deduplicate: skip if this exact metadata already seen
+ key = tuple(sorted(seg_data.items()))
+ if key in seen:
+ num_dupes += 1
+ continue
+ seen.add(key)
+
+ segments.append(seg_data)
print(f"Skipped {num_dupes} duplicate segments")
return pd.DataFrame(segments)
@@ -112,7 +111,7 @@ def log_seed_data(seeds_df, seed_type, label_desc):
def main():
# Find logs directory
log_dir = None
- for pattern in ["ui/dist/logs", "logs", "ui/dist/*/logs", "ui/dist/*/*/logs", "ui/dist/*/*/*/logs"]:
+ for pattern in ["ui/dist/win-unpacked/resources/logs"]:
paths = list(Path(".").glob(pattern))
if paths:
log_dir = str(paths[0])
diff --git a/ui/index.html b/ui/index.html
index 70eaa68..1828b90 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -32,8 +32,8 @@
<option value="base">base</option>
<option value="small">small</option>
<option value="medium">medium</option>
+ <option value="distil-large-v3.5">distilled large 3.5</option>
<option value="large">large</option>
- <option value="turbo">turbo</option>
</select>
</div>
<div>
diff --git a/ui/renderer.js b/ui/renderer.js
index c2df4d2..05f1866 100644
--- a/ui/renderer.js
+++ b/ui/renderer.js
@@ -88,7 +88,8 @@ class LoadingOverlay {
this.overlay.classList.add('hidden');
// Restore original states of form inputs and buttons
const leftPanel = this.overlay.parentElement;
- const inputs = leftPanel.querySelectorAll('input, select, textarea, button');
+ const inputs = leftPanel.querySelectorAll(
+ 'input, select, textarea, button');
inputs.forEach(input => {
// Restore original disabled state
input.disabled = this.originalStates.get(input) || false;
@@ -107,7 +108,8 @@ function showStatus(message, type = 'info') {
statusEl.textContent = message;
// Remove all status classes
- const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800'];
+ const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100',
+ 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800'];
statusEl.classList.remove(...statusClasses);
// Add appropriate classes based on type