summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2024-01-08 18:59:27 -0800
committeryum <yum.food.vr@gmail.com>2024-01-08 18:59:27 -0800
commit33db3dcc23a45cae611bcf839c33d6615ccbf59e (patch)
tree7c4c9bc5b8c1ab83a4d1f7f009a9065f545f2074 /Scripts
parentf7e1cf963efc6e4e564b41445cfd328c3baa0f0a (diff)
Revert "Begin experimenting with flash-attention"
This reverts commit 921b92a69f36502dc5eefd14ba3487c1bb49bb9d.
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/requirements.txt9
-rw-r--r--Scripts/transcribe_v2.py80
2 files changed, 58 insertions, 31 deletions
diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt
index 58db539..9224ba8 100644
--- a/Scripts/requirements.txt
+++ b/Scripts/requirements.txt
@@ -1,20 +1,17 @@
+ctranslate2
editdistance
-flash-attn==2.3.6
+faster-whisper@https://github.com/guillaumekln/faster-whisper/archive/78d57d73c5b4a76b32d1d5a415e4e7aea760295c.tar.gz
future==0.18.2
huggingface_hub==0.16.4
keyboard
langcodes
language-data
openvr
-onnxruntime
-packaging
pillow
pyaudio
pydub
python-osc
pyyaml
sentence_splitter
-transformers==4.35.2
---extra-index-url https://download.pytorch.org/whl/cu121
-torch
+transformers>=4.21.0
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index e1929c1..491bc35 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -1,6 +1,7 @@
from browser_src import BrowserSource
from datetime import datetime
from emotes_v2 import EmotesState
+from faster_whisper import WhisperModel
from functools import partial
from huggingface_hub import hf_hub_download
from profanity_filter import ProfanityFilter
@@ -10,7 +11,7 @@ from transcribe_pipeline import StreamingPlugin, TranscriptCommit
import app_config
import argparse
-#import ctranslate2
+import ctranslate2
import editdistance
import keybind_event_machine
import keyboard
@@ -403,18 +404,38 @@ class Whisper:
collector: AudioCollector,
cfg: typing.Dict):
self.collector = collector
+ self.model = None
self.cfg = cfg
- import torch
- from transformers import pipeline
+ abspath = os.path.abspath(__file__)
+ my_dir = os.path.dirname(abspath)
+ parent_dir = os.path.dirname(my_dir)
- self.pipe = pipeline(
- "automatic-speech-recognition",
- model="distil-whisper/distil-large-v2",
- torch_dtype=torch.float16,
- device="cuda", # TODO plumb
- model_kwargs={"use_flash_attention_2": True}, # TODO only if cuda on
- )
+ model_str = cfg["model"]
+ model_root = os.path.join(parent_dir, "Models", model_str)
+ print(f"Model {cfg['model']} will be saved to {model_root}",
+ file=sys.stderr)
+
+ model_device = "cuda"
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+
+ download_it = os.path.exists(model_root)
+ if '/' in model_str:
+ hf_hub_download(repo_id=model_str, filename='model.bin',
+ local_dir=model_root)
+ hf_hub_download(repo_id=model_str, filename='vocabulary.json',
+ local_dir=model_root)
+ hf_hub_download(repo_id=model_str, filename='config.json',
+ local_dir=model_root)
+ if download_it:
+ model_str = model_root
+ self.model = WhisperModel(model_str,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = "float16",
+ download_root = model_root,
+ local_files_only = download_it)
def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
if frames is None:
@@ -425,22 +446,31 @@ class Whisper:
dtype=np.int16).flatten().astype(np.float32) / 32768.0
t0 = time.time()
- res = self.pipe(
+ segments, info = self.model.transcribe(
audio,
- chunk_length_s=30,
- batch_size=1)
-
- result = [Segment(res["text"],
- 0,
- 0,
- self.collector.begin(),
- 0,
- 0,
- 0)]
-
+ language = langcodes.find(self.cfg["language"]).language,
+ vad_filter = True,
+ temperature=0.0,
+ without_timestamps = False)
+ res = []
+ for s in segments:
+ # Manual touchup. I see a decent number of hallucinations sneaking
+ # in with high `no_speech_prob` and modest `avg_logprob`.
+ if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
+ continue
+ if cfg["enable_debug_mode"]:
+ print(f"s get: {s}")
+ if s.avg_logprob < -1.0:
+ continue
+ if s.compression_ratio > 2.4:
+ continue
+ res.append(Segment(s.text, s.start, s.end,
+ self.collector.begin(),
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio))
t1 = time.time()
- print(f"Transcription latency (s): {t1 - t0}")
- return result
+ if cfg["enable_debug_mode"]:
+ print(f"Transcription latency (s): {t1 - t0}")
+ return res
def saveAudio(audio: bytes, path: str):
with wave.open(path, 'wb') as wf:
@@ -490,7 +520,7 @@ class VadCommitter:
#saveAudio(commit_audio, filename)
preview = ""
- if self.cfg["enable_previews"] and has_audio and not stable_cutoff:
+ if self.cfg["enable_previews"] and has_audio:
segments = self.whisper.transcribe(audio)
preview = "".join(s.transcript for s in segments)