diff options
| author | yum <yum.food.vr@gmail.com> | 2023-05-25 21:45:09 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-05-25 22:00:56 -0700 |
| commit | 84f09e1fdf15644d1ea5f955889581932e4f6a8e (patch) | |
| tree | 70894da7bc14c773f9755c1838cd87fe7f26b909 | |
| parent | eed2e8915d83796679c0b7a3ea9121d329ceddab (diff) | |
Add ability to translate into 200 languages
Use Meta's No Language Left Behind (NLLB) algorithm to provide
translation capabilities into 200 languages. Obviously most are very
untested.
This requires either 4.1 or 7.1 GB of RAM and significiantly increases
transcription latency.
| -rw-r--r-- | GUI/GUI/GUI/Config.cpp | 9 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Config.h | 3 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.cpp | 302 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.h | 3 | ||||
| -rw-r--r-- | GUI/GUI/GUI/PythonWrapper.cpp | 5 | ||||
| -rw-r--r-- | Scripts/requirements.txt | 2 | ||||
| -rw-r--r-- | Scripts/transcribe.py | 95 |
7 files changed, 408 insertions, 11 deletions
diff --git a/GUI/GUI/GUI/Config.cpp b/GUI/GUI/GUI/Config.cpp index f0b5a1b..bf0c1c6 100644 --- a/GUI/GUI/GUI/Config.cpp +++ b/GUI/GUI/GUI/Config.cpp @@ -64,7 +64,10 @@ AppConfig::AppConfig(wxTextCtrl* out) microphone("index"),
language("english"),
+ language_source("Do not translate"),
+ language_target("Do not translate"),
model("base.en"),
+ model_translation("nllb-200-distilled-600M"),
button("left joystick"),
window_duration("15"),
@@ -112,7 +115,10 @@ bool AppConfig::Serialize(const std::filesystem::path& path) { cm.Set("microphone", microphone);
cm.Set("language", language);
+ cm.Set("language_source", language_source);
+ cm.Set("language_target", language_target);
cm.Set("model", model);
+ cm.Set("model_translation", model_translation);
cm.Set("button", button);
cm.Set("window_duration", window_duration);
@@ -173,7 +179,10 @@ bool AppConfig::Deserialize(const std::filesystem::path& path) { AppConfig c(out_);
cm.Get("microphone", c.microphone);
cm.Get("language", c.language);
+ cm.Get("language_source", c.language_source);
+ cm.Get("language_target", c.language_target);
cm.Get("model", c.model);
+ cm.Get("model_translation", c.model_translation);
cm.Get("button", c.button);
cm.Get("window_duration", c.window_duration);
diff --git a/GUI/GUI/GUI/Config.h b/GUI/GUI/GUI/Config.h index 01b0239..bc942d4 100644 --- a/GUI/GUI/GUI/Config.h +++ b/GUI/GUI/GUI/Config.h @@ -50,7 +50,10 @@ public: // Transcription-specific settings.
std::string microphone;
std::string language;
+ std::string language_source;
+ std::string language_target;
std::string model;
+ std::string model_translation;
std::string button;
std::string window_duration;
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp index 7e965df..8cebc2d 100644 --- a/GUI/GUI/GUI/Frame.cpp +++ b/GUI/GUI/GUI/Frame.cpp @@ -30,8 +30,11 @@ namespace { ID_PY_APP_MIC,
ID_PY_APP_MIC_PANEL,
ID_PY_APP_LANG,
+ ID_PY_APP_TRANSLATE_SOURCE,
+ ID_PY_APP_TRANSLATE_TARGET,
ID_PY_APP_LANG_PANEL,
ID_PY_APP_MODEL,
+ ID_PY_APP_MODEL_TRANSLATION,
ID_PY_APP_CHARS_PER_SYNC,
ID_PY_APP_BYTES_PER_CHAR,
ID_PY_APP_BUTTON,
@@ -231,6 +234,216 @@ namespace { const size_t kNumLangChoices = sizeof(kLangChoices) / sizeof(kLangChoices[0]);
constexpr int kLangDefault = 0; // english
+ const wxString kLangTargetChoices[] = {
+ "Do not translate",
+ "Acehnese(Arabic script) | ace_Arab",
+ "Acehnese(Latin script) | ace_Latn",
+ "Afrikaans | afr_Latn",
+ "Akan | aka_Latn",
+ "Amharic | amh_Ethi",
+ "Armenian | hye_Armn",
+ "Assamese | asm_Beng",
+ "Asturian | ast_Latn",
+ "Awadhi | awa_Deva",
+ "Ayacucho Quechua | quy_Latn",
+ "Balinese | ban_Latn",
+ "Bambara | bam_Latn",
+ "Banjar(Arabic script) | bjn_Arab",
+ "Banjar(Latin script) | bjn_Latn",
+ "Bashkir | bak_Cyrl",
+ "Basque | eus_Latn",
+ "Belarusian | bel_Cyrl",
+ "Bemba | bem_Latn",
+ "Bengali | ben_Beng",
+ "Bhojpuri | bho_Deva",
+ "Bosnian | bos_Latn",
+ "Buginese | bug_Latn",
+ "Bulgarian | bul_Cyrl",
+ "Burmese | mya_Mymr",
+ "Catalan | cat_Latn",
+ "Cebuano | ceb_Latn",
+ "Central Atlas Tamazight | tzm_Tfng",
+ "Central Aymara | ayr_Latn",
+ "Central Kanuri(Arabic script) | knc_Arab",
+ "Central Kanuri(Latin script) | knc_Latn",
+ "Central Kurdish | ckb_Arab",
+ "Chhattisgarhi | hne_Deva",
+ "Chinese(Simplified) | zho_Hans",
+ "Chinese(Traditional) | zho_Hant",
+ "Chokwe | cjk_Latn",
+ "Crimean Tatar | crh_Latn",
+ "Croatian | hrv_Latn",
+ "Czech | ces_Latn",
+ "Danish | dan_Latn",
+ "Dari | prs_Arab",
+ "Dutch | nld_Latn",
+ "Dyula | dyu_Latn",
+ "Dzongkha | dzo_Tibt",
+ "Eastern Panjabi | pan_Guru",
+ "Eastern Yiddish | ydd_Hebr",
+ "Egyptian Arabic | arz_Arab",
+ "English | eng_Latn",
+ "Esperanto | epo_Latn",
+ "Estonian | est_Latn",
+ "Ewe | ewe_Latn",
+ "Faroese | fao_Latn",
+ "Fijian | fij_Latn",
+ "Finnish | fin_Latn",
+ "Fon | fon_Latn",
+ "French | fra_Latn",
+ "Friulian | fur_Latn",
+ "Galician | glg_Latn",
+ "Ganda | lug_Latn",
+ "Georgian | kat_Geor",
+ "German | deu_Latn",
+ "Greek | ell_Grek",
+ "Guarani | grn_Latn",
+ "Gujarati | guj_Gujr",
+ "Haitian Creole | hat_Latn",
+ "Halh Mongolian | khk_Cyrl",
+ "Hausa | hau_Latn",
+ "Hebrew | heb_Hebr",
+ "Hindi | hin_Deva",
+ "Hungarian | hun_Latn",
+ "Icelandic | isl_Latn",
+ "Igbo | ibo_Latn",
+ "Ilocano | ilo_Latn",
+ "Indonesian | ind_Latn",
+ "Irish | gle_Latn",
+ "Italian | ita_Latn",
+ "Japanese | jpn_Jpan",
+ "Javanese | jav_Latn",
+ "Jingpho | kac_Latn",
+ "Kabiyè | kbp_Latn",
+ "Kabuverdianu | kea_Latn",
+ "Kabyle | kab_Latn",
+ "Kamba | kam_Latn",
+ "Kannada | kan_Knda",
+ "Kashmiri(Arabic script) | kas_Arab",
+ "Kashmiri(Devanagari script) | kas_Deva",
+ "Kazakh | kaz_Cyrl",
+ "Khmer | khm_Khmr",
+ "Kikongo | kon_Latn",
+ "Kikuyu | kik_Latn",
+ "Kimbundu | kmb_Latn",
+ "Kinyarwanda | kin_Latn",
+ "Korean | kor_Hang",
+ "Kyrgyz | kir_Cyrl",
+ "Lao | lao_Laoo",
+ "Latgalian | ltg_Latn",
+ "Ligurian | lij_Latn",
+ "Limburgish | lim_Latn",
+ "Lingala | lin_Latn",
+ "Lithuanian | lit_Latn",
+ "Lombard | lmo_Latn",
+ "Luba - Kasai | lua_Latn",
+ "Luo | luo_Latn",
+ "Luxembourgish | ltz_Latn",
+ "Macedonian | mkd_Cyrl",
+ "Magahi | mag_Deva",
+ "Maithili | mai_Deva",
+ "Malayalam | mal_Mlym",
+ "Maltese | mlt_Latn",
+ "Maori | mri_Latn",
+ "Marathi | mar_Deva",
+ "Meitei(Bengali script) | mni_Beng",
+ "Mesopotamian Arabic | acm_Arab",
+ "Minangkabau(Arabic script) | min_Arab",
+ "Minangkabau(Latin script) | min_Latn",
+ "Mizo | lus_Latn",
+ "Modern Standard Arabic | arb_Arab",
+ "Modern Standard Arabic(Romanized) | arb_Latn",
+ "Moroccan Arabic | ary_Arab",
+ "Mossi | mos_Latn",
+ "Najdi Arabic | ars_Arab",
+ "Nepali | npi_Deva",
+ "Nigerian Fulfulde | fuv_Latn",
+ "North Azerbaijani | azj_Latn",
+ "North Levantine Arabic | apc_Arab",
+ "Northern Kurdish | kmr_Latn",
+ "Northern Sotho | nso_Latn",
+ "Northern Uzbek | uzn_Latn",
+ "Norwegian Bokmål | nob_Latn",
+ "Norwegian Nynorsk | nno_Latn",
+ "Nuer | nus_Latn",
+ "Nyanja | nya_Latn",
+ "Occitan | oci_Latn",
+ "Odia | ory_Orya",
+ "Pangasinan | pag_Latn",
+ "Papiamento | pap_Latn",
+ "Plateau Malagasy | plt_Latn",
+ "Polish | pol_Latn",
+ "Portuguese | por_Latn",
+ "Romanian | ron_Latn",
+ "Rundi | run_Latn",
+ "Russian | rus_Cyrl",
+ "Samoan | smo_Latn",
+ "Sango | sag_Latn",
+ "Sanskrit | san_Deva",
+ "Santali | sat_Olck",
+ "Sardinian | srd_Latn",
+ "Scottish Gaelic | gla_Latn",
+ "Serbian | srp_Cyrl",
+ "Shan | shn_Mymr",
+ "Shona | sna_Latn",
+ "Sicilian | scn_Latn",
+ "Silesian | szl_Latn",
+ "Sindhi | snd_Arab",
+ "Sinhala | sin_Sinh",
+ "Slovak | slk_Latn",
+ "Slovenian | slv_Latn",
+ "Somali | som_Latn",
+ "South Azerbaijani | azb_Arab",
+ "South Levantine Arabic | ajp_Arab",
+ "Southern Pashto | pbt_Arab",
+ "Southern Sotho | sot_Latn",
+ "Southwestern Dinka | dik_Latn",
+ "Spanish | spa_Latn",
+ "Standard Latvian | lvs_Latn",
+ "Standard Malay | zsm_Latn",
+ "Standard Tibetan | bod_Tibt",
+ "Sundanese | sun_Latn",
+ "Swahili | swh_Latn",
+ "Swati | ssw_Latn",
+ "Swedish | swe_Latn",
+ "Ta'izzi - Adeni Arabic | acq_Arab",
+ "Tagalog | tgl_Latn",
+ "Tajik | tgk_Cyrl",
+ "Tamasheq(Latin script) | taq_Latn",
+ "Tamasheq(Tifinagh script) | taq_Tfng",
+ "Tamil | tam_Taml",
+ "Tatar | tat_Cyrl",
+ "Telugu | tel_Telu",
+ "Thai | tha_Thai",
+ "Tigrinya | tir_Ethi",
+ "Tok Pisin | tpi_Latn",
+ "Tosk Albanian | als_Latn",
+ "Tsonga | tso_Latn",
+ "Tswana | tsn_Latn",
+ "Tumbuka | tum_Latn",
+ "Tunisian Arabic | aeb_Arab",
+ "Turkish | tur_Latn",
+ "Turkmen | tuk_Latn",
+ "Twi | twi_Latn",
+ "Ukrainian | ukr_Cyrl",
+ "Umbundu | umb_Latn",
+ "Urdu | urd_Arab",
+ "Uyghur | uig_Arab",
+ "Venetian | vec_Latn",
+ "Vietnamese | vie_Latn",
+ "Waray | war_Latn",
+ "Welsh | cym_Latn",
+ "West Central Oromo | gaz_Latn",
+ "Western Persian | pes_Arab",
+ "Wolof | wol_Latn",
+ "Xhosa | xho_Latn",
+ "Yoruba | yor_Latn",
+ "Yue Chinese | yue_Hant",
+ "Zulu | zul_Latn",
+ };
+ const size_t kNumLangTargetChoices = sizeof(kLangTargetChoices) / sizeof(kLangTargetChoices[0]);
+ constexpr int kLangTargetDefault = 0; // do not translate
+
// lifted from whisper/__init__.py
const wxString kModelChoices[] = {
"tiny.en",
@@ -247,6 +460,13 @@ namespace { const size_t kNumModelChoices = sizeof(kModelChoices) / sizeof(kModelChoices[0]);
constexpr int kModelDefault = 2; // base.en
+ const wxString kModelTranslationChoices[] = {
+ "nllb-200-distilled-600M",
+ "nllb-200-distilled-1.3B",
+ };
+ const size_t kNumModelTranslationChoices = sizeof(kModelTranslationChoices) / sizeof(kModelTranslationChoices[0]);
+ constexpr int kModelTranslationDefault = 2; // base.en
+
const wxString kCharsPerSync[] = {
"5",
"6",
@@ -413,13 +633,28 @@ Frame::Frame() ID_PY_APP_LANG, wxDefaultPosition, wxDefaultSize,
kNumLangChoices, kLangChoices);
py_app_lang->SetToolTip("Select which language you will "
- "speak in. It will be transcribed into that language. "
- "If using a language with non-ASCII characters (i.e. "
- "not English), make sure you have 'bytes per char' "
- "set to 2. If using something other than English, "
+ "speak in. If using something other than English, "
"make sure you're not using a *.en model.");
py_app_lang_ = py_app_lang;
+ auto* py_app_translate_source = new wxChoice(py_app_config_panel_pairs,
+ ID_PY_APP_TRANSLATE_SOURCE, wxDefaultPosition, wxDefaultSize,
+ kNumLangTargetChoices, kLangTargetChoices);
+ py_app_translate_source->SetToolTip("Select which language to "
+ "translate from, in other words, the language you are transcribing into.");
+ py_app_translate_source_ = py_app_translate_source;
+
+ auto* py_app_translate_target = new wxChoice(py_app_config_panel_pairs,
+ ID_PY_APP_TRANSLATE_TARGET, wxDefaultPosition, wxDefaultSize,
+ kNumLangTargetChoices, kLangTargetChoices);
+ py_app_translate_target->SetToolTip("Select which "
+ "language to translate to. This is the language "
+ "that will appear in game. "
+ "If using a language with non-ASCII characters (i.e. "
+ "not English), make sure you have 'bytes per char' "
+ "set to 2.");
+ py_app_translate_target_ = py_app_translate_target;
+
auto* py_app_model = new wxChoice(
py_app_config_panel_pairs, ID_PY_APP_MODEL,
wxDefaultPosition, wxDefaultSize, kNumModelChoices,
@@ -432,6 +667,16 @@ Frame::Frame() "don't work for other languages.");
py_app_model_ = py_app_model;
+ auto* py_app_model_translation = new wxChoice(
+ py_app_config_panel_pairs, ID_PY_APP_MODEL_TRANSLATION,
+ wxDefaultPosition, wxDefaultSize, kNumModelTranslationChoices,
+ kModelTranslationChoices);
+ py_app_model_translation->SetToolTip("Select which "
+ "version of the translation model to use. 600M params "
+ "uses 4.1 GB of memory, while 1.3B uses ~7GB of "
+ "memory.");
+ py_app_model_translation_ = py_app_model_translation;
+
auto* py_app_chars_per_sync = new wxChoice(
py_app_config_panel_pairs, ID_PY_APP_CHARS_PER_SYNC,
wxDefaultPosition, wxDefaultSize, kNumCharsPerSync,
@@ -517,16 +762,31 @@ Frame::Frame() /*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
- wxID_ANY, /*label=*/"Language:"));
+ wxID_ANY, /*label=*/"Spoken language:"));
sizer->Add(py_app_lang, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
- wxID_ANY, /*label=*/"Model:"));
+ wxID_ANY, /*label=*/"Translate from:"));
+ sizer->Add(py_app_translate_source, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Translate to:"));
+ sizer->Add(py_app_translate_target, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Transcription model:"));
sizer->Add(py_app_model, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(new wxStaticText(py_app_config_panel_pairs,
+ wxID_ANY, /*label=*/"Translation model:"));
+ sizer->Add(py_app_model_translation, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
+
+ sizer->Add(new wxStaticText(py_app_config_panel_pairs,
wxID_ANY, /*label=*/"Characters per sync:"));
sizer->Add(py_app_chars_per_sync, /*proportion=*/0,
/*flags=*/wxEXPAND);
@@ -1084,11 +1344,26 @@ void Frame::ApplyConfigToInputFields() kNumLangChoices, app_c_->language, kLangDefault);
py_app_lang->SetSelection(lang_idx);
+ auto* py_app_translate_source = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_TRANSLATE_SOURCE));
+ int translate_source_idx = GetDropdownChoiceIndex(kLangTargetChoices,
+ kNumLangTargetChoices, app_c_->language_source, kLangTargetDefault);
+ py_app_translate_source->SetSelection(translate_source_idx);
+
+ auto* py_app_translate_target = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_TRANSLATE_TARGET));
+ int translate_target_idx = GetDropdownChoiceIndex(kLangTargetChoices,
+ kNumLangTargetChoices, app_c_->language_target, kLangTargetDefault);
+ py_app_translate_target->SetSelection(translate_target_idx);
+
auto* py_app_model = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_MODEL));
int model_idx = GetDropdownChoiceIndex(kModelChoices,
kNumModelChoices, app_c_->model, kModelDefault);
py_app_model->SetSelection(model_idx);
+ auto* py_app_model_translation = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_MODEL_TRANSLATION));
+ int model_translation_idx = GetDropdownChoiceIndex(kModelTranslationChoices,
+ kNumModelTranslationChoices, app_c_->model_translation, kModelTranslationDefault);
+ py_app_model_translation->SetSelection(model_translation_idx);
+
auto* py_app_chars_per_sync = static_cast<wxChoice*>(FindWindowById(ID_PY_APP_CHARS_PER_SYNC));
int chars_idx = GetDropdownChoiceIndex(kCharsPerSync,
kNumCharsPerSync, std::to_string(app_c_->chars_per_sync),
@@ -1693,10 +1968,22 @@ void Frame::OnAppStart(wxCommandEvent& event) { if (which_lang == wxNOT_FOUND) {
which_lang = kLangDefault;
}
+ int which_translate_source = py_app_translate_source_->GetSelection();
+ if (which_translate_source == wxNOT_FOUND) {
+ which_translate_source = kLangDefault;
+ }
+ int which_translate_target = py_app_translate_target_->GetSelection();
+ if (which_translate_target == wxNOT_FOUND) {
+ which_translate_target = kLangDefault;
+ }
int which_model = py_app_model_->GetSelection();
if (which_model == wxNOT_FOUND) {
which_model = kModelDefault;
}
+ int which_model_translation = py_app_model_translation_->GetSelection();
+ if (which_model_translation == wxNOT_FOUND) {
+ which_model_translation = kModelTranslationDefault;
+ }
int chars_per_sync_idx = py_app_chars_per_sync_->GetSelection();
if (chars_per_sync_idx == wxNOT_FOUND) {
chars_per_sync_idx = kCharsDefault;
@@ -1774,7 +2061,10 @@ void Frame::OnAppStart(wxCommandEvent& event) { app_c_->microphone = kMicChoices[which_mic].ToStdString();
app_c_->language = kLangChoices[which_lang].ToStdString();
+ app_c_->language_source = kLangTargetChoices[which_translate_source].ToStdString();
+ app_c_->language_target = kLangTargetChoices[which_translate_target].ToStdString();
app_c_->model = kModelChoices[which_model].ToStdString();
+ app_c_->model_translation = kModelTranslationChoices[which_model_translation].ToStdString();
app_c_->chars_per_sync = chars_per_sync;
app_c_->bytes_per_char = bytes_per_char;
app_c_->button = kButton[button_idx].ToStdString();
diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h index f5a8cd9..24a0594 100644 --- a/GUI/GUI/GUI/Frame.h +++ b/GUI/GUI/GUI/Frame.h @@ -62,7 +62,10 @@ private: wxChoice* py_app_mic_;
wxChoice* py_app_lang_;
+ wxChoice* py_app_translate_source_;
+ wxChoice* py_app_translate_target_;
wxChoice* py_app_model_;
+ wxChoice* py_app_model_translation_;
// TODO(yum) figure out how to deduplicate these objects
wxChoice* py_app_chars_per_sync_;
wxChoice* py_app_bytes_per_char_;
diff --git a/GUI/GUI/GUI/PythonWrapper.cpp b/GUI/GUI/GUI/PythonWrapper.cpp index 5df1dfa..b95d6b3 100644 --- a/GUI/GUI/GUI/PythonWrapper.cpp +++ b/GUI/GUI/GUI/PythonWrapper.cpp @@ -462,8 +462,11 @@ std::future<bool> PythonWrapper::StartApp( "-u", // Unbuffered output "Resources/Scripts/transcribe.py", "--mic", config.microphone, - "--lang", config.language, + "--language", config.language, + "--language_source", Quote(config.language_source), + "--language_target", Quote(config.language_target), "--model", config.model, + "--model_translation", config.model_translation, "--chars_per_sync", std::to_string(config.chars_per_sync), "--bytes_per_char", std::to_string(config.bytes_per_char), "--button", Quote(config.button), diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt index c887808..5500a91 100644 --- a/Scripts/requirements.txt +++ b/Scripts/requirements.txt @@ -1,3 +1,4 @@ +ctranslate2 editdistance faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/358d373691c95205021bd4bbf28cde7ce4d10030.tar.gz future==0.18.2 @@ -10,3 +11,4 @@ pyaudio python-osc playsound==1.2.2 pyyaml +transformers>=4.21.0 diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 7ba80dc..e113be1 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -8,6 +8,7 @@ from playsound import playsound import argparse import copy +import ctranslate2 import generate_utils import keybind_event_machine import keyboard @@ -22,6 +23,7 @@ import subprocess import sys import threading import time +import transformers import wave class AudioState: @@ -263,8 +265,19 @@ def transcribeAudio(audio_state, audio_state.text = string_matcher.matchStrings(audio_state.text, text, window_size = 25) + # Translate if requested. + if audio_state.language_source and audio_state.language_target: + source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(copy.copy(audio_state.text))) + target_prefix = [audio_state.language_target] + results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + translated = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target)) + print(f"Translated text: {translated}") + else: + translated = copy.copy(audio_state.text) + # Apply filters to transcription - filtered_text = audio_state.text + filtered_text = translated if enable_uwu_filter: uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text], stdout=subprocess.PIPE, @@ -493,7 +506,10 @@ def readControllerInput(audio_state, enable_local_beep: bool, # whisper/__init__.py. Examples: tiny, base, small, medium. def transcribeLoop(mic: str, language: str, + language_source: str, + language_target: str, model: str, + model_translation: str, enable_local_beep: bool, use_cpu: bool, use_builtin: bool, @@ -510,6 +526,63 @@ def transcribeLoop(mic: str, audio_state.language = langcodes.find(language).language audio_state.MAX_LENGTH_S = window_duration_s + lang_bits = language_target.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_target = lang_code + else: + audio_state.language_target = None + lang_bits = language_source.split(" | ") + if len(lang_bits) == 2: + lang_code = lang_bits[1] + audio_state.language_source = lang_code + else: + audio_state.language_source = None + + if audio_state.language_source and audio_state.language_target: + print("Translation requested") + + print("Installing torch and sentencepiece in virtual environment. " + "Nothing will print " + "for several minutes while these download (~2.4 GB).") + pip_proc = subprocess.Popen( + "Resources/Python/python.exe -m pip install sentencepiece torch".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout) + print(pip_stderr) + if pip_proc.returncode != 0: + print(f"Failed to set up for translation: `pip install torch` " + "exited with {pip_proc.returncode}") + + output_dir = "Resources/" + model_translation + # Provided by ctranslate2 Python package + cmd = "ct2-transformers-converter.exe --model facebook/" + model_translation + " --output_dir " + output_dir + + print(f"Fetching translation algorithm ({model_translation})") + if not os.path.exists(output_dir): + ct2_proc = subprocess.Popen( + cmd.split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + ct2_stdout, ct2_stderr = ct2_proc.communicate() + ct2_stdout = ct2_stdout.decode("utf-8") + ct2_stderr = ct2_stderr.decode("utf-8") + print(ct2_stdout) + print(ct2_stderr) + if ct2_proc.returncode != 0: + print(f"Failed to get NLLB model: ct2 process exited with " + "{ct2_proc.returncode}") + print(f"Using model at {output_dir}") + + audio_state.translator = ctranslate2.Translator(output_dir) + audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/" + model_translation, + src_lang=audio_state.language_source) + print("Safe to start talking") abspath = os.path.abspath(__file__) @@ -588,9 +661,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mic", type=str, help="Which mic to use. Options: index, focusrite. Default: index") parser.add_argument("--language", type=str, help="Which language to use. Ex: english, japanese, chinese, french, german.") - parser.add_argument("--model", type=str, help="Which AI model to use. \ - Options: tiny, tiny.en, base, base.en, small, small.en, \ - medium, medium.en, large-v1, large-v2") + parser.add_argument("--language_source", type=str, help="Which language to translate from. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--language_target", type=str, help="Which language to translate into. See kLangTargetChoices in Frame.cpp for valid choices") + parser.add_argument("--model", type=str, help="Which transcription model to use. " \ + "Options: tiny, tiny.en, base, base.en, small, small.en, " \ + "medium, medium.en, large-v1, large-v2") + parser.add_argument("--model_translation", type=str, help="Which translation model to use. " \ + "Options: nllb-200-distilled-600M, nllb-200-distilled-1.3B.") parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character") parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event") parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.") @@ -615,9 +692,16 @@ if __name__ == "__main__": if not args.language: args.language = "english" + if not args.language_source or not args.language_target: + print("--language_source and --language_target required", file=sys.stderr) + if not args.model: args.model = "base" + if not args.model_translation: + print("--model_translation required.", file=sys.stderr) + sys.exit(1) + if not args.bytes_per_char or not args.chars_per_sync: print("--bytes_per_char and --chars_per_sync required", file=sys.stderr) sys.exit(1) @@ -685,7 +769,10 @@ if __name__ == "__main__": transcribeLoop(args.mic, args.language, + args.language_source, + args.language_target, args.model, + args.model_translation, args.enable_local_beep, args.cpu, args.use_builtin, args.enable_uwu_filter, |
