diff options
| -rw-r--r-- | GUI/GUI/GUI/Config.cpp | 9 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Config.h | 3 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.cpp | 302 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.h | 3 | ||||
| -rw-r--r-- | GUI/GUI/GUI/PythonWrapper.cpp | 5 | ||||
| -rw-r--r-- | Scripts/requirements.txt | 2 | ||||
| -rw-r--r-- | Scripts/transcribe.py | 95 |
7 files changed, 408 insertions, 11 deletions
diff --git a/GUI/GUI/GUI/Config.cpp b/GUI/GUI/GUI/Config.cpp index f0b5a1b..bf0c1c6 100644 --- a/GUI/GUI/GUI/Config.cpp +++ b/GUI/GUI/GUI/Config.cpp @@ -64,7 +64,10 @@ AppConfig::AppConfig(wxTextCtrl* out) microphone("index"),
language("english"),
+ language_source("Do not translate"),
+ language_target("Do not translate"),
model("base.en"),
+ model_translation("nllb-200-distilled-600M"),
button("left joystick"),
window_duration("15"),
@@ -112,7 +115,10 @@ bool AppConfig::Serialize(const std::filesystem::path& path) { cm.Set("microphone", microphone);
cm.Set("language", language);
+ cm.Set("language_source", language_source);
+ cm.Set("language_target", language_target);
cm.Set("model", model);
+ cm.Set("model_translation", model_translation);
cm.Set("button", button);
cm.Set("window_duration", window_duration);
@@ -173,7 +179,10 @@ bool AppConfig::Deserialize(const std::filesystem::path& path) { AppConfig c(out_);
cm.Get("microphone", c.microphone);
cm.Get("language", c.language);
+ cm.Get("language_source", c.language_source);
+ cm.Get("language_target", c.language_target);
cm.Get("model", c.model);
+ cm.Get("model_translation", c.model_translation);
cm.Get("button", c.button);
cm.Get("window_duration", c.window_duration);
diff --git a/GUI/GUI/GUI/Config.h b/GUI/GUI/GUI/Config.h index 01b0239..bc942d4 100644 --- a/GUI/GUI/GUI/Config.h +++ b/GUI/GUI/GUI/Config.h @@ -50,7 +50,10 @@ public: // Transcription-specific settings.
std::string microphone;
std::string language;
+ std::string language_source;
+ std::string language_target;
std::string model;
+ std::string model_translation;
std::string button;
std::string window_duration;
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp index 7e965df..8cebc2d 100644 --- a/GUI/GUI/GUI/Frame.cpp +++ b/GUI/GUI/GUI/Frame.cpp @@ -30,8 +30,11 @@ namespace { ID_PY_APP_MIC,
ID_PY_APP_MIC_PANEL,
ID_PY_APP_LANG,
+ ID_PY_APP_TRANSLATE_SOURCE,
+ ID_PY_APP_TRANSLATE_TARGET,
ID_PY_APP_LANG_PANEL,
ID_PY_APP_MODEL,
+ ID_PY_APP_MODEL_TRANSLATION,
ID_PY_APP_CHARS_PER_SYNC,
ID_PY_APP_BYTES_PER_CHAR,
ID_PY_APP_BUTTON,
@@ -231,6 +234,216 @@ namespace { const size_t kNumLangChoices = sizeof(kLangChoices) / sizeof(kLangChoices[0]);
constexpr int kLangDefault = 0; // english
+ const wxString kLangTargetChoices[] = {
+ "Do not translate",
+ "Acehnese(Arabic script) | ace_Arab",
+ "Acehnese(Latin script) | ace_Latn",
+ "Afrikaans | afr_Latn",
+ "Akan | aka_Latn",
+ "Amharic | amh_Ethi",
+ "Armenian | hye_Armn",
+ "Assamese | asm_Beng",
+ "Asturian | ast_Latn",
+ "Awadhi | awa_Deva",
+ "Ayacucho Quechua | quy_Latn",
+ "Balinese | ban_Latn",
+ "Bambara | bam_Latn",
+ "Banjar(Arabic script) | bjn_Arab",
+ "Banjar(Latin script) | bjn_Latn",
+ "Bashkir | bak_Cyrl",
+ "Basque | eus_Latn",
+ "Belarusian | bel_Cyrl",
+ "Bemba | bem_Latn",
+ "Bengali | ben_Beng",
+ "Bhojpuri | bho_Deva",
+ "Bosnian | bos_Latn",
+ "Buginese | bug_Latn",
+ "Bulgarian | bul_Cyrl",
+ "Burmese | mya_Mymr",
+ "Catalan | cat_Latn",
+ "Cebuano | ceb_Latn",
+ "Central Atlas Tamazight | tzm_Tfng",
+ "Central Aymara | ayr_Latn",
+ "Central Kanuri(Arabic script) | knc_Arab",
+ "Central Kanuri(Latin script) | knc_Latn",
+ "Central Kurdish | ckb_Arab",
+ "Chhattisgarhi | hne_Deva",
+ "Chinese(Simplified) | zho_Hans",
+ "Chinese(Traditional) | zho_Hant",
+ "Chokwe | cjk_Latn",
+ "Crimean Tatar | crh_Latn",
+ "Croatian | hrv_Latn",
+ "Czech | ces_Latn",
+ "Danish | dan_Latn",
+ "Dari | prs_Arab",
+ "Dutch | nld_Latn",
+ "Dyula | dyu_Latn",
+ "Dzongkha | dzo_Tibt",
+ "Eastern Panjabi | pan_Guru",
+ "Eastern Yiddish | ydd_Hebr",
+ "Egyptian Arabic | arz_Arab",
+ "English | eng_Latn",
+ "Esperanto | epo_Latn",
+ "Estonian | est_Latn",
+ "Ewe | ewe_Latn",
+ "Faroese | fao_Latn",
+ "Fijian | fij_Latn",
+ "Finnish | fin_Latn",
+ "Fon | fon_Latn",
+ "French | fra_Latn",
+ "Friulian | fur_Latn",
+ "Galician | glg_Latn",
+ "Ganda | lug_Latn",
+ "Georgian | kat_Geor",
+ "German | deu_Latn",
+ "Greek | ell_Grek",
+ "Guarani | grn_Latn",
+ "Gujarati | guj_Gujr",
+ "Haitian Creole | hat_Latn",
+ "Halh Mongolian | khk_Cyrl",
+ "Hausa | hau_Latn",
+ "Hebrew | heb_Hebr",
+ "Hindi | hin_Deva",
+ "Hungarian | hun_Latn",
+ "Icelandic | isl_Latn",
+ "Igbo | ibo_Latn",
+ "Ilocano | ilo_Latn",
+ "Indonesian | ind_Latn",
+ "Irish | gle_Latn",
+ "Italian | ita_Latn",
+ "Japanese | jpn_Jpan",
+ "Javanese | jav_Latn",
+ "Jingpho | kac_Latn",
+ "Kabiyè | kbp_Latn",
+ "Kabuverdianu | kea_Latn",
+ "Kabyle | kab_Latn",
+ "Kamba | kam_Latn",
+ "Kannada | kan_Knda",
+ "Kashmiri(Arabic script) | kas_Arab",
+ "Kashmiri(Devanagari script) | kas_Deva",
+ "Kazakh | kaz_Cyrl",
+ "Khmer | khm_Khmr",
+ "Kikongo | kon_Latn",
+ "Kikuyu | kik_Latn",
+ "Kimbundu | kmb_Latn",
+ "Kinyarwanda | kin_Latn",
+ "Korean | kor_Hang",
+ "Kyrgyz | kir_Cyrl",
+ "Lao | lao_Laoo",
+ "Latgalian | ltg_Latn",
+ "Ligurian | lij_Latn",
+ "Limburgish | lim_Latn",
+ "Lingala | lin_Latn",
+ "Lithuanian | lit_Latn",
+ "Lombard | lmo_Latn",
+ "Luba - Kasai | lua_Latn",
+ "Luo | luo_Latn",
+ "Luxembourgish | ltz_Latn",
+ "Macedonian | mkd_Cyrl",
+ "Magahi | mag_Deva",
+ "Maithili | mai_Deva",
+ "Malayalam | mal_Mlym",
+ "Maltese | mlt_Latn",
+ "Maori | mri_Latn",
+ "Marathi | mar_Deva",
+ "Meitei(Bengali script) | mni_Beng",
+ "Mesopotamian Arabic | acm_Arab",
+ "Minangkabau(Arabic script) | min_Arab",
+ "Minangkabau(Latin script) | min_Latn",
+ "Mizo | lus_Latn",
+ "Modern Standard Arabic | arb_Arab",
+ "Modern Standard Arabic(Romanized) | arb_Latn",
+ "Moroccan Arabic | ary_Arab",
+ "Mossi | mos_Latn",
+ "Najdi Arabic | ars_Arab",
+ "Nepali | npi_Deva",
+ "Nigerian Fulfulde | fuv_Latn",
+ "North Azerbaijani | azj_Latn",
+ "North Levantine Arabic | apc_Arab",
+ "Northern Kurdish | kmr_Latn",
+ "Northern Sotho | nso_Latn",
+ "Northern Uzbek | uzn_Latn",
+ "Norwegian Bokmål | nob_Latn",
+ "Norwegian Nynorsk | nno_Latn",
+ "Nuer | nus_Latn",
+ "Nyanja | nya_Latn",
+ "Occitan | oci_Latn",
+ "Odia | ory_Orya",
+ "Pangasinan | pag_Latn",
+ "Papiamento | pap_Latn",
+ "Plateau Malagasy | plt_Latn",
+ "Polish | pol_Latn",
+ "Portuguese | por_Latn",
+ "Romanian | ron_Latn",
+ "Rundi | run_Latn",
+ "Russian | rus_Cyrl",
+ "Samoan | smo_Latn",
+ "Sango | sag_Latn",
+ "Sanskrit | san_Deva",
+ "Santali | sat_Olck",
+ "Sardinian | srd_Latn",
+ "Scottish Gaelic | gla_Latn",
+ "Serbian | srp_Cyrl",
+ "Shan | shn_Mymr",
+ "Shona | sna_Latn",
+ "Sicilian | scn_Latn",
+ "Silesian | szl_Latn",
+ "Sindhi | snd_Arab",
+ "Sinhala | sin_Sinh",
+ "Slovak | slk_Latn",
+ "Slovenian | slv_Latn",
+ "Somali | som_Latn",
+ "South Azerbaijani | azb_Arab",
+ "South Levantine Arabic | ajp_Arab",
+ "Southern Pashto | pbt_Arab",
+ "Southern Sotho | sot_Latn",
+ "Southwestern Dinka | dik_Latn",
+ "Spanish | spa_Latn",
+ "Standard Latvian | lvs_Latn",
+ "Standard Malay | zsm_Latn",
+ "Standard Tibetan | bod_Tibt",
+ "Sundanese | sun_Latn",
+ "Swahili | swh_Latn",
+ "Swati | ssw_Latn",
+ "Swedish | swe_Latn",
+ "Ta'izzi - Adeni Arabic | acq_Arab",
+ "Tagalog | tgl_Latn",
+ "Tajik | tgk_Cyrl",
+ "Tamasheq(Latin script) | taq_Latn",
+ "Tamasheq(Tifinagh script) | taq_Tfng",
+ "Tamil | tam_Taml",
+ "Tatar | tat_Cyrl",
+ "Telugu | tel_Telu",
+ "Thai | tha_Thai",
+ "Tigrinya | tir_Ethi",
+ "Tok Pisin | tpi_Latn",
+ "Tosk Albanian | als_Latn",
+ "Tsonga | tso_Latn",
+ "Tswana | tsn_Latn",
+ "Tumbuka | tum_Latn",
+ "Tunisian Arabic | aeb_Arab",
+ "Turkish | tur_Latn",
+ "Turkmen | tuk_Latn",
+ "Twi | twi_Latn",
+ "Ukrainian | ukr_Cyrl",
+ "Umbundu | umb_Latn",
+ "Urdu | urd_Arab",
+ "Uyghur | uig_Arab",
+ "Venetian | vec_Latn",
+ "Vietnamese | vie_Latn",
+ "Waray | war_Latn",
+ "Welsh | cym_Latn",
+ "West Central Oromo | gaz_Latn",
+ "Western Persian | pes_Arab",
+ "Wolof | wol_Latn",
+ "Xhosa | xho_Latn",
+ "Yoruba | yor_Latn",
+ "Yue Chinese | yue_Hant",
+ "Zulu | zul_Latn",
+ };
+ const size_t kNumLangTargetChoices = sizeof(kLangTargetChoices) / sizeof(kLangTargetChoices[0]);
+ constexpr int kLangTargetDefault = 0; // do not translate
+
// lifted from whisper/__init__.py
const wxString kModelChoices[] = {
"tiny.en",
@@ -247,6 +460,13 @@ namespace { const size_t kNumModelChoices = sizeof(kModelChoices) / sizeof(kModelChoices[0]);
constexpr int kModelDefault = 2; // base.en
+ const wxString kModelTranslationChoices[] = {
+ "nllb-200-distilled-600M",
+ "nllb-200-distilled-1.3B",
+ };
+ const size_t kNumModelTranslationChoices = sizeof(kModelTranslationChoices) / sizeof(kModelTranslationChoices[0]);
+ constexpr int kModelTranslationDefault = 2; // base.en
+
const wxString kCharsPerSync[] = {
"5",
"6",
@@ -413,13 +633,28 @@ Frame::Frame() ID_PY_APP_LANG, wxDefaultPosition, wxDefaultSize,
kNumLangChoices, kLangChoices);
py_app_lang->SetToolTip("Select which language you will "
- "speak in. It will be transcribed into that language. "
- "If using a language with non-ASCII characters (i.e. "
- "not English), make sure you have 'bytes per char' "
- "set to 2. If using something other than English, "
+ "speak in. If using something other than English, "
"make sure you're not using a *.en model.");
py_app_lang_ = py_app_lang;
+ auto* py_app_translate_source = new wxChoice(py_app_config_panel_pairs,
+ ID_PY_APP_TRANSLATE_SOURCE, wxDefaultPosition, wxDefaultSize,
+ kNumLangTargetChoices, kLangTargetChoices);
+ py_app_translate_source->SetToolTip("Select which language to "
+ "translate from, in other words, the language you are transcribing into.");
+ py_app_translate_source_ = py_app_translate_source;
+
+ auto* py_app_translate_target = new wxChoice(py_app_config_panel_pairs,
+ ID_PY_APP_TRANSLATE_TARGET, wxDefaultPosition, wxDefaultSize,
+ kNumLangTargetChoices, kLangTargetChoices);
+ py_app_translate_target->SetToolTip("Select which "
+ "language to translate to. This is the language "
+ "that will appear in game. "
+ "If using a language with non-ASCII characters (i.e. "
+ "not English), make sure you have 'bytes per char' "
+ "set to 2.");
+ py_app_translate_target_ = py_app_translate_target;
+
auto* py_app_model = new wxChoice(
py_app_config_panel_pairs, ID_PY_APP_MODEL,
wxDefaultPosition, wxDefaultSize, kNumModelChoices,
@@ -432,6 +667,16 @@ Frame::Frame() "don't work for other languages.");
py_app_model_ = py_app_model;
+ auto* py_app_model_translation = new wxChoice(
+ py_app_config_panel_pairs, ID_PY_APP_MODEL_TRANSLATION,
+ wxDefaultPosition, wxDefaultSize, kNumModelTranslationChoices,
+ kModelTranslationChoices);
+ py_app_model_translation->SetToolTip("Select which "
+ "version of the translation model to use. 600M params "
+ "uses 4.1 GB of memory, while 1.3B uses ~7GB of "
+ "memory.");
+ py_app_model_translation_ = py_app_model_translation;
+
auto* py_app_chars_per_sync = new wxChoice(
py_app_config_panel_pairs, ID_PY_APP_CHARS_PER_SYNC,
wxDefaultPosition, wxDefaultSize, kNumCharsPerSync,
@@ -517,16 +762,31 @@ Frame::Frame() /*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
- wxID_ANY, /*label=*/"Language:"));
+ wxID_ANY, /*label=*/"Spoken language:"));
sizer->Add(py_app_lang, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
- wxID_ANY, /*label=*/"Model:"));
+ wxID_ANY, /*label=*/"Translate from:"));
+ sizer->Add(py_app_translate_source, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Translate to:"));
+ sizer->Add(py_app_translate_target, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Transcription model:"));
sizer->Add(py_app_model, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Translation model:"));
+ sizer->Add(py_app_model_translation, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
wxID_ANY, /*label=*/"Characters per sync:"));
sizer->Add(py_app_chars_per_sync, /*proportion=*/0,
/*flags=*/wxEXPAND);
@@ -1084,11 +1344,26 @@ void Frame::ApplyConfigToInputFields() kNumLangChoices, app_c_->language, kLangDefault);
py_app_lang->SetSelection(lang_idx);
+ auto* py_app_translate_source = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_TRANSLATE_SOURCE));
+ int translate_source_idx = GetDropdownChoiceIndex(kLangTargetChoices,
+ kNumLangTargetChoices, app_c_->language_source, kLangTargetDefault);
+ py_app_translate_source->SetSelection(translate_source_idx);
+
+ auto* py_app_translate_target = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_TRANSLATE_TARGET));
+ int translate_target_idx = GetDropdownChoiceIndex(kLangTargetChoices,
+ kNumLangTargetChoices, app_c_->language_target, kLangTargetDefault);
+ py_app_translate_target->SetSelection(translate_target_idx);
+
auto* py_app_model = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_MODEL));
int model_idx = GetDropdownChoiceIndex(kModelChoices,
kNumModelChoices, app_c_->model, kModelDefault);
py_app_model->SetSelection(model_idx);
+ auto* py_app_model_translation = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_MODEL_TRANSLATION));
+ int model_translation_idx = GetDropdownChoiceIndex(kModelTranslationChoices,
+ kNumModelTranslationChoices, app_c_->model_translation, kModelTranslationDefault);
+ py_app_model_translation->SetSelection(model_translation_idx);
+
auto* py_app_chars_per_sync = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_CHARS_PER_SYNC));
int chars_idx = GetDropdownChoiceIndex(kCharsPerSync,
kNumCharsPerSync, std::to_string(app_c_->chars_per_sync),
@@ -1693,10 +1968,22 @@ void Frame::OnAppStart(wxCommandEvent& event) { if (which_lang == wxNOT_FOUND) {
which_lang = kLangDefault;
}
+ int which_translate_source = py_app_translate_source_->GetSelection();
+ if (which_translate_source == wxNOT_FOUND) {
+ which_translate_source = kLangDefault;
+ }
+ int which_translate_target = py_app_translate_target_->GetSelection();
+ if (which_translate_target == wxNOT_FOUND) {
+ which_translate_target = kLangDefault;
+ }
int which_model = py_app_model_->GetSelection();
if (which_model == wxNOT_FOUND) {
which_model = kModelDefault;
}
+ int which_model_translation = py_app_model_translation_->GetSelection();
+ if (which_model_translation == wxNOT_FOUND) {
+ which_model_translation = kModelTranslationDefault;
+ }
int chars_per_sync_idx = py_app_chars_per_sync_->GetSelection();
if (chars_per_sync_idx == wxNOT_FOUND) {
chars_per_sync_idx = kCharsDefault;
@@ -1774,7 +2061,10 @@ void Frame::OnAppStart(wxCommandEvent& event) { app_c_->microphone = kMicChoices[which_mic].ToStdString();
app_c_->language = kLangChoices[which_lang].ToStdString();
+ app_c_->language_source = kLangTargetChoices[which_translate_source].ToStdString();
+ app_c_->language_target = kLangTargetChoices[which_translate_target].ToStdString();
app_c_->model = kModelChoices[which_model].ToStdString();
+ app_c_->model_translation = kModelTranslationChoices[which_model_translation].ToStdString();
app_c_->chars_per_sync = chars_per_sync;
app_c_->bytes_per_char = bytes_per_char;
app_c_->button = kButton[button_idx].ToStdString();
diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h index f5a8cd9..24a0594 100644 --- a/GUI/GUI/GUI/Frame.h +++ b/GUI/GUI/GUI/Frame.h @@ -62,7 +62,10 @@ private: wxChoice* py_app_mic_;
wxChoice* py_app_lang_;
+ wxChoice* py_app_translate_source_;
+ wxChoice* py_app_translate_target_;
wxChoice* py_app_model_;
+ wxChoice* py_app_model_translation_;
// TODO(yum) figure out how to deduplicate these objects
wxChoice* py_app_chars_per_sync_;
wxChoice* py_app_bytes_per_char_;
diff --git a/GUI/GUI/GUI/PythonWrapper.cpp b/GUI/GUI/GUI/PythonWrapper.cpp index 5df1dfa..b95d6b3 100644 --- a/GUI/GUI/GUI/PythonWrapper.cpp +++ b/GUI/GUI/GUI/PythonWrapper.cpp @@ -462,8 +462,11 @@ std::future<bool> PythonWrapper::StartApp( "-u", // Unbuffered output "Resources/Scripts/transcribe.py", "--mic", config.microphone, - "--lang", config.language, + "--language", config.language, + "--language_source", Quote(config.language_source), + "--language_target", Quote(config.language_target), "--model", config.model, + "--model_translation", config.model_translation, "--chars_per_sync", std::to_string(config.chars_per_sync), "--bytes_per_char", std::to_string(config.bytes_per_char), "--button", Quote(config.button), diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt index c887808..5500a91 100644 --- a/Scripts/requirements.txt +++ b/Scripts/requirements.txt @@ -1,3 +1,4 @@ +ctranslate2 editdistance faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/358d373691c95205021bd4bbf28cde7ce4d10030.tar.gz future==0.18.2 @@ -10,3 +11,4 @@ pyaudio python-osc playsound==1.2.2 pyyaml +transformers>=4.21.0 diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 7ba80dc..e113be1 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -8,6 +8,7 @@ from playsound import playsound import argparse import copy +import ctranslate2 import generate_utils import keybind_event_machine import keyboard @@ -22,6 +23,7 @@ import subprocess import sys import threading import time +import transformers import wave class AudioState: @@ -263,8 +265,19 @@ def transcribeAudio(audio_state, audio_state.text = string_matcher.matchStrings(audio_state.text, text, window_size = 25) + # Translate if requested. + if audio_state.language_source and audio_state.language_target: + source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text))) + target_prefix = [audio_state.language_target] + results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target)) + print(f"Translated text: {translated}") + else: + translated = copy.copy(audio_state.text) + # Apply filters to transcription - filtered_text = audio_state.text + filtered_text = translated if enable_uwu_filter: uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text], stdout=subprocess.PIPE, @@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool, # whisper/__init__.py. Examples: tiny, base, small, medium. def transcribeLoop(mic: str, language: str, + language_source: str, + language_target: str, model: str, + model_translation: str, enable_local_beep: bool, use_cpu: bool, use_builtin: bool, @@ -510,6 +526,63 @@ def transcribeLoop(mic: str, audio_state.language = langcodes.find(language).language audio_state.MAX_LENGTH_S = window_duration_s + lang_bits = language_target.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_target = lang_code + else: + audio_state.language_target = None + lang_bits = language_source.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_source = lang_code + else: + audio_state.language_source = None + + if audio_state.language_source and audio_state.language_target: + print("Translation requested") + + print("Installing torch and sentencepiece in virtual environment. " + "Nothing will print " + "for several minutes while these download (~2.4 GB).") + pip_proc = subprocess.Popen( + "Resources/Python/python.exe -m pip install sentencepiece torch".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout) + print(pip_stderr) + if pip_proc.returncode != 0: + print(f"Failed to set up for translation: `pip install torch` " + "exited with {pip_proc.returncode}") + + output_dir = "Resources/" + model_translation + # Provided by ctranslate2 Python package + cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir + + print(f"Fetching translation algorithm ({model_translation})") + if not os.path.exists(output_dir): + ct2_proc = subprocess.Popen( + cmd.split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + ct2_stdout, ct2_stderr = ct2_proc.communicate() + ct2_stdout = ct2_stdout.decode("utf-8") + ct2_stderr = ct2_stderr.decode("utf-8") + print(ct2_stdout) + print(ct2_stderr) + if ct2_proc.returncode != 0: + print(f"Failed to get NLLB model: ct2 process exited with " + "{ct2_proc.returncode}") + print(f"Using model at {output_dir}") + + audio_state.translator = ctranslate2.Translator(output_dir) + audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/" + model_translation, + src_lang=audio_state.language_source) + print("Safe to start talking") abspath = os.path.abspath(__file__) @@ -588,9 +661,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index") parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.") - parser.add_argument("--model", type=str, help="Which AI model to use. \ - Options: tiny, tiny.en, base, base.en, small, small.en, \ - medium, medium.en, large-v1, large-v2") + parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--model", type=str, help="Which transcription model to use. " \ + "Options: tiny, tiny.en, base, base.en, small, small.en, " \ + "medium, medium.en, large-v1, large-v2") + parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \ + "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.") parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character") parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event") parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.") @@ -615,9 +692,16 @@ if __name__ == "__main__": if not args.language: args.language = "english" + if not args.language_source or not args.language_target: + print("--language_source and --language_target required", file=sys.stderr) + if not args.model: args.model = "base" + if not args.model_translation: + print("--model_translation required.", file=sys.stderr) + sys.exit(1) + if not args.bytes_per_char or not args.chars_per_sync: print("--bytes_per_char and --chars_per_sync required", file=sys.stderr) sys.exit(1) @@ -685,7 +769,10 @@ if __name__ == "__main__": transcribeLoop(args.mic, args.language, + args.language_source, + args.language_target, args.model, + args.model_translation, args.enable_local_beep, args.cpu, args.use_builtin, args.enable_uwu_filter, |
