diff options
| author | yum <yum.food.vr@gmail.com> | 2024-06-09 15:54:30 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2024-06-09 15:54:30 -0700 |
| commit | f2b21dd5afebd6b76b5835168f7d1bd3bec21f5d (patch) | |
| tree | 9a72c443088efdd7fdac31ab3ea69b756e75f7df /Scripts | |
| parent | 72b9fb8337cfb7bddc58f74b8977e4a2283e6728 (diff) | |
Add checkbox for flash-attention
Pre-3000 series GPUs don't support it. Oops!
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/transcribe_v2.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 32deb42..2f37945 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -426,13 +426,15 @@ class Whisper: already_downloaded = os.path.exists(model_root) + print(f"Use flash attention {cfg['use_flash_attention']}") + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], compute_type = cfg["compute_type"], download_root = model_root, local_files_only = already_downloaded, - flash_attention = True) + flash_attention = cfg["use_flash_attention"]) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: |
