summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
Diffstat (limited to 'app')
-rw-r--r--app/hi.py4
-rw-r--r--app/profanity_filter.py43
-rw-r--r--app/stt.py151
3 files changed, 173 insertions, 25 deletions
diff --git a/app/hi.py b/app/hi.py
index e6877ff..bab0fd4 100644
--- a/app/hi.py
+++ b/app/hi.py
@@ -1,5 +1,6 @@
import app_config
import argparse
+import io
from math import floor, ceil
import msvcrt
import os
@@ -11,6 +12,9 @@ import sys
import threading
import time
+sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
+sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+
TESTS_ENABLED = True
# 0 = quiet, 1 = verbose, 2 = very verbose
diff --git a/app/profanity_filter.py b/app/profanity_filter.py
new file mode 100644
index 0000000..b8c84ed
--- /dev/null
+++ b/app/profanity_filter.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+
+class ProfanityFilter:
+ def __init__(self, en_path: str):
+ self.en_path = en_path
+ self.en_profanity = set()
+
+ def load(self):
+ with open(self.en_path, 'r') as f:
+ for line in f:
+ self.en_profanity.add(line.strip())
+
+ def filter(self, line: str, language_code: str = "en") -> str:
+ filtered = ""
+
+ if language_code not in {"en"}:
+ raise ValueError(f"Language code \"{language_code}\" is " +
+ "unsupported by the profanity filter")
+
+ # Translation table converting vowels to asterisks.
+ vowel_to_asterisk = str.maketrans('aeiouAEIOU', '**********')
+
+ result = []
+ for word in line.split():
+ word_clean = word.lower()
+ # Filter out non-alphabet characters from the word.
+ word_clean = ''.join([char for char in word_clean if char.isalpha()])
+ if word_clean in self.en_profanity:
+ result.append(word.translate(vowel_to_asterisk))
+ else:
+ result.append(word)
+
+ return " ".join(result)
+
+if __name__ == "__main__":
+ en_path = "/mnt/d/vrc/TaSTT/GUI/Profanity/Profanity/en"
+ p = ProfanityFilter(en_path)
+ p.load()
+ assert(p.filter("fuck") == "f*ck")
+ assert(p.filter("fuck!") == "f*ck!")
+ assert(p.filter("fuck shit") == "f*ck sh*t")
+ assert(p.filter("fuck shit this should not be filtered") == "f*ck sh*t this should not be filtered")
+ assert(p.filter("ASS") == "*SS")
diff --git a/app/stt.py b/app/stt.py
index 7d76333..a3988e1 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -3,6 +3,12 @@ from faster_whisper import WhisperModel
import langcodes
import numpy as np
import os
+try:
+ from profanity_filter import ProfanityFilter
+ PROFANITY_FILTER_AVAILABLE = True
+except ImportError:
+ PROFANITY_FILTER_AVAILABLE = False
+ print("Warning: profanity_filter module not available", file=sys.stderr)
import pyaudio
from pydub import AudioSegment
from shared_thread_data import SharedThreadData
@@ -12,7 +18,6 @@ import time
import typing
import wave
-
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(APP_ROOT)
@@ -297,21 +302,19 @@ class AudioSegmenter:
max_speech_s=5):
self.min_silence_ms = min_silence_ms
self.max_speech_s = max_speech_s
-
+
# Load Silero VAD model
self.model = load_silero_vad()
-
+
self.vad_threshold = 0.3
self.min_silence_duration_ms = min_silence_ms
self.max_speech_duration_s = max_speech_s
-
- self.speech_pad_ms = 300
def segmentAudio(self, audio: bytes):
# Convert audio bytes to numpy array expected by silero-vad
audio_array = np.frombuffer(audio,
dtype=np.int16).flatten().astype(np.float32) / 32768.0
-
+
# Get speech timestamps using silero-vad
# Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS
speech_timestamps = get_speech_timestamps(
@@ -323,7 +326,7 @@ class AudioSegmenter:
max_speech_duration_s=self.max_speech_duration_s,
return_seconds=False # We want frame indices, not seconds
)
-
+
return speech_timestamps
# Returns the stable cutoff (if any) and whether there are any segments.
@@ -399,27 +402,25 @@ class Whisper:
self.model = None
self.cfg = cfg
- abspath = os.path.abspath(__file__)
- my_dir = os.path.dirname(abspath)
- parent_dir = os.path.dirname(my_dir)
-
model_str = cfg["model"]
- model_root = os.path.join(parent_dir, "Models",
+ model_root = os.path.join(PROJECT_ROOT, "Models",
os.path.normpath(model_str))
if cfg["enable_debug_mode"]:
print(f"Model {cfg['model']} will be saved to {model_root}",
file=sys.stderr)
model_device = "cuda"
+ compute_type = cfg["compute_type"]
if cfg["use_cpu"]:
model_device = "cpu"
+ compute_type = "int8"
already_downloaded = os.path.exists(model_root)
self.model = WhisperModel(model_str,
device = model_device,
device_index = cfg["gpu_idx"],
- compute_type = cfg["compute_type"],
+ compute_type = compute_type,
download_root = model_root,
local_files_only = already_downloaded)
@@ -436,14 +437,14 @@ class Whisper:
def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
if frames is None:
frames = self.collector.getAudio()
-
+
# Convert audio to float32
audio = np.frombuffer(frames,
dtype=np.int16).flatten().astype(np.float32) / 32768.0
# Build context-aware prompt
prompt = self._build_prompt()
-
+
t0 = time.time()
segments, info = self.model.transcribe(
audio,
@@ -452,12 +453,9 @@ class Whisper:
temperature=0.0,
without_timestamps = False,
initial_prompt=prompt,
- beam_size=5,
- best_of=5,
- condition_on_previous_text=True,
- compression_ratio_threshold=2.4,
- log_prob_threshold=-1.0,
- no_speech_threshold=0.6
+ beam_size=self.cfg.get("beam_size", 5),
+ best_of=self.cfg.get("best_of", 5),
+ condition_on_previous_text=True
)
res = []
for s in segments:
@@ -562,21 +560,21 @@ class VadCommitter:
latency_s = self.collector.now() - self.collector.begin()
duration_s = stable_cutoff / AudioStream.FPS
start_ts = self.collector.begin()
-
+
# Get the filtered audio first, then extract the portion we need
filtered_audio = self.collector.getAudio()
commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ]
-
+
# Now drop the prefix from the collector
self.collector.dropAudioPrefixByFrames(stable_cutoff)
segments = self.whisper.transcribe(commit_audio)
delta = ''.join(s.transcript for s in segments)
-
+
# Update whisper's context with the committed text
if delta.strip():
self.whisper.update_context(delta.strip())
-
+
audio = self.collector.getAudio()
if self.cfg["enable_debug_mode"]:
for s in segments:
@@ -608,6 +606,88 @@ class VadCommitter:
duration_s=duration_s,
start_ts=start_ts)
+
+class StreamingPlugin:
+ def __init__(self):
+ pass
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ return commit
+
+ def stop(self):
+ pass
+
+
+class LowercasePlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_lowercase_filter"]:
+ commit.delta = commit.delta.lower()
+ commit.preview = commit.preview.lower()
+ return commit
+
+
+class UppercasePlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_uppercase_filter"]:
+ commit.delta = commit.delta.upper()
+ commit.preview = commit.preview.upper()
+ return commit
+
+
+class ProfanityPlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.filter = None
+ if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]:
+ en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en")
+ try:
+ self.filter = ProfanityFilter(en_profanity_path)
+ self.filter.load()
+ except Exception as e:
+ print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr)
+ self.filter = None
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_profanity_filter"] and self.filter:
+ commit.delta = self.filter.filter(commit.delta)
+ commit.preview = self.filter.filter(commit.preview)
+ return commit
+
+
+class PresentationFilter:
+ def __init__(self):
+ pass
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ return transcript, preview
+
+ def stop(self):
+ pass
+
+
+class TrailingPeriodFilter(PresentationFilter):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ if self.cfg["remove_trailing_period"]:
+ def _remove_trailing_period(s: str) -> str:
+ if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
+ s = s[0:len(s)-1]
+ return s
+ if len(preview) == 0:
+ transcript = _remove_trailing_period(transcript)
+ else:
+ preview = _remove_trailing_period(preview)
+ return transcript, preview
+
+
def transcriptionThread(shared_data: SharedThreadData):
last_stable_commit = None
@@ -621,6 +701,17 @@ def transcriptionThread(shared_data: SharedThreadData):
max_speech_s=shared_data.cfg["max_speech_duration_s"])
committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter)
+ plugins = []
+ # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet
+ plugins.append(UppercasePlugin(shared_data.cfg))
+ plugins.append(LowercasePlugin(shared_data.cfg))
+ plugins.append(ProfanityPlugin(shared_data.cfg))
+ # plugins.append(UwuPlugin(shared_data.cfg)) # Not implemented yet
+ # plugins.append(BrowserSource(shared_data.cfg)) # Not implemented yet
+
+ filters = []
+ filters.append(TrailingPeriodFilter(shared_data.cfg))
+
transcript = ""
preview = ""
@@ -633,6 +724,9 @@ def transcriptionThread(shared_data: SharedThreadData):
commit = committer.getDelta()
+ for plugin in plugins:
+ commit = plugin.transform(commit)
+
if len(commit.delta) > 0 or len(commit.preview) > 0:
# Avoid re-sending text after long pauses
if shared_data.cfg["reset_after_silence_s"] > 0:
@@ -664,6 +758,9 @@ def transcriptionThread(shared_data: SharedThreadData):
transcript = join_segments(transcript, commit.delta)
preview = commit.preview
+ for filt in filters:
+ transcript, preview = filt.transform(transcript, preview)
+
try:
print(f"Transcript: {transcript}", flush=True)
except UnicodeEncodeError:
@@ -691,4 +788,8 @@ def transcriptionThread(shared_data: SharedThreadData):
(not commit.delta.endswith(' ')) and \
(not commit.preview.startswith(' ')):
commit.preview = ' ' + commit.preview
+ for plugin in plugins:
+ plugin.stop()
+ for filt in filters:
+ filt.stop()