diff options
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/requirements.txt | 2 | ||||
| -rw-r--r-- | Scripts/transcribe.py | 95 |
2 files changed, 93 insertions, 4 deletions
diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt index c887808..5500a91 100644 --- a/Scripts/requirements.txt +++ b/Scripts/requirements.txt @@ -1,3 +1,4 @@ +ctranslate2 editdistance faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/358d373691c95205021bd4bbf28cde7ce4d10030.tar.gz future==0.18.2 @@ -10,3 +11,4 @@ pyaudio python-osc playsound==1.2.2 pyyaml +transformers>=4.21.0 diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 7ba80dc..e113be1 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -8,6 +8,7 @@ from playsound import playsound import argparse import copy +import ctranslate2 import generate_utils import keybind_event_machine import keyboard @@ -22,6 +23,7 @@ import subprocess import sys import threading import time +import transformers import wave class AudioState: @@ -263,8 +265,19 @@ def transcribeAudio(audio_state, audio_state.text = string_matcher.matchStrings(audio_state.text, text, window_size = 25) + # Translate if requested. + if audio_state.language_source and audio_state.language_target: + source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text))) + target_prefix = [audio_state.language_target] + results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target)) + print(f"Translated text: {translated}") + else: + translated = copy.copy(audio_state.text) + # Apply filters to transcription - filtered_text = audio_state.text + filtered_text = translated if enable_uwu_filter: uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text], stdout=subprocess.PIPE, @@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool, # whisper/__init__.py. Examples: tiny, base, small, medium. def transcribeLoop(mic: str, language: str, + language_source: str, + language_target: str, model: str, + model_translation: str, enable_local_beep: bool, use_cpu: bool, use_builtin: bool, @@ -510,6 +526,63 @@ def transcribeLoop(mic: str, audio_state.language = langcodes.find(language).language audio_state.MAX_LENGTH_S = window_duration_s + lang_bits = language_target.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_target = lang_code + else: + audio_state.language_target = None + lang_bits = language_source.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_source = lang_code + else: + audio_state.language_source = None + + if audio_state.language_source and audio_state.language_target: + print("Translation requested") + + print("Installing torch and sentencepiece in virtual environment. " + "Nothing will print " + "for several minutes while these download (~2.4 GB).") + pip_proc = subprocess.Popen( + "Resources/Python/python.exe -m pip install sentencepiece torch".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout) + print(pip_stderr) + if pip_proc.returncode != 0: + print(f"Failed to set up for translation: `pip install torch` " + "exited with {pip_proc.returncode}") + + output_dir = "Resources/" + model_translation + # Provided by ctranslate2 Python package + cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir + + print(f"Fetching translation algorithm ({model_translation})") + if not os.path.exists(output_dir): + ct2_proc = subprocess.Popen( + cmd.split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + ct2_stdout, ct2_stderr = ct2_proc.communicate() + ct2_stdout = ct2_stdout.decode("utf-8") + ct2_stderr = ct2_stderr.decode("utf-8") + print(ct2_stdout) + print(ct2_stderr) + if ct2_proc.returncode != 0: + print(f"Failed to get NLLB model: ct2 process exited with " + "{ct2_proc.returncode}") + print(f"Using model at {output_dir}") + + audio_state.translator = ctranslate2.Translator(output_dir) + audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/" + model_translation, + src_lang=audio_state.language_source) + print("Safe to start talking") abspath = os.path.abspath(__file__) @@ -588,9 +661,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index") parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.") - parser.add_argument("--model", type=str, help="Which AI model to use. \ - Options: tiny, tiny.en, base, base.en, small, small.en, \ - medium, medium.en, large-v1, large-v2") + parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--model", type=str, help="Which transcription model to use. " \ + "Options: tiny, tiny.en, base, base.en, small, small.en, " \ + "medium, medium.en, large-v1, large-v2") + parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \ + "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.") parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character") parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event") parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.") @@ -615,9 +692,16 @@ if __name__ == "__main__": if not args.language: args.language = "english" + if not args.language_source or not args.language_target: + print("--language_source and --language_target required", file=sys.stderr) + if not args.model: args.model = "base" + if not args.model_translation: + print("--model_translation required.", file=sys.stderr) + sys.exit(1) + if not args.bytes_per_char or not args.chars_per_sync: print("--bytes_per_char and --chars_per_sync required", file=sys.stderr) sys.exit(1) @@ -685,7 +769,10 @@ if __name__ == "__main__": transcribeLoop(args.mic, args.language, + args.language_source, + args.language_target, args.model, + args.model_translation, args.enable_local_beep, args.cpu, args.use_builtin, args.enable_uwu_filter, |
