diff options
| author | yum <yum.food.vr@gmail.com> | 2024-07-12 14:48:36 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2024-07-12 14:48:36 -0700 |
| commit | 75069522ffc8863a356d95e509c81612a3703458 (patch) | |
| tree | 8abebad0dc35facf668a5a1777b7a4bcd1b5085c | |
| parent | fef509ce622641079c08c5c9e602315e8d207868 (diff) | |
Fix translation plugin
Translation needs torch to convert the nllb model, but the latest
version (2.3.1) has an embedded OMP dll which clashes with ctranslate2's
dll. Using the last minor version instead (2.2.2) doesn't clash.
Also propagate the device, quantization, and flash attention settings
to the translator. If you're using GPU, this is a HUUUUGE performance
uplift. Translation is basically instant. The bigger models are now
feasible to use.
| -rw-r--r-- | Scripts/transcribe_v2.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 1bdc487..2a206fd 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -583,7 +583,8 @@ class TranslationPlugin(StreamingPlugin): self.language_target = lang_bits[1] print("Translation requested", file=sys.stderr) - if not install_in_venv(["torch", "sentencepiece"]): + # The ctranslate2 model converter needs torch. Grr. + if not install_in_venv(["torch==2.2.2"]): return output_dir = "Resources/" + cfg["model_translation"] @@ -608,7 +609,18 @@ class TranslationPlugin(StreamingPlugin): "{ct2_proc.returncode}", file=sys.stderr) print(f"Using model at {output_dir}", file=sys.stderr) - self.translator = ctranslate2.Translator(output_dir) + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + if cfg["use_flash_attention"]: + print(f"Flash attention disabled on CPU", file=sys.stderr) + cfg["use_flash_attention"] = False + + self.translator = ctranslate2.Translator(output_dir, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = cfg["compute_type"], + flash_attention = cfg["use_flash_attention"]) whisper_lang = cfg["language"] nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] |
