diff options
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 80 |
1 files changed, 25 insertions, 55 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 491bc35..e1929c1 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -1,7 +1,6 @@ from browser_src import BrowserSource from datetime import datetime from emotes_v2 import EmotesState -from faster_whisper import WhisperModel from functools import partial from huggingface_hub import hf_hub_download from profanity_filter import ProfanityFilter @@ -11,7 +10,7 @@ from transcribe_pipeline import StreamingPlugin, TranscriptCommit import app_config import argparse -import ctranslate2 +#import ctranslate2 import editdistance import keybind_event_machine import keyboard @@ -404,38 +403,18 @@ class Whisper: collector: AudioCollector, cfg: typing.Dict): self.collector = collector - self.model = None self.cfg = cfg - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) + import torch + from transformers import pipeline - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", model_str) - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) - - model_device = "cuda" - if cfg["use_cpu"]: - model_device = "cpu" - - download_it = os.path.exists(model_root) - if '/' in model_str: - hf_hub_download(repo_id=model_str, filename='model.bin', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='vocabulary.json', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='config.json', - local_dir=model_root) - if download_it: - model_str = model_root - self.model = WhisperModel(model_str, - device = model_device, - device_index = cfg["gpu_idx"], - compute_type = "float16", - download_root = model_root, - local_files_only = download_it) + self.pipe = pipeline( + "automatic-speech-recognition", + model="distil-whisper/distil-large-v2", + torch_dtype=torch.float16, + device="cuda", # TODO plumb + model_kwargs={"use_flash_attention_2": True}, # TODO only if cuda on + ) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -446,31 +425,22 @@ class Whisper: dtype=np.int16).flatten().astype(np.float32) / 32768.0 t0 = time.time() - segments, info = self.model.transcribe( + res = self.pipe( audio, - language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, - without_timestamps = False) - res = [] - for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - continue - if cfg["enable_debug_mode"]: - print(f"s get: {s}") - if s.avg_logprob < -1.0: - continue - if s.compression_ratio > 2.4: - continue - res.append(Segment(s.text, s.start, s.end, - self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + chunk_length_s=30, + batch_size=1) + + result = [Segment(res["text"], + 0, + 0, + self.collector.begin(), + 0, + 0, + 0)] + t1 = time.time() - if cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") - return res + print(f"Transcription latency (s): {t1 - t0}") + return result def saveAudio(audio: bytes, path: str): with wave.open(path, 'wb') as wf: @@ -520,7 +490,7 @@ class VadCommitter: #saveAudio(commit_audio, filename) preview = "" - if self.cfg["enable_previews"] and has_audio: + if self.cfg["enable_previews"] and has_audio and not stable_cutoff: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) |
