summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2024-06-09 16:43:34 -0700
committeryum <yum.food.vr@gmail.com>2024-06-09 16:43:34 -0700
commit4fec36c3cc00bd649dfb3c9d7e9079b5c8685a0e (patch)
tree48c765fb558f10dcfe7cec5ec1511020ed41e8f0 /Scripts/transcribe_v2.py
parentf2b21dd5afebd6b76b5835168f7d1bd3bec21f5d (diff)
Bump CUDNN to v8.9.7v0.19.1
Also disable flash-attention when CPU mode is selected
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 2f37945..1bdc487 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -423,11 +423,12 @@ class Whisper:
model_device = "cuda"
if cfg["use_cpu"]:
model_device = "cpu"
+ if cfg["use_flash_attention"]:
+ print(f"Flash attention disabled on CPU", file=sys.stderr)
+ cfg["use_flash_attention"] = False
already_downloaded = os.path.exists(model_root)
- print(f"Use flash attention {cfg['use_flash_attention']}")
-
self.model = WhisperModel(model_str,
device = model_device,
device_index = cfg["gpu_idx"],