diff options
| author | yum <yum.food.vr@gmail.com> | 2024-11-16 15:27:25 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2024-11-16 15:52:24 -0800 |
| commit | 953c03d21d6ef75d115a25ff83e2e3a706b685a6 (patch) | |
| tree | 0fc15d8be147a021b6e9aaaf50749a7aa9f191df /Scripts | |
| parent | 673d701ea471daebecb1fb0c1edf79a2017a78ac (diff) | |
Remove flash_attention toggle
Deprecated in the Python release of CTranslate2 as of 4.4.0:
https://github.com/OpenNMT/CTranslate2/blob/master/CHANGELOG.md#v440-2024-09-09
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/transcribe_v2.py | 12 |
1 files changed, 2 insertions, 10 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 63d7f52..e024bae 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -423,9 +423,6 @@ class Whisper: model_device = "cuda" if cfg["use_cpu"]: model_device = "cpu" - if cfg["use_flash_attention"]: - print(f"Flash attention disabled on CPU", file=sys.stderr) - cfg["use_flash_attention"] = False already_downloaded = os.path.exists(model_root) @@ -434,8 +431,7 @@ class Whisper: device_index = cfg["gpu_idx"], compute_type = cfg["compute_type"], download_root = model_root, - local_files_only = already_downloaded, - flash_attention = cfg["use_flash_attention"]) + local_files_only = already_downloaded) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -612,15 +608,11 @@ class TranslationPlugin(StreamingPlugin): model_device = "cuda" if cfg["use_cpu"]: model_device = "cpu" - if cfg["use_flash_attention"]: - print(f"Flash attention disabled on CPU", file=sys.stderr) - cfg["use_flash_attention"] = False self.translator = ctranslate2.Translator(output_dir, device = model_device, device_index = cfg["gpu_idx"], - compute_type = cfg["compute_type"], - flash_attention = cfg["use_flash_attention"]) + compute_type = cfg["compute_type"]) whisper_lang = cfg["language"] nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] |
