summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/requirements.txt2
-rw-r--r--Scripts/transcribe.py95
2 files changed, 93 insertions, 4 deletions
diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt
index c887808..5500a91 100644
--- a/Scripts/requirements.txt
+++ b/Scripts/requirements.txt
@@ -1,3 +1,4 @@
+ctranslate2
editdistance
faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/358d373691c95205021bd4bbf28cde7ce4d10030.tar.gz
future==0.18.2
@@ -10,3 +11,4 @@ pyaudio
python-osc
playsound==1.2.2
pyyaml
+transformers>=4.21.0
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py
index 7ba80dc..e113be1 100644
--- a/Scripts/transcribe.py
+++ b/Scripts/transcribe.py
@@ -8,6 +8,7 @@ from playsound import playsound
import argparse
import copy
+import ctranslate2
import generate_utils
import keybind_event_machine
import keyboard
@@ -22,6 +23,7 @@ import subprocess
import sys
import threading
import time
+import transformers
import wave
class AudioState:
@@ -263,8 +265,19 @@ def transcribeAudio(audio_state,
audio_state.text = string_matcher.matchStrings(audio_state.text,
text, window_size = 25)
+ # Translate if requested.
+ if audio_state.language_source and audio_state.language_target:
+ source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text)))
+ target_prefix = [audio_state.language_target]
+ results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix])
+ target = results[0].hypotheses[0][1:]
+ translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target))
+ print(f"Translated text: {translated}")
+ else:
+ translated = copy.copy(audio_state.text)
+
# Apply filters to transcription
- filtered_text = audio_state.text
+ filtered_text = translated
if enable_uwu_filter:
uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text],
stdout=subprocess.PIPE,
@@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool,
# whisper/__init__.py. Examples: tiny, base, small, medium.
def transcribeLoop(mic: str,
language: str,
+ language_source: str,
+ language_target: str,
model: str,
+ model_translation: str,
enable_local_beep: bool,
use_cpu: bool,
use_builtin: bool,
@@ -510,6 +526,63 @@ def transcribeLoop(mic: str,
audio_state.language = langcodes.find(language).language
audio_state.MAX_LENGTH_S = window_duration_s
+ lang_bits = language_target.split(" | ")
+ if len(lang_bits) == 2:
+ lang_code = lang_bits[1]
+ audio_state.language_target = lang_code
+ else:
+ audio_state.language_target = None
+ lang_bits = language_source.split(" | ")
+ if len(lang_bits) == 2:
+ lang_code = lang_bits[1]
+ audio_state.language_source = lang_code
+ else:
+ audio_state.language_source = None
+
+ if audio_state.language_source and audio_state.language_target:
+ print("Translation requested")
+
+ print("Installing torch and sentencepiece in virtual environment. "
+ "Nothing will print "
+ "for several minutes while these download (~2.4 GB).")
+ pip_proc = subprocess.Popen(
+ "Resources/Python/python.exe -m pip install sentencepiece torch".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout)
+ print(pip_stderr)
+ if pip_proc.returncode != 0:
+ print(f"Failed to set up for translation: `pip install torch` "
+ "exited with {pip_proc.returncode}")
+
+ output_dir = "Resources/" + model_translation
+ # Provided by ctranslate2 Python package
+ cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir
+
+ print(f"Fetching translation algorithm ({model_translation})")
+ if not os.path.exists(output_dir):
+ ct2_proc = subprocess.Popen(
+ cmd.split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ ct2_stdout, ct2_stderr = ct2_proc.communicate()
+ ct2_stdout = ct2_stdout.decode("utf-8")
+ ct2_stderr = ct2_stderr.decode("utf-8")
+ print(ct2_stdout)
+ print(ct2_stderr)
+ if ct2_proc.returncode != 0:
+ print(f"Failed to get NLLB model: ct2 process exited with "
+ "{ct2_proc.returncode}")
+ print(f"Using model at {output_dir}")
+
+ audio_state.translator = ctranslate2.Translator(output_dir)
+ audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ "facebook/" + model_translation,
+ src_lang=audio_state.language_source)
+
print("Safe to start talking")
abspath = os.path.abspath(__file__)
@@ -588,9 +661,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index")
parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.")
- parser.add_argument("--model", type=str, help="Which AI model to use. \
- Options: tiny, tiny.en, base, base.en, small, small.en, \
- medium, medium.en, large-v1, large-v2")
+ parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices")
+ parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices")
+ parser.add_argument("--model", type=str, help="Which transcription model to use. " \
+ "Options: tiny, tiny.en, base, base.en, small, small.en, " \
+ "medium, medium.en, large-v1, large-v2")
+ parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \
+ "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.")
parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character")
parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event")
parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.")
@@ -615,9 +692,16 @@ if __name__ == "__main__":
if not args.language:
args.language = "english"
+ if not args.language_source or not args.language_target:
+ print("--language_source and --language_target required", file=sys.stderr)
+
if not args.model:
args.model = "base"
+ if not args.model_translation:
+ print("--model_translation required.", file=sys.stderr)
+ sys.exit(1)
+
if not args.bytes_per_char or not args.chars_per_sync:
print("--bytes_per_char and --chars_per_sync required", file=sys.stderr)
sys.exit(1)
@@ -685,7 +769,10 @@ if __name__ == "__main__":
transcribeLoop(args.mic,
args.language,
+ args.language_source,
+ args.language_target,
args.model,
+ args.model_translation,
args.enable_local_beep,
args.cpu, args.use_builtin,
args.enable_uwu_filter,