diff options
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 1bdc487..2a206fd 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -583,7 +583,8 @@ class TranslationPlugin(StreamingPlugin): self.language_target = lang_bits[1] print("Translation requested", file=sys.stderr) - if not install_in_venv(["torch", "sentencepiece"]): + # The ctranslate2 model converter needs torch. Grr. + if not install_in_venv(["torch==2.2.2"]): return output_dir = "Resources/" + cfg["model_translation"] @@ -608,7 +609,18 @@ class TranslationPlugin(StreamingPlugin): "{ct2_proc.returncode}", file=sys.stderr) print(f"Using model at {output_dir}", file=sys.stderr) - self.translator = ctranslate2.Translator(output_dir) + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + if cfg["use_flash_attention"]: + print(f"Flash attention disabled on CPU", file=sys.stderr) + cfg["use_flash_attention"] = False + + self.translator = ctranslate2.Translator(output_dir, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = cfg["compute_type"], + flash_attention = cfg["use_flash_attention"]) whisper_lang = cfg["language"] nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] |
