diff options
| author | yum <yum.food.vr@gmail.com> | 2023-12-13 13:54:55 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-12-13 13:54:57 -0800 |
| commit | 921b92a69f36502dc5eefd14ba3487c1bb49bb9d (patch) | |
| tree | 06c2adf532d34440f3fa85fe62de519a357f8942 /Scripts | |
| parent | 859caec3d5c1b6aa9eee98571af3324b6ed1bd21 (diff) | |
Begin experimenting with flash-attention
Seems much faster than faster-whisper.
There are two issues:
* Requires NVIDIA 3000 series or higher.
* Incompatible with faster-whisper dependencies.
So it seems like we'll either need to toggle between two sets of
dependencies at runtime or have two environments.
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/requirements.txt | 9 | ||||
| -rw-r--r-- | Scripts/transcribe_v2.py | 80 |
2 files changed, 31 insertions, 58 deletions
diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt index 9224ba8..58db539 100644 --- a/Scripts/requirements.txt +++ b/Scripts/requirements.txt @@ -1,17 +1,20 @@ -ctranslate2 editdistance -faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/78d57d73c5b4a76b32d1d5a415e4e7aea760295c.tar.gz +flash-attn==2.3.6 future==0.18.2 huggingface_hub==0.16.4 keyboard langcodes language-data openvr +onnxruntime +packaging pillow pyaudio pydub python-osc pyyaml sentence_splitter -transformers>=4.21.0 +transformers==4.35.2 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 491bc35..e1929c1 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -1,7 +1,6 @@ from browser_src import BrowserSource from datetime import datetime from emotes_v2 import EmotesState -from faster_whisper import WhisperModel from functools import partial from huggingface_hub import hf_hub_download from profanity_filter import ProfanityFilter @@ -11,7 +10,7 @@ from transcribe_pipeline import StreamingPlugin, TranscriptCommit import app_config import argparse -import ctranslate2 +#import ctranslate2 import editdistance import keybind_event_machine import keyboard @@ -404,38 +403,18 @@ class Whisper: collector: AudioCollector, cfg: typing.Dict): self.collector = collector - self.model = None self.cfg = cfg - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) + import torch + from transformers import pipeline - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", model_str) - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) - - model_device = "cuda" - if cfg["use_cpu"]: - model_device = "cpu" - - download_it = os.path.exists(model_root) - if '/' in model_str: - hf_hub_download(repo_id=model_str, filename='model.bin', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='vocabulary.json', - local_dir=model_root) - hf_hub_download(repo_id=model_str, filename='config.json', - local_dir=model_root) - if download_it: - model_str = model_root - self.model = WhisperModel(model_str, - device = model_device, - device_index = cfg["gpu_idx"], - compute_type = "float16", - download_root = model_root, - local_files_only = download_it) + self.pipe = pipeline( + "automatic-speech-recognition", + model="distil-whisper/distil-large-v2", + torch_dtype=torch.float16, + device="cuda", # TODO plumb + model_kwargs={"use_flash_attention_2": True}, # TODO only if cuda on + ) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -446,31 +425,22 @@ class Whisper: dtype=np.int16).flatten().astype(np.float32) / 32768.0 t0 = time.time() - segments, info = self.model.transcribe( + res = self.pipe( audio, - language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, - without_timestamps = False) - res = [] - for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - continue - if cfg["enable_debug_mode"]: - print(f"s get: {s}") - if s.avg_logprob < -1.0: - continue - if s.compression_ratio > 2.4: - continue - res.append(Segment(s.text, s.start, s.end, - self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + chunk_length_s=30, + batch_size=1) + + result = [Segment(res["text"], + 0, + 0, + self.collector.begin(), + 0, + 0, + 0)] + t1 = time.time() - if cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") - return res + print(f"Transcription latency (s): {t1 - t0}") + return result def saveAudio(audio: bytes, path: str): with wave.open(path, 'wb') as wf: @@ -520,7 +490,7 @@ class VadCommitter: #saveAudio(commit_audio, filename) preview = "" - if self.cfg["enable_previews"] and has_audio: + if self.cfg["enable_previews"] and has_audio and not stable_cutoff: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) |
