summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2024-07-12 14:48:36 -0700
committeryum <yum.food.vr@gmail.com>2024-07-12 14:48:36 -0700
commit75069522ffc8863a356d95e509c81612a3703458 (patch)
tree8abebad0dc35facf668a5a1777b7a4bcd1b5085c /Scripts
parentfef509ce622641079c08c5c9e602315e8d207868 (diff)
Fix translation plugin
Translation needs torch to convert the nllb model, but the latest version (2.3.1) has an embedded OMP dll which clashes with ctranslate2's dll. Using the last minor version instead (2.2.2) doesn't clash. Also propagate the device, quantization, and flash attention settings to the translator. If you're using GPU, this is a HUUUUGE performance uplift. Translation is basically instant. The bigger models are now feasible to use.
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/transcribe_v2.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 1bdc487..2a206fd 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -583,7 +583,8 @@ class TranslationPlugin(StreamingPlugin):
self.language_target = lang_bits[1]
print("Translation requested", file=sys.stderr)
- if not install_in_venv(["torch", "sentencepiece"]):
+ # The ctranslate2 model converter needs torch. Grr.
+ if not install_in_venv(["torch==2.2.2"]):
return
output_dir = "Resources/" + cfg["model_translation"]
@@ -608,7 +609,18 @@ class TranslationPlugin(StreamingPlugin):
"{ct2_proc.returncode}", file=sys.stderr)
print(f"Using model at {output_dir}", file=sys.stderr)
- self.translator = ctranslate2.Translator(output_dir)
+ model_device = "cuda"
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+ if cfg["use_flash_attention"]:
+ print(f"Flash attention disabled on CPU", file=sys.stderr)
+ cfg["use_flash_attention"] = False
+
+ self.translator = ctranslate2.Translator(output_dir,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = cfg["compute_type"],
+ flash_attention = cfg["use_flash_attention"])
whisper_lang = cfg["language"]
nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]