summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py80
1 files changed, 25 insertions, 55 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 491bc35..e1929c1 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -1,7 +1,6 @@
from browser_src import BrowserSource
from datetime import datetime
from emotes_v2 import EmotesState
-from faster_whisper import WhisperModel
from functools import partial
from huggingface_hub import hf_hub_download
from profanity_filter import ProfanityFilter
@@ -11,7 +10,7 @@ from transcribe_pipeline import StreamingPlugin, TranscriptCommit
import app_config
import argparse
-import ctranslate2
+#import ctranslate2
import editdistance
import keybind_event_machine
import keyboard
@@ -404,38 +403,18 @@ class Whisper:
collector: AudioCollector,
cfg: typing.Dict):
self.collector = collector
- self.model = None
self.cfg = cfg
- abspath = os.path.abspath(__file__)
- my_dir = os.path.dirname(abspath)
- parent_dir = os.path.dirname(my_dir)
+ import torch
+ from transformers import pipeline
- model_str = cfg["model"]
- model_root = os.path.join(parent_dir, "Models", model_str)
- print(f"Model {cfg['model']} will be saved to {model_root}",
- file=sys.stderr)
-
- model_device = "cuda"
- if cfg["use_cpu"]:
- model_device = "cpu"
-
- download_it = os.path.exists(model_root)
- if '/' in model_str:
- hf_hub_download(repo_id=model_str, filename='model.bin',
- local_dir=model_root)
- hf_hub_download(repo_id=model_str, filename='vocabulary.json',
- local_dir=model_root)
- hf_hub_download(repo_id=model_str, filename='config.json',
- local_dir=model_root)
- if download_it:
- model_str = model_root
- self.model = WhisperModel(model_str,
- device = model_device,
- device_index = cfg["gpu_idx"],
- compute_type = "float16",
- download_root = model_root,
- local_files_only = download_it)
+ self.pipe = pipeline(
+ "automatic-speech-recognition",
+ model="distil-whisper/distil-large-v2",
+ torch_dtype=torch.float16,
+ device="cuda", # TODO plumb
+ model_kwargs={"use_flash_attention_2": True}, # TODO only if cuda on
+ )
def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
if frames is None:
@@ -446,31 +425,22 @@ class Whisper:
dtype=np.int16).flatten().astype(np.float32) / 32768.0
t0 = time.time()
- segments, info = self.model.transcribe(
+ res = self.pipe(
audio,
- language = langcodes.find(self.cfg["language"]).language,
- vad_filter = True,
- temperature=0.0,
- without_timestamps = False)
- res = []
- for s in segments:
- # Manual touchup. I see a decent number of hallucinations sneaking
- # in with high `no_speech_prob` and modest `avg_logprob`.
- if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
- continue
- if cfg["enable_debug_mode"]:
- print(f"s get: {s}")
- if s.avg_logprob < -1.0:
- continue
- if s.compression_ratio > 2.4:
- continue
- res.append(Segment(s.text, s.start, s.end,
- self.collector.begin(),
- s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ chunk_length_s=30,
+ batch_size=1)
+
+ result = [Segment(res["text"],
+ 0,
+ 0,
+ self.collector.begin(),
+ 0,
+ 0,
+ 0)]
+
t1 = time.time()
- if cfg["enable_debug_mode"]:
- print(f"Transcription latency (s): {t1 - t0}")
- return res
+ print(f"Transcription latency (s): {t1 - t0}")
+ return result
def saveAudio(audio: bytes, path: str):
with wave.open(path, 'wb') as wf:
@@ -520,7 +490,7 @@ class VadCommitter:
#saveAudio(commit_audio, filename)
preview = ""
- if self.cfg["enable_previews"] and has_audio:
+ if self.cfg["enable_previews"] and has_audio and not stable_cutoff:
segments = self.whisper.transcribe(audio)
preview = "".join(s.transcript for s in segments)