diff options
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 32652df..32deb42 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -13,6 +13,7 @@ import app_config import argparse import ctranslate2 import editdistance +import glob import keybind_event_machine import keyboard import langcodes @@ -414,7 +415,8 @@ class Whisper: parent_dir = os.path.dirname(my_dir) model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", model_str) + model_root = os.path.join(parent_dir, "Models", + os.path.normpath(model_str)) print(f"Model {cfg['model']} will be saved to {model_root}", file=sys.stderr) @@ -423,22 +425,14 @@ class Whisper: model_device = "cpu" already_downloaded = os.path.exists(model_root) - if '/' in model_str: - hf_hub_download(repo_id=model_str, filename='model.bin', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='vocabulary.json', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='config.json', - local_dir=model_root) - already_downloaded = True - if already_downloaded: - model_str = model_root + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], compute_type = cfg["compute_type"], download_root = model_root, - local_files_only = already_downloaded) + local_files_only = already_downloaded, + flash_attention = True) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: |
