From bdaeb1911297d7901a12e3ac51b38c3463789279 Mon Sep 17 00:00:00 2001 From: yum Date: Wed, 28 Jun 2023 21:24:56 -0700 Subject: Add profanity filter Add toggle to UI to enable a profanity filter. It replaces vowels in bad words with asterisks. Bugfix: filters now apply to OBS --- Scripts/transcribe.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) (limited to 'Scripts') diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 28b6ca0..d937cb6 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -6,6 +6,7 @@ from faster_whisper import WhisperModel from functools import partial from math import ceil from playsound import playsound +from profanity_filter import ProfanityFilter from sentence_splitter import split_text_into_sentences import argparse @@ -67,6 +68,11 @@ class AudioState: # gets appended to `text`. self.commit_fuzz_threshold = 8 + # If set, profanity in transcriptions will have their vowels replaced + # with asterisks. Only works in English. + self.enable_profanity_filter = False + self.profanity_filter: ProfanityFilter = None + # List of: # List of tuples of: # Segment start time, end time, and text @@ -154,7 +160,7 @@ def onAudioFramesAvailable( return (frames, pyaudio.paContinue) -def getMicStream(which_mic): +def getMicStream(which_mic) -> AudioState: audio_state = AudioState() audio_state.p = pyaudio.PyAudio() @@ -346,10 +352,11 @@ def transcribeAudio(audio_state, audio_state.preview_text = audio_state.text + preview_text now = time.time() - print("Transcription ({} seconds): {}".format( - now - last_transcribe_time, - audio_state.preview_text)) - last_transcribe_time = now + if audio_state.enable_debug_mode: + print("Raw transcription ({} seconds): {}".format( + now - last_transcribe_time, + audio_state.preview_text)) + last_transcribe_time = now # Translate if requested. translated = audio_state.preview_text @@ -388,8 +395,16 @@ def transcribeAudio(audio_state, filtered_text = filtered_text.upper() if enable_lowercase_filter: filtered_text = filtered_text.lower() + if audio_state.enable_profanity_filter: + filtered_text = audio_state.profanity_filter.filter(filtered_text) audio_state.filtered_text = filtered_text + now = time.time() + print("Transcription ({} seconds): {}".format( + now - last_transcribe_time, + filtered_text)) + last_transcribe_time = now + if old_text != audio_state.preview_text: # We think the user said something, so reset the amount of # time we sleep between transcriptions to the minimum. @@ -618,6 +633,7 @@ def transcribeLoop(mic: str, remove_trailing_period: bool, enable_uppercase_filter: bool, enable_lowercase_filter: bool, + enable_profanity_filter: bool, enable_debug_mode: bool, button: str, estate: EmotesState, @@ -633,6 +649,13 @@ def transcribeLoop(mic: str, audio_state.reset_on_toggle = reset_on_toggle audio_state.commit_fuzz_threshold = commit_fuzz_threshold audio_state.enable_debug_mode = enable_debug_mode + audio_state.enable_profanity_filter = enable_profanity_filter + + # Set up profanity filter + en_profanity_path = os.path.abspath("Resources/Profanity/en") + audio_state.profanity_filter = ProfanityFilter(en_profanity_path) + if enable_profanity_filter: + audio_state.profanity_filter.load() lang_bits = language_target.split(" | ") if len(lang_bits) == 2: @@ -780,6 +803,7 @@ if __name__ == "__main__": parser.add_argument("--remove_trailing_period", type=int, help="If set to 1, trailing period will be removed.") parser.add_argument("--enable_uppercase_filter", type=int, help="If set to 1, transcriptions will be converted to UPPERCASE.") parser.add_argument("--enable_lowercase_filter", type=int, help="If set to 1, transcriptions will be converted to lowercase.") + parser.add_argument("--enable_profanity_filter", type=int, help="If set to 1, profanity in transcriptions will have their vowels replaced with asterisks. Only works in English.") parser.add_argument("--button", type=str, help="The controller button used to start/stop transcription. E.g. \"left joystick\"") parser.add_argument("--emotes_pickle", type=str, help="The path to emotes pickle. See emotes_v2.py for details.") parser.add_argument("--gpu_idx", type=str, help="The index of the GPU device to use. On single GPU systems, use 0.") @@ -870,6 +894,11 @@ if __name__ == "__main__": else: args.enable_lowercase_filter = False + if args.enable_profanity_filter == 1: + args.enable_profanity_filter = True + else: + args.enable_profanity_filter = False + if args.enable_debug_mode == 1: args.enable_debug_mode = True else: @@ -896,6 +925,7 @@ if __name__ == "__main__": args.remove_trailing_period, args.enable_uppercase_filter, args.enable_lowercase_filter, + args.enable_profanity_filter, args.enable_debug_mode, args.button, estate, window_duration_s, -- cgit v1.2.3