From b40ded2981d5b037cdab9b78ff1ea0f8f22658d3 Mon Sep 17 00:00:00 2001 From: yum Date: Tue, 5 Sep 2023 20:51:20 -0700 Subject: Put OSC logic into its own thread This logic is highly IO bound *and* latency critical so it makes sense to put it into its own thread. Also: * Collector::drop* methods return the dropped audio. Committer includes that audio in commits. Transcription thread holds onto it. When the user segments their speech with a button press, the transcription thread sends the entire combined audio of all commits over to Whisper to be transcribed. This allows us to recover from errors introduced by segmentation. * Remove unused animator params * Fix issue where clearing the board doesn't completely reset STT state TODO: * Coalescing does not occur for in-place updates. It should. --- Scripts/generate_utils.py | 6 ---- Scripts/libtastt.py | 2 -- Scripts/transcribe_v2.py | 91 ++++++++++++++++++++++++++++++++++++----------- 3 files changed, 71 insertions(+), 28 deletions(-) (limited to 'Scripts') diff --git a/Scripts/generate_utils.py b/Scripts/generate_utils.py index c486201..1e12103 100644 --- a/Scripts/generate_utils.py +++ b/Scripts/generate_utils.py @@ -30,12 +30,6 @@ config = Config() def getDummyParam(): return "TaSTT_Dummy" -def getHipToggleParam(): - return "TaSTT_Hip_Toggle" - -def getHandToggleParam(): - return "TaSTT_Hand_Toggle" - def getToggleParam(): return "TaSTT_Toggle" diff --git a/Scripts/libtastt.py b/Scripts/libtastt.py index b05a724..5e216ca 100644 --- a/Scripts/libtastt.py +++ b/Scripts/libtastt.py @@ -689,8 +689,6 @@ def generateFXController(anim: libunity.UnityAnimator) -> typing.Dict[int, libun anim.addParameter(generate_utils.getEnableParam(), bool) anim.addParameter(generate_utils.getDummyParam(), bool) - anim.addParameter(generate_utils.getHipToggleParam(), bool) - anim.addParameter(generate_utils.getHandToggleParam(), bool) anim.addParameter(generate_utils.getToggleParam(), bool) anim.addParameter(generate_utils.getClearBoardParam(), bool) anim.addParameter(generate_utils.getScaleParam(), float) diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 9812535..65a86e3 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -21,10 +21,14 @@ import threading import time import typing +TRANSCRIBE_REQ_RESET_COMMITS = 0 +TRANSCRIBE_REQ_WHOLE_BUFFER = 1 + class ThreadControl: def __init__(self, cfg): self.cfg = cfg self.run_app = True + self.transcribe_queue = [] class AudioStream(): FORMAT = pyaudio.paInt16 @@ -202,19 +206,23 @@ class AudioCollector: self.frames += frames return self.frames - def dropAudioPrefix(self, dur_s: float): + def dropAudioPrefix(self, dur_s: float) -> bytes: n_bytes = int(dur_s * self.stream.FPS) * self.stream.FRAME_SZ n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] self.frames = self.frames[n_bytes:] self.wall_ts = self.wall_ts + self.duration() + return cut_portion - def keepLast(self, dur_s: float): + def keepLast(self, dur_s: float) -> bytes: drop_len = max(0, self.duration() - dur_s) - self.dropAudioPrefix(drop_len) + return self.dropAudioPrefix(drop_len) def dropAudio(self): self.wall_ts += self.duration() + cut_portion = self.frames self.frames = b'' + return cut_portion def duration(self): return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ) @@ -352,8 +360,9 @@ class Whisper: download_root = model_root, local_files_only = download_it) - def transcribe(self) -> typing.List[Segment]: - frames = self.collector.getAudio() + def transcribe(self, frames = None) -> typing.List[Segment]: + if frames is None: + frames = self.collector.getAudio() # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on # [-1, 1]. audio = np.frombuffer(frames, @@ -379,11 +388,13 @@ class TranscriptCommit: delta: str, preview: str, latency_s: int = None, - thresh_at_commit: int = None): + thresh_at_commit: int = None, + audio: bytes = None): self.delta = delta self.preview = preview self.latency_s = latency_s self.thresh_at_commit = thresh_at_commit + self.audio = audio # Commits audio when the transcription layer repeats the same transcript, # within some fuzzy match distance. @@ -391,7 +402,7 @@ class FuzzyRepeatCommitter: def __init__(self, collector: AudioCollector, whisper: Whisper, - last_n_must_match: int = 3, + last_n_must_match: int = 4, edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 0.5, @@ -464,10 +475,10 @@ class FuzzyRepeatCommitter: self.candidates = [] latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts) # Measured to slightly improve performance in benchmark. - self.collector.dropAudioPrefix(candidate.end_ts + 0.10) + audio = self.collector.dropAudioPrefix(candidate.end_ts + 0.10) return TranscriptCommit(candidate.transcript, preview, latency_s, - thresh_at_commit = edit_thresh) + thresh_at_commit=edit_thresh, audio=audio) class OscPager: def __init__(self, cfg): @@ -495,10 +506,10 @@ class OscPager: def toggleBoard(self, state: bool): osc_ctrl.toggleBoard(self.osc_state.client, state) - def lockWorld(self, state): + def lockWorld(self, state: bool): osc_ctrl.lockWorld(self.osc_state.client, state) - def ellipsis(self, state): + def ellipsis(self, state: bool): osc_ctrl.ellipsis(self.osc_state.client, state) def evaluate(cfg, @@ -622,17 +633,43 @@ def optimize(cfg, return optimized_params def transcriptionThread(ctrl: ThreadControl): + commits = [] while ctrl.run_app: + op = None + while len(ctrl.transcribe_queue) > 0: + cur_op = ctrl.transcribe_queue[0] + ctrl.transcribe_queue = ctrl.transcribe_queue[1:] + if cur_op == TRANSCRIBE_REQ_RESET_COMMITS: + commits = [] + op = None + ctrl.transcribe_queue = [] + break + elif cur_op == TRANSCRIBE_REQ_WHOLE_BUFFER: + op = TRANSCRIBE_REQ_WHOLE_BUFFER + if op == TRANSCRIBE_REQ_WHOLE_BUFFER: + print("Retranscribing committed buffers", file=sys.stderr) + audio = b''.join(commit.audio for commit in commits) + segments = ctrl.whisper.transcribe(audio) + # TODO support concatenation + ctrl.transcript = "".join([s.transcript for s in segments]) + ctrl.preview = ctrl.transcript + commits = [] + commit = ctrl.committer.getDelta() - ctrl.pager.page(ctrl.transcript + commit.preview) - ctrl.transcript += commit.delta + preview = commit.preview + if False and len(preview) > 0: + preview = "[" + preview + "]" if len(commit.delta): - print(f"Transcript: {ctrl.transcript}") + print(f"Transcript: {ctrl.transcript}{preview}") if cfg["enable_debug_mode"]: print(f"commit latency: {commit.latency_s}", file=sys.stderr) print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) + commits.append(commit) + + ctrl.preview = ctrl.transcript + preview + ctrl.transcript += commit.delta def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -662,9 +699,6 @@ def vrInputThread(ctrl: ThreadControl): if state == PAUSE_STATE: ctrl.stream.pause(False) ctrl.stream.getSamples() - ctrl.pager.clear() - if ctrl.cfg["reset_on_toggle"]: - ctrl.transcript = "" elif event.opcode == steamvr.EVENT_FALLING_EDGE: now = time.time() @@ -700,13 +734,16 @@ def vrInputThread(ctrl: ThreadControl): if not ctrl.cfg["use_builtin"]: ctrl.pager.toggleBoard(False) + # Flush the *entire* pipeline. ctrl.stream.pause(True) ctrl.stream.getSamples() + ctrl.collector.dropAudio() ctrl.pager.clear() else: # Short hold if state == RECORD_STATE: print("PAUSED") + ctrl.transcribe_queue.append(TRANSCRIBE_REQ_WHOLE_BUFFER) state = PAUSE_STATE if not ctrl.cfg["use_builtin"]: ctrl.pager.lockWorld(True) @@ -719,6 +756,7 @@ def vrInputThread(ctrl: ThreadControl): elif state == PAUSE_STATE: print("RECORDING", file=sys.stderr) state = RECORD_STATE + ctrl.transcribe_queue.append(TRANSCRIBE_REQ_RESET_COMMITS) if not ctrl.cfg["use_builtin"]: ctrl.pager.toggleBoard(True) ctrl.pager.lockWorld(False) @@ -728,6 +766,7 @@ def vrInputThread(ctrl: ThreadControl): print("Toggle detected, dropping transcript (3)", file=sys.stderr) ctrl.transcript = "" + ctrl.preview = "" #audio_state.drop_transcription = True else: if ctrl.cfg["enable_debug_mode"]: @@ -741,9 +780,13 @@ def vrInputThread(ctrl: ThreadControl): #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON) pass -def kbInputThread( - thread_ctrl): - while thread_ctrl.run_app: +def kbInputThread(ctrl: ThreadControl): + while ctrl.run_app: + time.sleep(0.01) + +def oscThread(ctrl: ThreadControl): + while ctrl.run_app: + ctrl.pager.page(ctrl.preview) time.sleep(0.01) def run(cfg): @@ -764,6 +807,8 @@ def run(cfg): ctrl.committer = committer ctrl.pager = pager ctrl.transcript = "" + ctrl.transcribe_queue = [] + ctrl.preview = "" transcribe_audio_thd = threading.Thread(target=transcriptionThread, args=[ctrl]) transcribe_audio_thd.daemon = True @@ -777,6 +822,10 @@ def run(cfg): kb_input_thd.daemon = True kb_input_thd.start() + osc_thd = threading.Thread(target=oscThread, args=[ctrl]) + osc_thd.daemon = True + osc_thd.start() + for line in sys.stdin: if "exit" in line or "quit" in line: break @@ -784,6 +833,8 @@ def run(cfg): ctrl.run_app = False transcribe_audio_thd.join() vr_input_thd.join() + kb_input_thd.join() + osc_thd.join() if __name__ == "__main__": parser = argparse.ArgumentParser() -- cgit v1.2.3