diff options
| author | yum <yum.food.vr@gmail.com> | 2023-09-07 22:04:16 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-09-07 22:04:16 -0700 |
| commit | a82e43c16ff097a7c57ee87e67fa67e7f007b977 (patch) | |
| tree | acb7c884a7a0f5037269a5e79a9bc79ed9f5372c | |
| parent | b40ded2981d5b037cdab9b78ff1ea0f8f22658d3 (diff) | |
Switch to VadCommitter
FuzzyRepeatCommitter was approximating this behavior in the
best-performing configuration, so switch to it in earnest.
This committer simply commits audio once we detect a long enough gap in
speech. That's it!
| -rw-r--r-- | GUI/package.ps1 | 6 | ||||
| -rw-r--r-- | Scripts/transcribe.py | 2 | ||||
| -rw-r--r-- | Scripts/transcribe_v2.py | 204 |
3 files changed, 103 insertions, 109 deletions
diff --git a/GUI/package.ps1 b/GUI/package.ps1 index ca4f4b4..c178a94 100644 --- a/GUI/package.ps1 +++ b/GUI/package.ps1 @@ -162,6 +162,10 @@ if (-Not (Test-Path curl)) { popd > $null
}
+if (-Not (Test-Path "silero-vad")) {
+ git clone "https://github.com/snakers4/silero-vad"
+}
+
mkdir $install_dir > $null
mkdir $install_dir/Resources > $null
cp -Recurse ../Animations TaSTT/Resources/Animations
@@ -179,6 +183,8 @@ cp -Recurse ../UnityAssets TaSTT/Resources/UnityAssets cp -Recurse ../BrowserSource TaSTT/Resources/BrowserSource
cp GUI/x64/$release/GUI.exe TaSTT/TaSTT.exe
mkdir TaSTT/Resources/Models
+cp "silero-vad/files/silero_vad.onnx" TaSTT/Resources/Models/
+cp "silero-vad/LICENSE" TaSTT/Resources/Models/silero_vad.onnx.LICENSE
mkdir TaSTT/Resources/Uwu
cp UwwwuPP/build/Src/Debug/Uwwwu.exe TaSTT/Resources/Uwu/
cp UwwwuPP/LICENSE TaSTT/Resources/Uwu/
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index a529ad2..7098400 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -444,7 +444,7 @@ def transcribeAudio(audio_state): # Apply filters to transcription filtered_text = translated if audio_state.cfg["enable_uwu_filter"]: - uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", filtered_text], + uwu_proc = subprocess.Popen(["Resources/Models/Uwwwu.exe", filtered_text], stdout=subprocess.PIPE, stderr=subprocess.PIPE) uwu_stdout, uwu_stderr = uwu_proc.communicate() diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 65a86e3..d957798 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -20,6 +20,7 @@ import sys import threading import time import typing +import vad TRANSCRIBE_REQ_RESET_COMMITS = 0 TRANSCRIBE_REQ_WHOLE_BUFFER = 1 @@ -214,6 +215,14 @@ class AudioCollector: self.wall_ts = self.wall_ts + self.duration() return cut_portion + def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes: + n_bytes = dur_frames * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] + self.frames = self.frames[n_bytes:] + self.wall_ts = self.wall_ts + self.duration() + return cut_portion + def keepLast(self, dur_s: float) -> bytes: drop_len = max(0, self.duration() - dur_s) return self.dropAudioPrefix(drop_len) @@ -241,6 +250,8 @@ class AudioCollectorFilter: return self.parent.getAudio() def dropAudioPrefix(self, dur_s: float): return self.parent.dropAudioPrefix(dur_s) + def dropAudioPrefixByFrames(self, dur_frames: int): + return self.parent.dropAudioPrefixByFrames(dur_frames) def keepLast(self, dur_s): return self.parent.keepLast(dur_s) def dropAudio(self): @@ -299,6 +310,43 @@ class CompressingAudioCollector(AudioCollectorFilter): return frames +class AudioSegmenter: + def __init__(self): + pass + + def segmentAudio(self, audio: bytes): + audio = np.frombuffer(audio, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + return vad.get_speech_timestamps(audio) + + def getStableCutoff(self, audio: bytes, min_delta_s = 1.5) -> int: + min_delta_frames = min_delta_s * AudioStream.FPS + cutoff = None + + last_end = None + segments = self.segmentAudio(audio) + for i in range(len(segments)): + s = segments[i] + #print(f"s: {s}") + #print(f"last_end: {last_end}") + + if last_end: + delta_frames = s['start'] - last_end + #print(f"delta frames: {delta_frames}") + if delta_frames > min_delta_frames: + cutoff = s['start'] - int(min_delta_frames / 2) + else: + last_end = s['end'] + + if i == len(segments) - 1: + now = int(len(audio) / AudioStream.FRAME_SZ) + #print(f"now: {now}") + #print(f"min d: {min_delta_frames}") + if now - s['end'] > min_delta_frames: + cutoff = now - int(min_delta_frames / 2) + + return cutoff + # A segment of transcribed audio. `start_ts` and `end_ts` are floating point # number of seconds since the beginning of audio data. class Segment: @@ -339,9 +387,10 @@ class Whisper: self.cfg = cfg abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) + my_dir = os.path.dirname(abspath) + parent_dir = os.path.dirname(my_dir) - model_root = os.path.join(dname, "Models", cfg["model"]) + model_root = os.path.join(parent_dir, "Models", cfg["model"]) print(f"Model {cfg['model']} will be saved to {model_root}", file=sys.stderr) @@ -360,7 +409,7 @@ class Whisper: download_root = model_root, local_files_only = download_it) - def transcribe(self, frames = None) -> typing.List[Segment]: + def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: frames = self.collector.getAudio() # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on @@ -396,89 +445,39 @@ class TranscriptCommit: self.thresh_at_commit = thresh_at_commit self.audio = audio -# Commits audio when the transcription layer repeats the same transcript, -# within some fuzzy match distance. -class FuzzyRepeatCommitter: +class VadCommitter: def __init__(self, collector: AudioCollector, - whisper: Whisper, - last_n_must_match: int = 4, - edit_thresh_min: float = 1, - edit_thresh_grow_begin_s: float = 1.5, - edit_thresh_grow_halflife_s: float = 0.5, - min_segment_age_s: float = 0.5): + whisper: Whisper): self.collector = collector self.whisper = whisper - # List of candidate segments. Once these all match, we commit the - # corresponding audio data. - self.candidates = [] - self.last_n_must_match = last_n_must_match - self.edit_thresh_min = edit_thresh_min - self.edit_thresh_grow_begin_s = edit_thresh_grow_begin_s - self.edit_thresh_grow_halflife_s = edit_thresh_grow_halflife_s - self.min_segment_age_s = min_segment_age_s + self.segmenter = AudioSegmenter() def getDelta(self) -> TranscriptCommit: - segments = self.whisper.transcribe() + audio = self.collector.getAudio() + stable_cutoff = self.segmenter.getStableCutoff(audio) + + delta = "" + commit_audio = None + latency_s = None + if stable_cutoff: + #print(f"stable cutoff get: {stable_cutoff}") + segments = self.whisper.transcribe(audio) + delta = ''.join(s.transcript for s in segments) + #print(f"delta get: {delta}") + commit_begin = self.collector.begin() + commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff) + latency_s = self.collector.now() - commit_begin + audio = self.collector.getAudio() + + segments = self.whisper.transcribe(audio) preview = ''.join(s.transcript for s in segments) - if len(segments) == 0: - self.collector.keepLast(1.0) - return TranscriptCommit("", preview, None) - - s = segments[0] - - if len(self.candidates) < self.last_n_must_match: - if len(self.candidates) == 0: - self.candidates.append(s) - return TranscriptCommit("", preview, None) - s0 = self.candidates[0] - if s.wall_ts != s0.wall_ts: - print("Frames dropped, committer resetting candidates", - file=sys.stderr) - self.candidates = [] - return TranscriptCommit("", preview, None) - self.candidates.append(s) - return TranscriptCommit("", preview, None) - - # Rule 1: last n segments must be within a certain edit distance of - # each other. This edit distance starts low and increases exponentially - # as the buffer size grows, thus allowing the check to get weaker under - # compute pressure. - edit_thresh = self.edit_thresh_min - dt = self.collector.now() - (self.collector.begin() + s.start_ts) - if dt > self.edit_thresh_grow_begin_s: - dt -= self.edit_thresh_grow_begin_s - edit_thresh = math.ceil(2**(dt / - self.edit_thresh_grow_halflife_s)) - - drop_candidates = 0 - for i in range(1, len(self.candidates)): - prev = self.candidates[i-1] - cur = self.candidates[i] - dist = editdistance.eval(prev.transcript, cur.transcript) - if dist > edit_thresh: - drop_candidates = i - if drop_candidates != 0: - self.candidates = self.candidates[drop_candidates:] - return TranscriptCommit("", preview, None) - - candidate = self.candidates[-1] - - # Rule 2: no committing segments that are fewer than the configured - # number of seconds old. - if self.collector.now() - (candidate.end_ts + candidate.wall_ts) < self.min_segment_age_s: - self.candidates = [] - return TranscriptCommit("", preview, None) - - # Got a candidate! Commit it and return. - self.candidates = [] - latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts) - # Measured to slightly improve performance in benchmark. - audio = self.collector.dropAudioPrefix(candidate.end_ts + 0.10) - - return TranscriptCommit(candidate.transcript, preview, latency_s, - thresh_at_commit=edit_thresh, audio=audio) + return TranscriptCommit( + delta, + preview, + latency_s, + audio=audio) class OscPager: def __init__(self, cfg): @@ -514,36 +513,17 @@ class OscPager: def evaluate(cfg, audio_path: str, - control_path: str, - last_n_must_match: int = 3, - edit_thresh_min: float = 1, - edit_thresh_grow_begin_s: float = 1.5, - edit_thresh_grow_halflife_s: float = 0.5, - min_segment_age_s: float = 0.5 - ): + control_path: str): stream = DiskStream(audio_path) collector = AudioCollector(stream) - #collector = LengthEnforcingAudioCollector(collector, 5.0) - #collector = NormalizingAudioCollector(collector) collector = CompressingAudioCollector(collector) whisper = Whisper(collector, cfg) - com = FuzzyRepeatCommitter(collector, whisper, - last_n_must_match=last_n_must_match, - edit_thresh_min=edit_thresh_min, - edit_thresh_grow_begin_s=edit_thresh_grow_begin_s, - edit_thresh_grow_halflife_s=edit_thresh_grow_halflife_s, - min_segment_age_s=min_segment_age_s) + com = None + com = VadCommitter(collector, whisper) transcript = "" commits = [] - print(f"PARAMS") - print(f"last_n_must_match: {last_n_must_match}") - print(f"edit_thresh_min: {edit_thresh_min}") - print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}") - print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}") - print(f"min_segment_age_s: {min_segment_age_s}") - while len(stream.frames) > 0: commit = com.getDelta() @@ -657,19 +637,18 @@ def transcriptionThread(ctrl: ThreadControl): commit = ctrl.committer.getDelta() - preview = commit.preview - if False and len(preview) > 0: - preview = "[" + preview + "]" + if True: + print(f"Transcript: {ctrl.transcript}{commit.delta}{commit.preview}") - if len(commit.delta): - print(f"Transcript: {ctrl.transcript}{preview}") + if False and len(commit.delta): + print(f"Transcript: {ctrl.transcript}{commit.delta}{commit.preview}") if cfg["enable_debug_mode"]: print(f"commit latency: {commit.latency_s}", file=sys.stderr) print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) commits.append(commit) - ctrl.preview = ctrl.transcript + preview ctrl.transcript += commit.delta + ctrl.preview = ctrl.transcript + commit.preview def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -789,6 +768,15 @@ def oscThread(ctrl: ThreadControl): ctrl.pager.page(ctrl.preview) time.sleep(0.01) +def dev_run(cfg): + stream = MicStream(cfg["microphone"]) + collector = AudioCollector(stream) + segmenter = AudioSegmenter() + while True: + audio = collector.getAudio() + cutoff = segmenter.getStableCutoff(audio) + print(f"audio cutoff: {cutoff}") + def run(cfg): stream = MicStream(cfg["microphone"]) @@ -797,7 +785,7 @@ def run(cfg): #collector = NormalizingAudioCollector(collector) collector = CompressingAudioCollector(collector) whisper = Whisper(collector, cfg) - committer = FuzzyRepeatCommitter(collector, whisper) + committer = VadCommitter(collector, whisper) pager = OscPager(cfg) ctrl = ThreadControl(cfg) @@ -860,5 +848,5 @@ if __name__ == "__main__": else: #optimize(cfg, experiments) run(cfg) - + #dev_run(cfg) |
