diff options
Diffstat (limited to 'app')
| -rw-r--r-- | app/hi.py | 4 | ||||
| -rw-r--r-- | app/profanity_filter.py | 43 | ||||
| -rw-r--r-- | app/stt.py | 151 |
3 files changed, 173 insertions, 25 deletions
@@ -1,5 +1,6 @@ import app_config import argparse +import io from math import floor, ceil import msvcrt import os @@ -11,6 +12,9 @@ import sys import threading import time +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose diff --git a/app/profanity_filter.py b/app/profanity_filter.py new file mode 100644 index 0000000..b8c84ed --- /dev/null +++ b/app/profanity_filter.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +class ProfanityFilter: + def __init__(self, en_path: str): + self.en_path = en_path + self.en_profanity = set() + + def load(self): + with open(self.en_path, 'r') as f: + for line in f: + self.en_profanity.add(line.strip()) + + def filter(self, line: str, language_code: str = "en") -> str: + filtered = "" + + if language_code not in {"en"}: + raise ValueError(f"Language code \"{language_code}\" is " + + "unsupported by the profanity filter") + + # Translation table converting vowels to asterisks. + vowel_to_asterisk = str.maketrans('aeiouAEIOU', '**********') + + result = [] + for word in line.split(): + word_clean = word.lower() + # Filter out non-alphabet characters from the word. + word_clean = ''.join([char for char in word_clean if char.isalpha()]) + if word_clean in self.en_profanity: + result.append(word.translate(vowel_to_asterisk)) + else: + result.append(word) + + return " ".join(result) + +if __name__ == "__main__": + en_path = "/mnt/d/vrc/TaSTT/GUI/Profanity/Profanity/en" + p = ProfanityFilter(en_path) + p.load() + assert(p.filter("fuck") == "f*ck") + assert(p.filter("fuck!") == "f*ck!") + assert(p.filter("fuck shit") == "f*ck sh*t") + assert(p.filter("fuck shit this should not be filtered") == "f*ck sh*t this should not be filtered") + assert(p.filter("ASS") == "*SS") @@ -3,6 +3,12 @@ from faster_whisper import WhisperModel import langcodes import numpy as np import os +try: + from profanity_filter import ProfanityFilter + PROFANITY_FILTER_AVAILABLE = True +except ImportError: + PROFANITY_FILTER_AVAILABLE = False + print("Warning: profanity_filter module not available", file=sys.stderr) import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -12,7 +18,6 @@ import time import typing import wave - APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -297,21 +302,19 @@ class AudioSegmenter: max_speech_s=5): self.min_silence_ms = min_silence_ms self.max_speech_s = max_speech_s - + # Load Silero VAD model self.model = load_silero_vad() - + self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s - - self.speech_pad_ms = 300 def segmentAudio(self, audio: bytes): # Convert audio bytes to numpy array expected by silero-vad audio_array = np.frombuffer(audio, dtype=np.int16).flatten().astype(np.float32) / 32768.0 - + # Get speech timestamps using silero-vad # Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS speech_timestamps = get_speech_timestamps( @@ -323,7 +326,7 @@ class AudioSegmenter: max_speech_duration_s=self.max_speech_duration_s, return_seconds=False # We want frame indices, not seconds ) - + return speech_timestamps # Returns the stable cutoff (if any) and whether there are any segments. @@ -399,27 +402,25 @@ class Whisper: self.model = None self.cfg = cfg - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", + model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: print(f"Model {cfg['model']} will be saved to {model_root}", file=sys.stderr) model_device = "cuda" + compute_type = cfg["compute_type"] if cfg["use_cpu"]: model_device = "cpu" + compute_type = "int8" already_downloaded = os.path.exists(model_root) self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], - compute_type = cfg["compute_type"], + compute_type = compute_type, download_root = model_root, local_files_only = already_downloaded) @@ -436,14 +437,14 @@ class Whisper: def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: frames = self.collector.getAudio() - + # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 # Build context-aware prompt prompt = self._build_prompt() - + t0 = time.time() segments, info = self.model.transcribe( audio, @@ -452,12 +453,9 @@ class Whisper: temperature=0.0, without_timestamps = False, initial_prompt=prompt, - beam_size=5, - best_of=5, - condition_on_previous_text=True, - compression_ratio_threshold=2.4, - log_prob_threshold=-1.0, - no_speech_threshold=0.6 + beam_size=self.cfg.get("beam_size", 5), + best_of=self.cfg.get("best_of", 5), + condition_on_previous_text=True ) res = [] for s in segments: @@ -562,21 +560,21 @@ class VadCommitter: latency_s = self.collector.now() - self.collector.begin() duration_s = stable_cutoff / AudioStream.FPS start_ts = self.collector.begin() - + # Get the filtered audio first, then extract the portion we need filtered_audio = self.collector.getAudio() commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ] - + # Now drop the prefix from the collector self.collector.dropAudioPrefixByFrames(stable_cutoff) segments = self.whisper.transcribe(commit_audio) delta = ''.join(s.transcript for s in segments) - + # Update whisper's context with the committed text if delta.strip(): self.whisper.update_context(delta.strip()) - + audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: @@ -608,6 +606,88 @@ class VadCommitter: duration_s=duration_s, start_ts=start_ts) + +class StreamingPlugin: + def __init__(self): + pass + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + return commit + + def stop(self): + pass + + +class LowercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_lowercase_filter"]: + commit.delta = commit.delta.lower() + commit.preview = commit.preview.lower() + return commit + + +class UppercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_uppercase_filter"]: + commit.delta = commit.delta.upper() + commit.preview = commit.preview.upper() + return commit + + +class ProfanityPlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + self.filter = None + if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") + try: + self.filter = ProfanityFilter(en_profanity_path) + self.filter.load() + except Exception as e: + print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + self.filter = None + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_profanity_filter"] and self.filter: + commit.delta = self.filter.filter(commit.delta) + commit.preview = self.filter.filter(commit.preview) + return commit + + +class PresentationFilter: + def __init__(self): + pass + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + return transcript, preview + + def stop(self): + pass + + +class TrailingPeriodFilter(PresentationFilter): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + if len(preview) == 0: + transcript = _remove_trailing_period(transcript) + else: + preview = _remove_trailing_period(preview) + return transcript, preview + + def transcriptionThread(shared_data: SharedThreadData): last_stable_commit = None @@ -621,6 +701,17 @@ def transcriptionThread(shared_data: SharedThreadData): max_speech_s=shared_data.cfg["max_speech_duration_s"]) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + plugins = [] + # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet + plugins.append(UppercasePlugin(shared_data.cfg)) + plugins.append(LowercasePlugin(shared_data.cfg)) + plugins.append(ProfanityPlugin(shared_data.cfg)) + # plugins.append(UwuPlugin(shared_data.cfg)) # Not implemented yet + # plugins.append(BrowserSource(shared_data.cfg)) # Not implemented yet + + filters = [] + filters.append(TrailingPeriodFilter(shared_data.cfg)) + transcript = "" preview = "" @@ -633,6 +724,9 @@ def transcriptionThread(shared_data: SharedThreadData): commit = committer.getDelta() + for plugin in plugins: + commit = plugin.transform(commit) + if len(commit.delta) > 0 or len(commit.preview) > 0: # Avoid re-sending text after long pauses if shared_data.cfg["reset_after_silence_s"] > 0: @@ -664,6 +758,9 @@ def transcriptionThread(shared_data: SharedThreadData): transcript = join_segments(transcript, commit.delta) preview = commit.preview + for filt in filters: + transcript, preview = filt.transform(transcript, preview) + try: print(f"Transcript: {transcript}", flush=True) except UnicodeEncodeError: @@ -691,4 +788,8 @@ def transcriptionThread(shared_data: SharedThreadData): (not commit.delta.endswith(' ')) and \ (not commit.preview.startswith(' ')): commit.preview = ' ' + commit.preview + for plugin in plugins: + plugin.stop() + for filt in filters: + filt.stop() |
