From 75069522ffc8863a356d95e509c81612a3703458 Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 12 Jul 2024 14:48:36 -0700 Subject: Fix translation plugin Translation needs torch to convert the nllb model, but the latest version (2.3.1) has an embedded OMP dll which clashes with ctranslate2's dll. Using the last minor version instead (2.2.2) doesn't clash. Also propagate the device, quantization, and flash attention settings to the translator. If you're using GPU, this is a HUUUUGE performance uplift. Translation is basically instant. The bigger models are now feasible to use. --- Scripts/transcribe_v2.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 1bdc487..2a206fd 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -583,7 +583,8 @@ class TranslationPlugin(StreamingPlugin): self.language_target = lang_bits[1] print("Translation requested", file=sys.stderr) - if not install_in_venv(["torch", "sentencepiece"]): + # The ctranslate2 model converter needs torch. Grr. + if not install_in_venv(["torch==2.2.2"]): return output_dir = "Resources/" + cfg["model_translation"] @@ -608,7 +609,18 @@ class TranslationPlugin(StreamingPlugin): "{ct2_proc.returncode}", file=sys.stderr) print(f"Using model at {output_dir}", file=sys.stderr) - self.translator = ctranslate2.Translator(output_dir) + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + if cfg["use_flash_attention"]: + print(f"Flash attention disabled on CPU", file=sys.stderr) + cfg["use_flash_attention"] = False + + self.translator = ctranslate2.Translator(output_dir, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = cfg["compute_type"], + flash_attention = cfg["use_flash_attention"]) whisper_lang = cfg["language"] nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] -- cgit v1.2.3