diff options
| author | yum <yum.food.vr@gmail.com> | 2023-05-25 21:45:09 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-05-25 22:00:56 -0700 |
| commit | 84f09e1fdf15644d1ea5f955889581932e4f6a8e (patch) | |
| tree | 70894da7bc14c773f9755c1838cd87fe7f26b909 /Scripts/transcribe.py | |
| parent | eed2e8915d83796679c0b7a3ea9121d329ceddab (diff) | |
Add ability to translate into 200 languages
Use Meta's No Language Left Behind (NLLB) algorithm to provide
translation capabilities into 200 languages. Obviously most are very
untested.
This requires either 4.1 or 7.1 GB of RAM and significiantly increases
transcription latency.
Diffstat (limited to 'Scripts/transcribe.py')
| -rw-r--r-- | Scripts/transcribe.py | 95 |
1 files changed, 91 insertions, 4 deletions
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 7ba80dc..e113be1 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -8,6 +8,7 @@ from playsound import playsound import argparse import copy +import ctranslate2 import generate_utils import keybind_event_machine import keyboard @@ -22,6 +23,7 @@ import subprocess import sys import threading import time +import transformers import wave class AudioState: @@ -263,8 +265,19 @@ def transcribeAudio(audio_state, audio_state.text = string_matcher.matchStrings(audio_state.text, text, window_size = 25) + # Translate if requested. + if audio_state.language_source and audio_state.language_target: + source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text))) + target_prefix = [audio_state.language_target] + results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target)) + print(f"Translated text: {translated}") + else: + translated = copy.copy(audio_state.text) + # Apply filters to transcription - filtered_text = audio_state.text + filtered_text = translated if enable_uwu_filter: uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text], stdout=subprocess.PIPE, @@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool, # whisper/__init__.py. Examples: tiny, base, small, medium. def transcribeLoop(mic: str, language: str, + language_source: str, + language_target: str, model: str, + model_translation: str, enable_local_beep: bool, use_cpu: bool, use_builtin: bool, @@ -510,6 +526,63 @@ def transcribeLoop(mic: str, audio_state.language = langcodes.find(language).language audio_state.MAX_LENGTH_S = window_duration_s + lang_bits = language_target.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_target = lang_code + else: + audio_state.language_target = None + lang_bits = language_source.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_source = lang_code + else: + audio_state.language_source = None + + if audio_state.language_source and audio_state.language_target: + print("Translation requested") + + print("Installing torch and sentencepiece in virtual environment. " + "Nothing will print " + "for several minutes while these download (~2.4 GB).") + pip_proc = subprocess.Popen( + "Resources/Python/python.exe -m pip install sentencepiece torch".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout) + print(pip_stderr) + if pip_proc.returncode != 0: + print(f"Failed to set up for translation: `pip install torch` " + "exited with {pip_proc.returncode}") + + output_dir = "Resources/" + model_translation + # Provided by ctranslate2 Python package + cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir + + print(f"Fetching translation algorithm ({model_translation})") + if not os.path.exists(output_dir): + ct2_proc = subprocess.Popen( + cmd.split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + ct2_stdout, ct2_stderr = ct2_proc.communicate() + ct2_stdout = ct2_stdout.decode("utf-8") + ct2_stderr = ct2_stderr.decode("utf-8") + print(ct2_stdout) + print(ct2_stderr) + if ct2_proc.returncode != 0: + print(f"Failed to get NLLB model: ct2 process exited with " + "{ct2_proc.returncode}") + print(f"Using model at {output_dir}") + + audio_state.translator = ctranslate2.Translator(output_dir) + audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/" + model_translation, + src_lang=audio_state.language_source) + print("Safe to start talking") abspath = os.path.abspath(__file__) @@ -588,9 +661,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index") parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.") - parser.add_argument("--model", type=str, help="Which AI model to use. \ - Options: tiny, tiny.en, base, base.en, small, small.en, \ - medium, medium.en, large-v1, large-v2") + parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--model", type=str, help="Which transcription model to use. " \ + "Options: tiny, tiny.en, base, base.en, small, small.en, " \ + "medium, medium.en, large-v1, large-v2") + parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \ + "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.") parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character") parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event") parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.") @@ -615,9 +692,16 @@ if __name__ == "__main__": if not args.language: args.language = "english" + if not args.language_source or not args.language_target: + print("--language_source and --language_target required", file=sys.stderr) + if not args.model: args.model = "base" + if not args.model_translation: + print("--model_translation required.", file=sys.stderr) + sys.exit(1) + if not args.bytes_per_char or not args.chars_per_sync: print("--bytes_per_char and --chars_per_sync required", file=sys.stderr) sys.exit(1) @@ -685,7 +769,10 @@ if __name__ == "__main__": transcribeLoop(args.mic, args.language, + args.language_source, + args.language_target, args.model, + args.model_translation, args.enable_local_beep, args.cpu, args.use_builtin, args.enable_uwu_filter, |
