summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/lang_compat.py58
-rw-r--r--Scripts/requirements.txt1
-rw-r--r--Scripts/transcribe.py78
3 files changed, 103 insertions, 34 deletions
diff --git a/Scripts/lang_compat.py b/Scripts/lang_compat.py
new file mode 100644
index 0000000..af35921
--- /dev/null
+++ b/Scripts/lang_compat.py
@@ -0,0 +1,58 @@
+# This file provides mappings between language codes used by different
+# third-party libraries.
+
+# Whisper to NLLB.
+whisper_to_nllb = {
+ "catalan": "cat_Ltn", # catalan
+ "czech": "ces_Latn", # czech
+ "danish": "dan_Latn", # danish
+ "dutch": "nld_Latn", # dutch
+ "english": "eng_Latn", # english
+ "finnish": "fin_Latn", # finnish
+ "french": "fra_Latn", # french
+ "german": "deu_Latn", # german
+ "greek": "ell_Grek", # greek
+ "hungarian": "hun_Latn", # hungarian
+ "icelandic": "isl_Latn", # icelandic
+ "italian": "ita_Latn", # italian
+ "latvian": "lvs_Latn", # latvian
+ "lithuanian": "lit_Latn", # lithuanian
+ "norwegian": "nob_Latn", # norwegian (bokmal)
+ "polish": "pol_Latn", # polish
+ "portugese": "por_Latn", # portugese
+ "romanian": "ron_Latn", # romanian
+ "russian": "rus_Cyrl", # russian
+ "slovak": "slk_Latn", # slovak
+ "slovene": "slv_Latn", # slovene
+ "spanish": "spa_Latn", # spanish
+ "swedish": "swe_Latn", # swedish
+ "turkish": "tur_Latn", # turkish
+ }
+
+# NLLB to sentence_splitter (SS).
+nllb_to_ss = {
+ "cat_Ltn": "ca", # catalan
+ "ces_Latn": "cs", # czech
+ "dan_Latn": "da", # danish
+ "nld_Latn": "nl", # dutch
+ "eng_Latn": "en", # english
+ "fin_Latn": "fi", # finnish
+ "fra_Latn": "fr", # french
+ "deu_Latn": "de", # german
+ "ell_Grek": "el", # greek
+ "hun_Latn": "hu", # hungarian
+ "isl_Latn": "is", # icelandic
+ "ita_Latn": "it", # italian
+ "lvs_Latn": "lv", # latvian
+ "lit_Latn": "lt", # lithuanian
+ "nob_Latn": "no", # norwegian (bokmal)
+ "pol_Latn": "pl", # polish
+ "por_Latn": "pt", # portugese
+ "ron_Latn": "ro", # romanian
+ "rus_Cyrl": "ru", # russian
+ "slk_Latn": "sk", # slovak
+ "slv_Latn": "sl", # slovene
+ "spa_Latn": "es", # spanish
+ "swe_Latn": "sv", # swedish
+ "tur_Latn": "tr", # turkish
+ }
diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt
index 5500a91..647e942 100644
--- a/Scripts/requirements.txt
+++ b/Scripts/requirements.txt
@@ -11,4 +11,5 @@ pyaudio
python-osc
playsound==1.2.2
pyyaml
+sentence_splitter
transformers>=4.21.0
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py
index e113be1..fe06631 100644
--- a/Scripts/transcribe.py
+++ b/Scripts/transcribe.py
@@ -5,6 +5,7 @@ from emotes_v2 import EmotesState
from faster_whisper import WhisperModel
from functools import partial
from playsound import playsound
+from sentence_splitter import split_text_into_sentences
import argparse
import copy
@@ -12,6 +13,7 @@ import ctranslate2
import generate_utils
import keybind_event_machine
import keyboard
+import lang_compat
import langcodes
import numpy as np
import os
@@ -71,7 +73,7 @@ class AudioState:
# The language the user is speaking in. Default is English but user may set
# this to whatever they want.
- self.language = "en"
+ self.language = "english"
self.audio_paused = False
@@ -257,6 +259,8 @@ def transcribeAudio(audio_state,
if audio_state.drop_transcription:
audio_state.drop_transcription = False
+ audio_state.text = ""
+ audio_state.filtered_text = ""
print("drop transcription ({} seconds)".format(time.time() - last_transcribe_time))
last_transcribe_time = time.time()
continue
@@ -265,16 +269,30 @@ def transcribeAudio(audio_state,
audio_state.text = string_matcher.matchStrings(audio_state.text,
text, window_size = 25)
+ now = time.time()
+ print("Transcription ({} seconds): {}".format(
+ now - last_transcribe_time,
+ audio_state.text))
+ last_transcribe_time = now
+
# Translate if requested.
- if audio_state.language_source and audio_state.language_target:
- source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text)))
- target_prefix = [audio_state.language_target]
- results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix])
- target = results[0].hypotheses[0][1:]
- translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target))
- print(f"Translated text: {translated}")
- else:
- translated = copy.copy(audio_state.text)
+ translated = audio_state.text
+ if audio_state.language_target:
+ whisper_lang = audio_state.whisper_language
+ nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]
+ ss_lang = lang_compat.nllb_to_ss[nllb_lang]
+ sentences = split_text_into_sentences(translated, language=ss_lang)
+
+ translated_sentences = []
+ for sentence in sentences:
+ source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(sentence))
+ target_prefix = [audio_state.language_target]
+ results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix])
+ target = results[0].hypotheses[0][1:]
+ translated_sentence = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target))
+ translated_sentences.append(translated_sentence)
+ translated = " ".join(translated_sentences)
+ print(f"Translation: {translated}")
# Apply filters to transcription
filtered_text = translated
@@ -296,12 +314,6 @@ def transcribeAudio(audio_state,
filtered_text = filtered_text.lower()
audio_state.filtered_text = filtered_text
- now = time.time()
- print("Transcription ({} seconds): {}".format(
- now - last_transcribe_time,
- audio_state.text))
- last_transcribe_time = now
-
if old_text != audio_state.text:
# We think the user said something, so reset the amount of
# time we sleep between transcriptions to the minimum.
@@ -358,10 +370,10 @@ def readKeyboardInput(audio_state, enable_local_beep: bool,
osc_ctrl.toggleBoard(audio_state.osc_state.client, False)
#playsound(os.path.abspath("../Sounds/Noise_Off_Quiet.wav"))
- resetAudioLocked(audio_state)
- resetDisplayLocked(audio_state)
audio_state.drop_transcription = True
audio_state.audio_paused = True
+ resetAudioLocked(audio_state)
+ resetDisplayLocked(audio_state)
continue
# Short hold
@@ -383,12 +395,12 @@ def readKeyboardInput(audio_state, enable_local_beep: bool,
osc_ctrl.indicateSpeech(audio_state.osc_state.client, True)
osc_ctrl.toggleBoard(audio_state.osc_state.client, True)
osc_ctrl.lockWorld(audio_state.osc_state.client, False)
- resetAudioLocked(audio_state)
- resetDisplayLocked(audio_state)
-
audio_state.drop_transcription = True
audio_state.audio_paused = False
+ resetAudioLocked(audio_state)
+ resetDisplayLocked(audio_state)
+
if enable_local_beep == 1:
playsound(os.path.abspath("Resources/Sounds/Noise_On_Quiet.wav"),
block=False)
@@ -506,7 +518,6 @@ def readControllerInput(audio_state, enable_local_beep: bool,
# whisper/__init__.py. Examples: tiny, base, small, medium.
def transcribeLoop(mic: str,
language: str,
- language_source: str,
language_target: str,
model: str,
model_translation: str,
@@ -523,6 +534,7 @@ def transcribeLoop(mic: str,
gpu_idx: int,
keyboard_hotkey: str):
audio_state = getMicStream(mic)
+ audio_state.whisper_language = language
audio_state.language = langcodes.find(language).language
audio_state.MAX_LENGTH_S = window_duration_s
@@ -532,14 +544,8 @@ def transcribeLoop(mic: str,
audio_state.language_target = lang_code
else:
audio_state.language_target = None
- lang_bits = language_source.split(" | ")
- if len(lang_bits) == 2:
- lang_code = lang_bits[1]
- audio_state.language_source = lang_code
- else:
- audio_state.language_source = None
- if audio_state.language_source and audio_state.language_target:
+ if audio_state.language_target:
print("Translation requested")
print("Installing torch and sentencepiece in virtual environment. "
@@ -579,9 +585,15 @@ def transcribeLoop(mic: str,
print(f"Using model at {output_dir}")
audio_state.translator = ctranslate2.Translator(output_dir)
+
+ whisper_lang = audio_state.whisper_language
+ nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]
+
audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained(
"facebook/" + model_translation,
- src_lang=audio_state.language_source)
+ src_lang=nllb_lang)
+
+ print(f"Translation ready to go")
print("Safe to start talking")
@@ -661,7 +673,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index")
parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.")
- parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices")
parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices")
parser.add_argument("--model", type=str, help="Which transcription model to use. " \
"Options: tiny, tiny.en, base, base.en, small, small.en, " \
@@ -692,8 +703,8 @@ if __name__ == "__main__":
if not args.language:
args.language = "english"
- if not args.language_source or not args.language_target:
- print("--language_source and --language_target required", file=sys.stderr)
+ if not args.language_target:
+ print("--language_target required", file=sys.stderr)
if not args.model:
args.model = "base"
@@ -769,7 +780,6 @@ if __name__ == "__main__":
transcribeLoop(args.mic,
args.language,
- args.language_source,
args.language_target,
args.model,
args.model_translation,