summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-06-28 21:24:56 -0700
committeryum <yum.food.vr@gmail.com>2023-06-28 21:24:56 -0700
commitbdaeb1911297d7901a12e3ac51b38c3463789279 (patch)
treeb6151c80100db635ca0c4479d8b1afde838b579e /Scripts
parentff7eb3c212195af71cd0ce4a3cd0c9a081d6ebda (diff)
Add profanity filter
Add toggle to UI to enable a profanity filter. It replaces vowels in bad words with asterisks. Bugfix: filters now apply to OBS
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/transcribe.py40
1 files changed, 35 insertions, 5 deletions
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py
index 28b6ca0..d937cb6 100644
--- a/Scripts/transcribe.py
+++ b/Scripts/transcribe.py
@@ -6,6 +6,7 @@ from faster_whisper import WhisperModel
from functools import partial
from math import ceil
from playsound import playsound
+from profanity_filter import ProfanityFilter
from sentence_splitter import split_text_into_sentences
import argparse
@@ -67,6 +68,11 @@ class AudioState:
# gets appended to `text`.
self.commit_fuzz_threshold = 8
+ # If set, profanity in transcriptions will have their vowels replaced
+ # with asterisks. Only works in English.
+ self.enable_profanity_filter = False
+ self.profanity_filter: ProfanityFilter = None
+
# List of:
# List of tuples of:
# Segment start time, end time, and text
@@ -154,7 +160,7 @@ def onAudioFramesAvailable(
return (frames, pyaudio.paContinue)
-def getMicStream(which_mic):
+def getMicStream(which_mic) -> AudioState:
audio_state = AudioState()
audio_state.p = pyaudio.PyAudio()
@@ -346,10 +352,11 @@ def transcribeAudio(audio_state,
audio_state.preview_text = audio_state.text + preview_text
now = time.time()
- print("Transcription ({} seconds): {}".format(
- now - last_transcribe_time,
- audio_state.preview_text))
- last_transcribe_time = now
+ if audio_state.enable_debug_mode:
+ print("Raw transcription ({} seconds): {}".format(
+ now - last_transcribe_time,
+ audio_state.preview_text))
+ last_transcribe_time = now
# Translate if requested.
translated = audio_state.preview_text
@@ -388,8 +395,16 @@ def transcribeAudio(audio_state,
filtered_text = filtered_text.upper()
if enable_lowercase_filter:
filtered_text = filtered_text.lower()
+ if audio_state.enable_profanity_filter:
+ filtered_text = audio_state.profanity_filter.filter(filtered_text)
audio_state.filtered_text = filtered_text
+ now = time.time()
+ print("Transcription ({} seconds): {}".format(
+ now - last_transcribe_time,
+ filtered_text))
+ last_transcribe_time = now
+
if old_text != audio_state.preview_text:
# We think the user said something, so reset the amount of
# time we sleep between transcriptions to the minimum.
@@ -618,6 +633,7 @@ def transcribeLoop(mic: str,
remove_trailing_period: bool,
enable_uppercase_filter: bool,
enable_lowercase_filter: bool,
+ enable_profanity_filter: bool,
enable_debug_mode: bool,
button: str,
estate: EmotesState,
@@ -633,6 +649,13 @@ def transcribeLoop(mic: str,
audio_state.reset_on_toggle = reset_on_toggle
audio_state.commit_fuzz_threshold = commit_fuzz_threshold
audio_state.enable_debug_mode = enable_debug_mode
+ audio_state.enable_profanity_filter = enable_profanity_filter
+
+ # Set up profanity filter
+ en_profanity_path = os.path.abspath("Resources/Profanity/en")
+ audio_state.profanity_filter = ProfanityFilter(en_profanity_path)
+ if enable_profanity_filter:
+ audio_state.profanity_filter.load()
lang_bits = language_target.split(" | ")
if len(lang_bits) == 2:
@@ -780,6 +803,7 @@ if __name__ == "__main__":
parser.add_argument("--remove_trailing_period", type=int, help="If set to 1, trailing period will be removed.")
parser.add_argument("--enable_uppercase_filter", type=int, help="If set to 1, transcriptions will be converted to UPPERCASE.")
parser.add_argument("--enable_lowercase_filter", type=int, help="If set to 1, transcriptions will be converted to lowercase.")
+ parser.add_argument("--enable_profanity_filter", type=int, help="If set to 1, profanity in transcriptions will have their vowels replaced with asterisks. Only works in English.")
parser.add_argument("--button", type=str, help="The controller button used to start/stop transcription. E.g. \"left joystick\"")
parser.add_argument("--emotes_pickle", type=str, help="The path to emotes pickle. See emotes_v2.py for details.")
parser.add_argument("--gpu_idx", type=str, help="The index of the GPU device to use. On single GPU systems, use 0.")
@@ -870,6 +894,11 @@ if __name__ == "__main__":
else:
args.enable_lowercase_filter = False
+ if args.enable_profanity_filter == 1:
+ args.enable_profanity_filter = True
+ else:
+ args.enable_profanity_filter = False
+
if args.enable_debug_mode == 1:
args.enable_debug_mode = True
else:
@@ -896,6 +925,7 @@ if __name__ == "__main__":
args.remove_trailing_period,
args.enable_uppercase_filter,
args.enable_lowercase_filter,
+ args.enable_profanity_filter,
args.enable_debug_mode,
args.button,
estate, window_duration_s,