summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-07 22:04:16 -0700
committeryum <yum.food.vr@gmail.com>2023-09-07 22:04:16 -0700
commita82e43c16ff097a7c57ee87e67fa67e7f007b977 (patch)
treeacb7c884a7a0f5037269a5e79a9bc79ed9f5372c
parentb40ded2981d5b037cdab9b78ff1ea0f8f22658d3 (diff)
Switch to VadCommitter
FuzzyRepeatCommitter was approximating this behavior in the best-performing configuration, so switch to it in earnest. This committer simply commits audio once we detect a long enough gap in speech. That's it!
-rw-r--r--GUI/package.ps16
-rw-r--r--Scripts/transcribe.py2
-rw-r--r--Scripts/transcribe_v2.py204
3 files changed, 103 insertions, 109 deletions
diff --git a/GUI/package.ps1 b/GUI/package.ps1
index ca4f4b4..c178a94 100644
--- a/GUI/package.ps1
+++ b/GUI/package.ps1
@@ -162,6 +162,10 @@ if (-Not (Test-Path curl)) {
popd > $null
}
+if (-Not (Test-Path "silero-vad")) {
+ git clone "https://github.com/snakers4/silero-vad"
+}
+
mkdir $install_dir > $null
mkdir $install_dir/Resources > $null
cp -Recurse ../Animations TaSTT/Resources/Animations
@@ -179,6 +183,8 @@ cp -Recurse ../UnityAssets TaSTT/Resources/UnityAssets
cp -Recurse ../BrowserSource TaSTT/Resources/BrowserSource
cp GUI/x64/$release/GUI.exe TaSTT/TaSTT.exe
mkdir TaSTT/Resources/Models
+cp "silero-vad/files/silero_vad.onnx" TaSTT/Resources/Models/
+cp "silero-vad/LICENSE" TaSTT/Resources/Models/silero_vad.onnx.LICENSE
mkdir TaSTT/Resources/Uwu
cp UwwwuPP/build/Src/Debug/Uwwwu.exe TaSTT/Resources/Uwu/
cp UwwwuPP/LICENSE TaSTT/Resources/Uwu/
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py
index a529ad2..7098400 100644
--- a/Scripts/transcribe.py
+++ b/Scripts/transcribe.py
@@ -444,7 +444,7 @@ def transcribeAudio(audio_state):
# Apply filters to transcription
filtered_text = translated
if audio_state.cfg["enable_uwu_filter"]:
- uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text],
+ uwu_proc = subprocess.Popen(["Resources/Models/Uwwwu.exe", filtered_text],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
uwu_stdout, uwu_stderr = uwu_proc.communicate()
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 65a86e3..d957798 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -20,6 +20,7 @@ import sys
import threading
import time
import typing
+import vad
TRANSCRIBE_REQ_RESET_COMMITS = 0
TRANSCRIBE_REQ_WHOLE_BUFFER = 1
@@ -214,6 +215,14 @@ class AudioCollector:
self.wall_ts = self.wall_ts + self.duration()
return cut_portion
+ def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes:
+ n_bytes = dur_frames * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ cut_portion = self.frames[:n_bytes]
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts = self.wall_ts + self.duration()
+ return cut_portion
+
def keepLast(self, dur_s: float) -> bytes:
drop_len = max(0, self.duration() - dur_s)
return self.dropAudioPrefix(drop_len)
@@ -241,6 +250,8 @@ class AudioCollectorFilter:
return self.parent.getAudio()
def dropAudioPrefix(self, dur_s: float):
return self.parent.dropAudioPrefix(dur_s)
+ def dropAudioPrefixByFrames(self, dur_frames: int):
+ return self.parent.dropAudioPrefixByFrames(dur_frames)
def keepLast(self, dur_s):
return self.parent.keepLast(dur_s)
def dropAudio(self):
@@ -299,6 +310,43 @@ class CompressingAudioCollector(AudioCollectorFilter):
return frames
+class AudioSegmenter:
+ def __init__(self):
+ pass
+
+ def segmentAudio(self, audio: bytes):
+ audio = np.frombuffer(audio,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+ return vad.get_speech_timestamps(audio)
+
+ def getStableCutoff(self, audio: bytes, min_delta_s = 1.5) -> int:
+ min_delta_frames = min_delta_s * AudioStream.FPS
+ cutoff = None
+
+ last_end = None
+ segments = self.segmentAudio(audio)
+ for i in range(len(segments)):
+ s = segments[i]
+ #print(f"s: {s}")
+ #print(f"last_end: {last_end}")
+
+ if last_end:
+ delta_frames = s['start'] - last_end
+ #print(f"delta frames: {delta_frames}")
+ if delta_frames > min_delta_frames:
+ cutoff = s['start'] - int(min_delta_frames / 2)
+ else:
+ last_end = s['end']
+
+ if i == len(segments) - 1:
+ now = int(len(audio) / AudioStream.FRAME_SZ)
+ #print(f"now: {now}")
+ #print(f"min d: {min_delta_frames}")
+ if now - s['end'] > min_delta_frames:
+ cutoff = now - int(min_delta_frames / 2)
+
+ return cutoff
+
# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
# number of seconds since the beginning of audio data.
class Segment:
@@ -339,9 +387,10 @@ class Whisper:
self.cfg = cfg
abspath = os.path.abspath(__file__)
- dname = os.path.dirname(abspath)
+ my_dir = os.path.dirname(abspath)
+ parent_dir = os.path.dirname(my_dir)
- model_root = os.path.join(dname, "Models", cfg["model"])
+ model_root = os.path.join(parent_dir, "Models", cfg["model"])
print(f"Model {cfg['model']} will be saved to {model_root}",
file=sys.stderr)
@@ -360,7 +409,7 @@ class Whisper:
download_root = model_root,
local_files_only = download_it)
- def transcribe(self, frames = None) -> typing.List[Segment]:
+ def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
if frames is None:
frames = self.collector.getAudio()
# Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on
@@ -396,89 +445,39 @@ class TranscriptCommit:
self.thresh_at_commit = thresh_at_commit
self.audio = audio
-# Commits audio when the transcription layer repeats the same transcript,
-# within some fuzzy match distance.
-class FuzzyRepeatCommitter:
+class VadCommitter:
def __init__(self,
collector: AudioCollector,
- whisper: Whisper,
- last_n_must_match: int = 4,
- edit_thresh_min: float = 1,
- edit_thresh_grow_begin_s: float = 1.5,
- edit_thresh_grow_halflife_s: float = 0.5,
- min_segment_age_s: float = 0.5):
+ whisper: Whisper):
self.collector = collector
self.whisper = whisper
- # List of candidate segments. Once these all match, we commit the
- # corresponding audio data.
- self.candidates = []
- self.last_n_must_match = last_n_must_match
- self.edit_thresh_min = edit_thresh_min
- self.edit_thresh_grow_begin_s = edit_thresh_grow_begin_s
- self.edit_thresh_grow_halflife_s = edit_thresh_grow_halflife_s
- self.min_segment_age_s = min_segment_age_s
+ self.segmenter = AudioSegmenter()
def getDelta(self) -> TranscriptCommit:
- segments = self.whisper.transcribe()
+ audio = self.collector.getAudio()
+ stable_cutoff = self.segmenter.getStableCutoff(audio)
+
+ delta = ""
+ commit_audio = None
+ latency_s = None
+ if stable_cutoff:
+ #print(f"stable cutoff get: {stable_cutoff}")
+ segments = self.whisper.transcribe(audio)
+ delta = ''.join(s.transcript for s in segments)
+ #print(f"delta get: {delta}")
+ commit_begin = self.collector.begin()
+ commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff)
+ latency_s = self.collector.now() - commit_begin
+ audio = self.collector.getAudio()
+
+ segments = self.whisper.transcribe(audio)
preview = ''.join(s.transcript for s in segments)
- if len(segments) == 0:
- self.collector.keepLast(1.0)
- return TranscriptCommit("", preview, None)
-
- s = segments[0]
-
- if len(self.candidates) < self.last_n_must_match:
- if len(self.candidates) == 0:
- self.candidates.append(s)
- return TranscriptCommit("", preview, None)
- s0 = self.candidates[0]
- if s.wall_ts != s0.wall_ts:
- print("Frames dropped, committer resetting candidates",
- file=sys.stderr)
- self.candidates = []
- return TranscriptCommit("", preview, None)
- self.candidates.append(s)
- return TranscriptCommit("", preview, None)
-
- # Rule 1: last n segments must be within a certain edit distance of
- # each other. This edit distance starts low and increases exponentially
- # as the buffer size grows, thus allowing the check to get weaker under
- # compute pressure.
- edit_thresh = self.edit_thresh_min
- dt = self.collector.now() - (self.collector.begin() + s.start_ts)
- if dt > self.edit_thresh_grow_begin_s:
- dt -= self.edit_thresh_grow_begin_s
- edit_thresh = math.ceil(2**(dt /
- self.edit_thresh_grow_halflife_s))
-
- drop_candidates = 0
- for i in range(1, len(self.candidates)):
- prev = self.candidates[i-1]
- cur = self.candidates[i]
- dist = editdistance.eval(prev.transcript, cur.transcript)
- if dist > edit_thresh:
- drop_candidates = i
- if drop_candidates != 0:
- self.candidates = self.candidates[drop_candidates:]
- return TranscriptCommit("", preview, None)
-
- candidate = self.candidates[-1]
-
- # Rule 2: no committing segments that are fewer than the configured
- # number of seconds old.
- if self.collector.now() - (candidate.end_ts + candidate.wall_ts) < self.min_segment_age_s:
- self.candidates = []
- return TranscriptCommit("", preview, None)
-
- # Got a candidate! Commit it and return.
- self.candidates = []
- latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts)
- # Measured to slightly improve performance in benchmark.
- audio = self.collector.dropAudioPrefix(candidate.end_ts + 0.10)
-
- return TranscriptCommit(candidate.transcript, preview, latency_s,
- thresh_at_commit=edit_thresh, audio=audio)
+ return TranscriptCommit(
+ delta,
+ preview,
+ latency_s,
+ audio=audio)
class OscPager:
def __init__(self, cfg):
@@ -514,36 +513,17 @@ class OscPager:
def evaluate(cfg,
audio_path: str,
- control_path: str,
- last_n_must_match: int = 3,
- edit_thresh_min: float = 1,
- edit_thresh_grow_begin_s: float = 1.5,
- edit_thresh_grow_halflife_s: float = 0.5,
- min_segment_age_s: float = 0.5
- ):
+ control_path: str):
stream = DiskStream(audio_path)
collector = AudioCollector(stream)
- #collector = LengthEnforcingAudioCollector(collector, 5.0)
- #collector = NormalizingAudioCollector(collector)
collector = CompressingAudioCollector(collector)
whisper = Whisper(collector, cfg)
- com = FuzzyRepeatCommitter(collector, whisper,
- last_n_must_match=last_n_must_match,
- edit_thresh_min=edit_thresh_min,
- edit_thresh_grow_begin_s=edit_thresh_grow_begin_s,
- edit_thresh_grow_halflife_s=edit_thresh_grow_halflife_s,
- min_segment_age_s=min_segment_age_s)
+ com = None
+ com = VadCommitter(collector, whisper)
transcript = ""
commits = []
- print(f"PARAMS")
- print(f"last_n_must_match: {last_n_must_match}")
- print(f"edit_thresh_min: {edit_thresh_min}")
- print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}")
- print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}")
- print(f"min_segment_age_s: {min_segment_age_s}")
-
while len(stream.frames) > 0:
commit = com.getDelta()
@@ -657,19 +637,18 @@ def transcriptionThread(ctrl: ThreadControl):
commit = ctrl.committer.getDelta()
- preview = commit.preview
- if False and len(preview) > 0:
- preview = "[" + preview + "]"
+ if True:
+ print(f"Transcript: {ctrl.transcript}{commit.delta}{commit.preview}")
- if len(commit.delta):
- print(f"Transcript: {ctrl.transcript}{preview}")
+ if False and len(commit.delta):
+ print(f"Transcript: {ctrl.transcript}{commit.delta}{commit.preview}")
if cfg["enable_debug_mode"]:
print(f"commit latency: {commit.latency_s}", file=sys.stderr)
print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr)
commits.append(commit)
- ctrl.preview = ctrl.transcript + preview
ctrl.transcript += commit.delta
+ ctrl.preview = ctrl.transcript + commit.preview
def vrInputThread(ctrl: ThreadControl):
RECORD_STATE = 0
@@ -789,6 +768,15 @@ def oscThread(ctrl: ThreadControl):
ctrl.pager.page(ctrl.preview)
time.sleep(0.01)
+def dev_run(cfg):
+ stream = MicStream(cfg["microphone"])
+ collector = AudioCollector(stream)
+ segmenter = AudioSegmenter()
+ while True:
+ audio = collector.getAudio()
+ cutoff = segmenter.getStableCutoff(audio)
+ print(f"audio cutoff: {cutoff}")
+
def run(cfg):
stream = MicStream(cfg["microphone"])
@@ -797,7 +785,7 @@ def run(cfg):
#collector = NormalizingAudioCollector(collector)
collector = CompressingAudioCollector(collector)
whisper = Whisper(collector, cfg)
- committer = FuzzyRepeatCommitter(collector, whisper)
+ committer = VadCommitter(collector, whisper)
pager = OscPager(cfg)
ctrl = ThreadControl(cfg)
@@ -860,5 +848,5 @@ if __name__ == "__main__":
else:
#optimize(cfg, experiments)
run(cfg)
-
+ #dev_run(cfg)