diff options
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 80 |
1 files changed, 55 insertions, 25 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index e1929c1..491bc35 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -1,6 +1,7 @@ from browser_src import BrowserSource from datetime import datetime from emotes_v2 import EmotesState +from faster_whisper import WhisperModel from functools import partial from huggingface_hub import hf_hub_download from profanity_filter import ProfanityFilter @@ -10,7 +11,7 @@ from transcribe_pipeline import StreamingPlugin, TranscriptCommit import app_config import argparse -#import ctranslate2 +import ctranslate2 import editdistance import keybind_event_machine import keyboard @@ -403,18 +404,38 @@ class Whisper: collector: AudioCollector, cfg: typing.Dict): self.collector = collector + self.model = None self.cfg = cfg - import torch - from transformers import pipeline + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + parent_dir = os.path.dirname(my_dir) - self.pipe = pipeline( - "automatic-speech-recognition", - model="distil-whisper/distil-large-v2", - torch_dtype=torch.float16, - device="cuda", # TODO plumb - model_kwargs={"use_flash_attention_2": True}, # TODO only if cuda on - ) + model_str = cfg["model"] + model_root = os.path.join(parent_dir, "Models", model_str) + print(f"Model {cfg['model']} will be saved to {model_root}", + file=sys.stderr) + + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + + download_it = os.path.exists(model_root) + if '/' in model_str: + hf_hub_download(repo_id=model_str, filename='model.bin', + local_dir=model_root) + hf_hub_download(repo_id=model_str, filename='vocabulary.json', + local_dir=model_root) + hf_hub_download(repo_id=model_str, filename='config.json', + local_dir=model_root) + if download_it: + model_str = model_root + self.model = WhisperModel(model_str, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = "float16", + download_root = model_root, + local_files_only = download_it) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -425,22 +446,31 @@ class Whisper: dtype=np.int16).flatten().astype(np.float32) / 32768.0 t0 = time.time() - res = self.pipe( + segments, info = self.model.transcribe( audio, - chunk_length_s=30, - batch_size=1) - - result = [Segment(res["text"], - 0, - 0, - self.collector.begin(), - 0, - 0, - 0)] - + language = langcodes.find(self.cfg["language"]).language, + vad_filter = True, + temperature=0.0, + without_timestamps = False) + res = [] + for s in segments: + # Manual touchup. I see a decent number of hallucinations sneaking + # in with high `no_speech_prob` and modest `avg_logprob`. + if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: + continue + if cfg["enable_debug_mode"]: + print(f"s get: {s}") + if s.avg_logprob < -1.0: + continue + if s.compression_ratio > 2.4: + continue + res.append(Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio)) t1 = time.time() - print(f"Transcription latency (s): {t1 - t0}") - return result + if cfg["enable_debug_mode"]: + print(f"Transcription latency (s): {t1 - t0}") + return res def saveAudio(audio: bytes, path: str): with wave.open(path, 'wb') as wf: @@ -490,7 +520,7 @@ class VadCommitter: #saveAudio(commit_audio, filename) preview = "" - if self.cfg["enable_previews"] and has_audio and not stable_cutoff: + if self.cfg["enable_previews"] and has_audio: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) |
