summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--GUI/GUI/GUI/Frame.cpp1
-rw-r--r--Scripts/transcribe_v2.py20
2 files changed, 18 insertions, 3 deletions
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp
index 602bf6d..0cdd99e 100644
--- a/GUI/GUI/GUI/Frame.cpp
+++ b/GUI/GUI/GUI/Frame.cpp
@@ -458,6 +458,7 @@ namespace {
"base",
"small.en",
"small",
+ "yumfood/whisper_distil_medium_en_ct2",
"medium.en",
"medium",
"large-v1",
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 889e1cf..4ae17d1 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -3,6 +3,7 @@ from datetime import datetime
from emotes_v2 import EmotesState
from faster_whisper import WhisperModel
from functools import partial
+from huggingface_hub import hf_hub_download
from profanity_filter import ProfanityFilter
from pydub import AudioSegment
from sentence_splitter import split_text_into_sentences
@@ -410,7 +411,8 @@ class Whisper:
my_dir = os.path.dirname(abspath)
parent_dir = os.path.dirname(my_dir)
- model_root = os.path.join(parent_dir, "Models", cfg["model"])
+ model_str = cfg["model"]
+ model_root = os.path.join(parent_dir, "Models", model_str)
print(f"Model {cfg['model']} will be saved to {model_root}",
file=sys.stderr)
@@ -419,13 +421,21 @@ class Whisper:
model_device = "cpu"
download_it = os.path.exists(model_root)
- model_str = cfg["model"]
+ model_str = model_str
if download_it:
+ if '/' in model_str:
+ hf_hub_download(repo_id=model_str, filename='model.bin',
+ local_dir=model_root)
+ hf_hub_download(repo_id=model_str, filename='vocabulary.json',
+ local_dir=model_root)
+ hf_hub_download(repo_id=model_str, filename='config.json',
+ local_dir=model_root)
+
model_str = model_root
self.model = WhisperModel(model_str,
device = model_device,
device_index = cfg["gpu_idx"],
- compute_type = "int8",
+ compute_type = "float16",
download_root = model_root,
local_files_only = download_it)
@@ -437,6 +447,7 @@ class Whisper:
audio = np.frombuffer(frames,
dtype=np.int16).flatten().astype(np.float32) / 32768.0
+ t0 = time.time()
segments, info = self.model.transcribe(
audio,
language = langcodes.find(self.cfg["language"]).language,
@@ -458,6 +469,9 @@ class Whisper:
res.append(Segment(s.text, s.start, s.end,
self.collector.begin(),
s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ t1 = time.time()
+ if cfg["enable_debug_mode"]:
+ print(f"Transcription latency (s): {t1 - t0}")
return res
def saveAudio(audio: bytes, path: str):