summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe.py
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-05-25 21:45:09 -0700
committeryum <yum.food.vr@gmail.com>2023-05-25 22:00:56 -0700
commit84f09e1fdf15644d1ea5f955889581932e4f6a8e (patch)
tree70894da7bc14c773f9755c1838cd87fe7f26b909 /Scripts/transcribe.py
parenteed2e8915d83796679c0b7a3ea9121d329ceddab (diff)
Add ability to translate into 200 languages
Use Meta's No Language Left Behind (NLLB) algorithm to provide translation capabilities into 200 languages. Obviously most are very untested. This requires either 4.1 or 7.1 GB of RAM and significiantly increases transcription latency.
Diffstat (limited to 'Scripts/transcribe.py')
-rw-r--r--Scripts/transcribe.py95
1 files changed, 91 insertions, 4 deletions
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py
index 7ba80dc..e113be1 100644
--- a/Scripts/transcribe.py
+++ b/Scripts/transcribe.py
@@ -8,6 +8,7 @@ from playsound import playsound
import argparse
import copy
+import ctranslate2
import generate_utils
import keybind_event_machine
import keyboard
@@ -22,6 +23,7 @@ import subprocess
import sys
import threading
import time
+import transformers
import wave
class AudioState:
@@ -263,8 +265,19 @@ def transcribeAudio(audio_state,
audio_state.text = string_matcher.matchStrings(audio_state.text,
text, window_size = 25)
+ # Translate if requested.
+ if audio_state.language_source and audio_state.language_target:
+ source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text)))
+ target_prefix = [audio_state.language_target]
+ results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix])
+ target = results[0].hypotheses[0][1:]
+ translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target))
+ print(f"Translated text: {translated}")
+ else:
+ translated = copy.copy(audio_state.text)
+
# Apply filters to transcription
- filtered_text = audio_state.text
+ filtered_text = translated
if enable_uwu_filter:
uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text],
stdout=subprocess.PIPE,
@@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool,
# whisper/__init__.py. Examples: tiny, base, small, medium.
def transcribeLoop(mic: str,
language: str,
+ language_source: str,
+ language_target: str,
model: str,
+ model_translation: str,
enable_local_beep: bool,
use_cpu: bool,
use_builtin: bool,
@@ -510,6 +526,63 @@ def transcribeLoop(mic: str,
audio_state.language = langcodes.find(language).language
audio_state.MAX_LENGTH_S = window_duration_s
+ lang_bits = language_target.split(" | ")
+ if len(lang_bits) == 2:
+ lang_code = lang_bits[1]
+ audio_state.language_target = lang_code
+ else:
+ audio_state.language_target = None
+ lang_bits = language_source.split(" | ")
+ if len(lang_bits) == 2:
+ lang_code = lang_bits[1]
+ audio_state.language_source = lang_code
+ else:
+ audio_state.language_source = None
+
+ if audio_state.language_source and audio_state.language_target:
+ print("Translation requested")
+
+ print("Installing torch and sentencepiece in virtual environment. "
+ "Nothing will print "
+ "for several minutes while these download (~2.4 GB).")
+ pip_proc = subprocess.Popen(
+ "Resources/Python/python.exe -m pip install sentencepiece torch".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout)
+ print(pip_stderr)
+ if pip_proc.returncode != 0:
+ print(f"Failed to set up for translation: `pip install torch` "
+ "exited with {pip_proc.returncode}")
+
+ output_dir = "Resources/" + model_translation
+ # Provided by ctranslate2 Python package
+ cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir
+
+ print(f"Fetching translation algorithm ({model_translation})")
+ if not os.path.exists(output_dir):
+ ct2_proc = subprocess.Popen(
+ cmd.split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ ct2_stdout, ct2_stderr = ct2_proc.communicate()
+ ct2_stdout = ct2_stdout.decode("utf-8")
+ ct2_stderr = ct2_stderr.decode("utf-8")
+ print(ct2_stdout)
+ print(ct2_stderr)
+ if ct2_proc.returncode != 0:
+ print(f"Failed to get NLLB model: ct2 process exited with "
+ "{ct2_proc.returncode}")
+ print(f"Using model at {output_dir}")
+
+ audio_state.translator = ctranslate2.Translator(output_dir)
+ audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ "facebook/" + model_translation,
+ src_lang=audio_state.language_source)
+
print("Safe to start talking")
abspath = os.path.abspath(__file__)
@@ -588,9 +661,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index")
parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.")
- parser.add_argument("--model", type=str, help="Which AI model to use. \
- Options: tiny, tiny.en, base, base.en, small, small.en, \
- medium, medium.en, large-v1, large-v2")
+ parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices")
+ parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices")
+ parser.add_argument("--model", type=str, help="Which transcription model to use. " \
+ "Options: tiny, tiny.en, base, base.en, small, small.en, " \
+ "medium, medium.en, large-v1, large-v2")
+ parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \
+ "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.")
parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character")
parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event")
parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.")
@@ -615,9 +692,16 @@ if __name__ == "__main__":
if not args.language:
args.language = "english"
+ if not args.language_source or not args.language_target:
+ print("--language_source and --language_target required", file=sys.stderr)
+
if not args.model:
args.model = "base"
+ if not args.model_translation:
+ print("--model_translation required.", file=sys.stderr)
+ sys.exit(1)
+
if not args.bytes_per_char or not args.chars_per_sync:
print("--bytes_per_char and --chars_per_sync required", file=sys.stderr)
sys.exit(1)
@@ -685,7 +769,10 @@ if __name__ == "__main__":
transcribeLoop(args.mic,
args.language,
+ args.language_source,
+ args.language_target,
args.model,
+ args.model_translation,
args.enable_local_beep,
args.cpu, args.use_builtin,
args.enable_uwu_filter,