diff options
| author | yum <yum.food.vr@gmail.com> | 2025-05-30 13:32:36 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-05-30 13:34:23 -0700 |
| commit | 7fb9c575aea4d318e9c14b82174d1b323171b62b (patch) | |
| tree | 8f924a32def3bdc963be40e67879887cbac68f08 | |
| parent | e1b3f638a1ea448de9691f69eb62ebf4c3944c9f (diff) | |
More stuff
- fix unicode output from python terminal
- fix cpu inference
- add filters
- add beam search params to UI
- DRY up config definition in UI
| m--------- | Third_Party/Profanity | 0 | ||||
| -rw-r--r-- | app/hi.py | 4 | ||||
| -rw-r--r-- | app/profanity_filter.py | 43 | ||||
| -rw-r--r-- | app/stt.py | 151 | ||||
| -rw-r--r-- | config.yaml | 20 | ||||
| -rw-r--r-- | ui/config-schema.js | 49 | ||||
| -rw-r--r-- | ui/index.html | 52 | ||||
| -rw-r--r-- | ui/index.js | 49 | ||||
| -rw-r--r-- | ui/renderer.js | 31 |
9 files changed, 311 insertions, 88 deletions
diff --git a/Third_Party/Profanity b/Third_Party/Profanity new file mode 160000 +Subproject 5faf2ba42d7b1c0977169ec3611df25a3c08eb1 @@ -1,5 +1,6 @@ import app_config import argparse +import io from math import floor, ceil import msvcrt import os @@ -11,6 +12,9 @@ import sys import threading import time +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose diff --git a/app/profanity_filter.py b/app/profanity_filter.py new file mode 100644 index 0000000..b8c84ed --- /dev/null +++ b/app/profanity_filter.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +class ProfanityFilter: + def __init__(self, en_path: str): + self.en_path = en_path + self.en_profanity = set() + + def load(self): + with open(self.en_path, 'r') as f: + for line in f: + self.en_profanity.add(line.strip()) + + def filter(self, line: str, language_code: str = "en") -> str: + filtered = "" + + if language_code not in {"en"}: + raise ValueError(f"Language code \"{language_code}\" is " + + "unsupported by the profanity filter") + + # Translation table converting vowels to asterisks. + vowel_to_asterisk = str.maketrans('aeiouAEIOU', '**********') + + result = [] + for word in line.split(): + word_clean = word.lower() + # Filter out non-alphabet characters from the word. + word_clean = ''.join([char for char in word_clean if char.isalpha()]) + if word_clean in self.en_profanity: + result.append(word.translate(vowel_to_asterisk)) + else: + result.append(word) + + return " ".join(result) + +if __name__ == "__main__": + en_path = "/mnt/d/vrc/TaSTT/GUI/Profanity/Profanity/en" + p = ProfanityFilter(en_path) + p.load() + assert(p.filter("fuck") == "f*ck") + assert(p.filter("fuck!") == "f*ck!") + assert(p.filter("fuck shit") == "f*ck sh*t") + assert(p.filter("fuck shit this should not be filtered") == "f*ck sh*t this should not be filtered") + assert(p.filter("ASS") == "*SS") @@ -3,6 +3,12 @@ from faster_whisper import WhisperModel import langcodes import numpy as np import os +try: + from profanity_filter import ProfanityFilter + PROFANITY_FILTER_AVAILABLE = True +except ImportError: + PROFANITY_FILTER_AVAILABLE = False + print("Warning: profanity_filter module not available", file=sys.stderr) import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -12,7 +18,6 @@ import time import typing import wave - APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -297,21 +302,19 @@ class AudioSegmenter: max_speech_s=5): self.min_silence_ms = min_silence_ms self.max_speech_s = max_speech_s - + # Load Silero VAD model self.model = load_silero_vad() - + self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s - - self.speech_pad_ms = 300 def segmentAudio(self, audio: bytes): # Convert audio bytes to numpy array expected by silero-vad audio_array = np.frombuffer(audio, dtype=np.int16).flatten().astype(np.float32) / 32768.0 - + # Get speech timestamps using silero-vad # Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS speech_timestamps = get_speech_timestamps( @@ -323,7 +326,7 @@ class AudioSegmenter: max_speech_duration_s=self.max_speech_duration_s, return_seconds=False # We want frame indices, not seconds ) - + return speech_timestamps # Returns the stable cutoff (if any) and whether there are any segments. @@ -399,27 +402,25 @@ class Whisper: self.model = None self.cfg = cfg - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", + model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: print(f"Model {cfg['model']} will be saved to {model_root}", file=sys.stderr) model_device = "cuda" + compute_type = cfg["compute_type"] if cfg["use_cpu"]: model_device = "cpu" + compute_type = "int8" already_downloaded = os.path.exists(model_root) self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], - compute_type = cfg["compute_type"], + compute_type = compute_type, download_root = model_root, local_files_only = already_downloaded) @@ -436,14 +437,14 @@ class Whisper: def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: frames = self.collector.getAudio() - + # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 # Build context-aware prompt prompt = self._build_prompt() - + t0 = time.time() segments, info = self.model.transcribe( audio, @@ -452,12 +453,9 @@ class Whisper: temperature=0.0, without_timestamps = False, initial_prompt=prompt, - beam_size=5, - best_of=5, - condition_on_previous_text=True, - compression_ratio_threshold=2.4, - log_prob_threshold=-1.0, - no_speech_threshold=0.6 + beam_size=self.cfg.get("beam_size", 5), + best_of=self.cfg.get("best_of", 5), + condition_on_previous_text=True ) res = [] for s in segments: @@ -562,21 +560,21 @@ class VadCommitter: latency_s = self.collector.now() - self.collector.begin() duration_s = stable_cutoff / AudioStream.FPS start_ts = self.collector.begin() - + # Get the filtered audio first, then extract the portion we need filtered_audio = self.collector.getAudio() commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ] - + # Now drop the prefix from the collector self.collector.dropAudioPrefixByFrames(stable_cutoff) segments = self.whisper.transcribe(commit_audio) delta = ''.join(s.transcript for s in segments) - + # Update whisper's context with the committed text if delta.strip(): self.whisper.update_context(delta.strip()) - + audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: @@ -608,6 +606,88 @@ class VadCommitter: duration_s=duration_s, start_ts=start_ts) + +class StreamingPlugin: + def __init__(self): + pass + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + return commit + + def stop(self): + pass + + +class LowercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_lowercase_filter"]: + commit.delta = commit.delta.lower() + commit.preview = commit.preview.lower() + return commit + + +class UppercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_uppercase_filter"]: + commit.delta = commit.delta.upper() + commit.preview = commit.preview.upper() + return commit + + +class ProfanityPlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + self.filter = None + if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") + try: + self.filter = ProfanityFilter(en_profanity_path) + self.filter.load() + except Exception as e: + print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + self.filter = None + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_profanity_filter"] and self.filter: + commit.delta = self.filter.filter(commit.delta) + commit.preview = self.filter.filter(commit.preview) + return commit + + +class PresentationFilter: + def __init__(self): + pass + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + return transcript, preview + + def stop(self): + pass + + +class TrailingPeriodFilter(PresentationFilter): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + if len(preview) == 0: + transcript = _remove_trailing_period(transcript) + else: + preview = _remove_trailing_period(preview) + return transcript, preview + + def transcriptionThread(shared_data: SharedThreadData): last_stable_commit = None @@ -621,6 +701,17 @@ def transcriptionThread(shared_data: SharedThreadData): max_speech_s=shared_data.cfg["max_speech_duration_s"]) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + plugins = [] + # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet + plugins.append(UppercasePlugin(shared_data.cfg)) + plugins.append(LowercasePlugin(shared_data.cfg)) + plugins.append(ProfanityPlugin(shared_data.cfg)) + # plugins.append(UwuPlugin(shared_data.cfg)) # Not implemented yet + # plugins.append(BrowserSource(shared_data.cfg)) # Not implemented yet + + filters = [] + filters.append(TrailingPeriodFilter(shared_data.cfg)) + transcript = "" preview = "" @@ -633,6 +724,9 @@ def transcriptionThread(shared_data: SharedThreadData): commit = committer.getDelta() + for plugin in plugins: + commit = plugin.transform(commit) + if len(commit.delta) > 0 or len(commit.preview) > 0: # Avoid re-sending text after long pauses if shared_data.cfg["reset_after_silence_s"] > 0: @@ -664,6 +758,9 @@ def transcriptionThread(shared_data: SharedThreadData): transcript = join_segments(transcript, commit.delta) preview = commit.preview + for filt in filters: + transcript, preview = filt.transform(transcript, preview) + try: print(f"Transcript: {transcript}", flush=True) except UnicodeEncodeError: @@ -691,4 +788,8 @@ def transcriptionThread(shared_data: SharedThreadData): (not commit.delta.endswith(' ')) and \ (not commit.preview.startswith(' ')): commit.preview = ' ' + commit.preview + for plugin in plugins: + plugin.stop() + for filt in filters: + filt.stop() diff --git a/config.yaml b/config.yaml index 5eec7a2..fea03bb 100644 --- a/config.yaml +++ b/config.yaml @@ -1,18 +1,24 @@ compute_type: float16 -enable_debug_mode: 0 -enable_previews: 1 -user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. -save_audio: 0 language: english +model: turbo +microphone: 2 +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. gpu_idx: 0 max_speech_duration_s: 10 min_silence_duration_ms: 250 -microphone: 0 -model: turbo reset_after_silence_s: 15 transcription_loop_delay_ms: 100 -use_cpu: 0 block_width: 2 num_blocks: 40 rows: 10 cols: 24 +beam_size: 5 +best_of: 5 +enable_debug_mode: 0 +enable_previews: 1 +save_audio: 0 +use_cpu: 0 +enable_lowercase_filter: 0 +enable_uppercase_filter: 0 +enable_profanity_filter: 0 +remove_trailing_period: 0 diff --git a/ui/config-schema.js b/ui/config-schema.js new file mode 100644 index 0000000..b1108ff --- /dev/null +++ b/ui/config-schema.js @@ -0,0 +1,49 @@ +// Shared configuration schema with types and defaults +const CONFIG_SCHEMA = { + // String fields + compute_type: { type: 'select', default: 'float16' }, + language: { type: 'select', default: 'english' }, + model: { type: 'select', default: 'turbo' }, + microphone: { type: 'number', default: 0 }, + user_prompt: { type: 'text', default: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.' }, + + // Number fields + gpu_idx: { type: 'number', default: 0 }, + max_speech_duration_s: { type: 'number', default: 10 }, + min_silence_duration_ms: { type: 'number', default: 250 }, + reset_after_silence_s: { type: 'number', default: 15 }, + transcription_loop_delay_ms: { type: 'number', default: 100 }, + block_width: { type: 'number', default: 2 }, + num_blocks: { type: 'number', default: 40 }, + rows: { type: 'number', default: 10 }, + cols: { type: 'number', default: 24 }, + beam_size: { type: 'number', default: 5 }, + best_of: { type: 'number', default: 5 }, + + // Boolean fields (stored as 1/0) + enable_debug_mode: { type: 'boolean', default: 0 }, + enable_previews: { type: 'boolean', default: 1 }, + save_audio: { type: 'boolean', default: 0 }, + use_cpu: { type: 'boolean', default: 0 }, + enable_lowercase_filter: { type: 'boolean', default: 0 }, + enable_uppercase_filter: { type: 'boolean', default: 0 }, + enable_profanity_filter: { type: 'boolean', default: 0 }, + remove_trailing_period: { type: 'boolean', default: 0 } +}; + +// Helper to extract just the default values +function getDefaultConfig() { + const defaults = {}; + for (const [key, schema] of Object.entries(CONFIG_SCHEMA)) { + defaults[key] = schema.default; + } + return defaults; +} + +// Export for both CommonJS (main process) and ES modules (renderer) +if (typeof module !== 'undefined' && module.exports) { + module.exports = { CONFIG_SCHEMA, getDefaultConfig }; +} else { + window.CONFIG_SCHEMA = CONFIG_SCHEMA; + window.getDefaultConfig = getDefaultConfig; +}
\ No newline at end of file diff --git a/ui/index.html b/ui/index.html index 90f78c1..97da3d2 100644 --- a/ui/index.html +++ b/ui/index.html @@ -10,9 +10,9 @@ <div class="container-fluid px-6 py-6 h-screen flex flex-col"> <div class="flex flex-1 gap-6 overflow-hidden"> <!-- Left Panel: Configuration Form --> - <div class="max-w-96 relative flex flex-col"> + <div class="max-w-96 relative flex flex-col overflow-hidden rounded-lg"> <!-- Loading Overlay --> - <div id="loading-overlay" class="absolute inset-0 bg-white bg-opacity-75 backdrop-blur-sm z-50 hidden flex items-center justify-center rounded-lg"> + <div id="loading-overlay" class="absolute inset-0 bg-white bg-opacity-75 backdrop-blur-sm z-50 hidden flex items-center justify-center"> <div class="text-center p-6"> <div class="animate-spin rounded-full h-12 w-12 border-b-2 border-blue-600 mx-auto mb-4"></div> <p class="text-gray-700 font-medium"></p> @@ -126,7 +126,7 @@ <h2 class="section-title">Transcription Settings</h2> <div> <label for="user_prompt" class="form-label"> - Custom Prompt + Prompt <span class="text-gray-500 text-xs block mt-1" title="Whisper is given this prompt before transcribing. It helps guide the transcription style. For example, you could improve the spelling of your friends' names with: 'My friends' names are Saoirse, Azariah, and Caoimhe.'"> (Hover for details) @@ -136,6 +136,28 @@ class="form-input h-20 resize-none" placeholder="My friends' names are Saoirse, Azariah, and Caoimhe."></textarea> </div> + <div class="grid grid-cols-2 gap-4 mt-4"> + <div> + <label for="beam_size" class="form-label"> + Beam size + <span class="text-gray-500 text-xs block mt-1" + title="Number of beams for beam search. Higher values may improve accuracy but increase compute time."> + (Search width) + </span> + </label> + <input type="number" id="beam_size" min="1" max="20" value="5" class="form-input"> + </div> + <div> + <label for="best_of" class="form-label"> + Best of + <span class="text-gray-500 text-xs block mt-1" + title="Number of candidates to generate when sampling. The best one will be selected."> + (Sampling candidates) + </span> + </label> + <input type="number" id="best_of" min="1" max="20" value="5" class="form-input"> + </div> + </div> </section> <!-- Performance Settings --> @@ -166,6 +188,29 @@ </div> </section> + <!-- Text Filters --> + <section class="config-section"> + <h2 class="section-title">Text Filters</h2> + <div class="space-y-3"> + <label for="enable_lowercase_filter" class="checkbox-label"> + <input type="checkbox" id="enable_lowercase_filter" class="mr-2"> + <span class="checkbox-text">Convert to lowercase</span> + </label> + <label for="enable_uppercase_filter" class="checkbox-label"> + <input type="checkbox" id="enable_uppercase_filter" class="mr-2"> + <span class="checkbox-text">Convert to uppercase</span> + </label> + <label for="enable_profanity_filter" class="checkbox-label"> + <input type="checkbox" id="enable_profanity_filter" class="mr-2"> + <span class="checkbox-text">Filter profanity</span> + </label> + <label for="remove_trailing_period" class="checkbox-label"> + <input type="checkbox" id="remove_trailing_period" class="mr-2"> + <span class="checkbox-text">Remove trailing period</span> + </label> + </div> + </section> + <!-- Display Settings --> <section class="config-section"> <h2 class="section-title">Custom Chatbox Settings</h2> @@ -240,6 +285,7 @@ </div> </div> + <script src="config-schema.js"></script> <script src="renderer.js"></script> </body> </html> diff --git a/ui/index.js b/ui/index.js index 2420ece..7717c92 100644 --- a/ui/index.js +++ b/ui/index.js @@ -4,6 +4,7 @@ const fs = require('node:fs').promises; const yaml = require('js-yaml'); const { spawn } = require('child_process'); const https = require('https'); +const { CONFIG_SCHEMA, getDefaultConfig } = require('./config-schema.js'); const APP_ROOT = path.join(__dirname, '..'); const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); @@ -82,6 +83,14 @@ function downloadFile(url, outputPath) { }); } +function shouldFilterMessage(message) { + // Filter out pydub ffmpeg/avconv warning. It does not actually matter. + if (message.includes("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work")) { + return true; + } + return false; +} + // Helper function to setup process event handlers function setupProcessHandlers(process) { process.stdout.on('data', (data) => { @@ -91,7 +100,9 @@ function setupProcessHandlers(process) { process.stderr.on('data', (data) => { const text = data.toString(); - sendPythonOutput(text.trimEnd(), 'stderr'); + if (!shouldFilterMessage(text)) { + sendPythonOutput(text.trimEnd(), 'stderr'); + } }); process.on('error', (error) => { @@ -137,7 +148,10 @@ function executePythonCommand(args, options = {}) { pythonProcess.stderr.on('data', (data) => { const text = data.toString(); stderr += text; - sendPythonOutput(text.trimEnd(), 'stderr'); + // Filter out specific warning messages + if (!shouldFilterMessage(text)) { + sendPythonOutput(text.trimEnd(), 'stderr'); + } }); pythonProcess.on('error', (error) => { @@ -171,27 +185,8 @@ function createWindow () { mainWindow.loadFile('index.html'); } -// Default configuration based on user's current config.yaml -const DEFAULT_CONFIG = { - compute_type: 'float16', - enable_debug_mode: 0, - enable_previews: 1, - user_prompt: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc.', - save_audio: 0, - language: 'english', - gpu_idx: 0, - max_speech_duration_s: 10, - min_silence_duration_ms: 250, - microphone: 0, - model: 'turbo', - reset_after_silence_s: 15, - transcription_loop_delay_ms: 100, - use_cpu: 0, - block_width: 2, - num_blocks: 40, - rows: 10, - cols: 24 -}; +// Replace the DEFAULT_CONFIG constant with: +const DEFAULT_CONFIG = getDefaultConfig(); // IPC handlers ipcMain.handle('load-config', async () => { @@ -521,12 +516,12 @@ ipcMain.handle('start-process', async () => { }); ipcMain.handle('stop-process', async () => { - if (!runningProcess) { - throw new Error('No process is running'); - } - return new Promise((resolve) => { let forcefullyKilled = false; + + if (!runningProcess) { + resolve({ success: true, forcefullyKilled }); + } // Set up a timeout to force kill after 10 seconds const killTimeout = setTimeout(() => { diff --git a/ui/renderer.js b/ui/renderer.js index 201eef6..133a79b 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -1,29 +1,5 @@ -// Configuration and form field mappings -const CONFIG_FIELDS = { - // String fields - compute_type: { type: 'select', default: 'float16' }, - language: { type: 'select', default: 'english' }, - model: { type: 'select', default: 'turbo' }, - microphone: { type: 'number', default: 0 }, - user_prompt: { type: 'text', default: '' }, - - // Number fields - gpu_idx: { type: 'number', default: 0 }, - max_speech_duration_s: { type: 'number', default: 10 }, - min_silence_duration_ms: { type: 'number', default: 250 }, - reset_after_silence_s: { type: 'number', default: 15 }, - transcription_loop_delay_ms: { type: 'number', default: 100 }, - block_width: { type: 'number', default: 2 }, - num_blocks: { type: 'number', default: 40 }, - rows: { type: 'number', default: 10 }, - cols: { type: 'number', default: 24 }, - - // Boolean fields (stored as 1/0) - enable_debug_mode: { type: 'boolean', default: 0 }, - enable_previews: { type: 'boolean', default: 1 }, - save_audio: { type: 'boolean', default: 0 }, - use_cpu: { type: 'boolean', default: 0 } -}; +// Import configuration schema +const CONFIG_FIELDS = window.CONFIG_SCHEMA; // Button management system class ButtonManager { @@ -35,6 +11,9 @@ class ButtonManager { resetVenv: document.getElementById('reset-venv'), refreshMicrophones: document.getElementById('refresh-microphones') }; + + // Initialize button states on construction + this.setProcessStopped(); } setState(buttonName, disabled) { |
