summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 1bdc487..2a206fd 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -583,7 +583,8 @@ class TranslationPlugin(StreamingPlugin):
self.language_target = lang_bits[1]
print("Translation requested", file=sys.stderr)
- if not install_in_venv(["torch", "sentencepiece"]):
+ # The ctranslate2 model converter needs torch. Grr.
+ if not install_in_venv(["torch==2.2.2"]):
return
output_dir = "Resources/" + cfg["model_translation"]
@@ -608,7 +609,18 @@ class TranslationPlugin(StreamingPlugin):
"{ct2_proc.returncode}", file=sys.stderr)
print(f"Using model at {output_dir}", file=sys.stderr)
- self.translator = ctranslate2.Translator(output_dir)
+ model_device = "cuda"
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+ if cfg["use_flash_attention"]:
+ print(f"Flash attention disabled on CPU", file=sys.stderr)
+ cfg["use_flash_attention"] = False
+
+ self.translator = ctranslate2.Translator(output_dir,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = cfg["compute_type"],
+ flash_attention = cfg["use_flash_attention"])
whisper_lang = cfg["language"]
nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]