diff options
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 91 |
1 files changed, 71 insertions, 20 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 9812535..65a86e3 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -21,10 +21,14 @@ import threading import time import typing +TRANSCRIBE_REQ_RESET_COMMITS = 0 +TRANSCRIBE_REQ_WHOLE_BUFFER = 1 + class ThreadControl: def __init__(self, cfg): self.cfg = cfg self.run_app = True + self.transcribe_queue = [] class AudioStream(): FORMAT = pyaudio.paInt16 @@ -202,19 +206,23 @@ class AudioCollector: self.frames += frames return self.frames - def dropAudioPrefix(self, dur_s: float): + def dropAudioPrefix(self, dur_s: float) -> bytes: n_bytes = int(dur_s * self.stream.FPS) * self.stream.FRAME_SZ n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] self.frames = self.frames[n_bytes:] self.wall_ts = self.wall_ts + self.duration() + return cut_portion - def keepLast(self, dur_s: float): + def keepLast(self, dur_s: float) -> bytes: drop_len = max(0, self.duration() - dur_s) - self.dropAudioPrefix(drop_len) + return self.dropAudioPrefix(drop_len) def dropAudio(self): self.wall_ts += self.duration() + cut_portion = self.frames self.frames = b'' + return cut_portion def duration(self): return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ) @@ -352,8 +360,9 @@ class Whisper: download_root = model_root, local_files_only = download_it) - def transcribe(self) -> typing.List[Segment]: - frames = self.collector.getAudio() + def transcribe(self, frames = None) -> typing.List[Segment]: + if frames is None: + frames = self.collector.getAudio() # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on # [-1, 1]. audio = np.frombuffer(frames, @@ -379,11 +388,13 @@ class TranscriptCommit: delta: str, preview: str, latency_s: int = None, - thresh_at_commit: int = None): + thresh_at_commit: int = None, + audio: bytes = None): self.delta = delta self.preview = preview self.latency_s = latency_s self.thresh_at_commit = thresh_at_commit + self.audio = audio # Commits audio when the transcription layer repeats the same transcript, # within some fuzzy match distance. @@ -391,7 +402,7 @@ class FuzzyRepeatCommitter: def __init__(self, collector: AudioCollector, whisper: Whisper, - last_n_must_match: int = 3, + last_n_must_match: int = 4, edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 0.5, @@ -464,10 +475,10 @@ class FuzzyRepeatCommitter: self.candidates = [] latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts) # Measured to slightly improve performance in benchmark. - self.collector.dropAudioPrefix(candidate.end_ts + 0.10) + audio = self.collector.dropAudioPrefix(candidate.end_ts + 0.10) return TranscriptCommit(candidate.transcript, preview, latency_s, - thresh_at_commit = edit_thresh) + thresh_at_commit=edit_thresh, audio=audio) class OscPager: def __init__(self, cfg): @@ -495,10 +506,10 @@ class OscPager: def toggleBoard(self, state: bool): osc_ctrl.toggleBoard(self.osc_state.client, state) - def lockWorld(self, state): + def lockWorld(self, state: bool): osc_ctrl.lockWorld(self.osc_state.client, state) - def ellipsis(self, state): + def ellipsis(self, state: bool): osc_ctrl.ellipsis(self.osc_state.client, state) def evaluate(cfg, @@ -622,17 +633,43 @@ def optimize(cfg, return optimized_params def transcriptionThread(ctrl: ThreadControl): + commits = [] while ctrl.run_app: + op = None + while len(ctrl.transcribe_queue) > 0: + cur_op = ctrl.transcribe_queue[0] + ctrl.transcribe_queue = ctrl.transcribe_queue[1:] + if cur_op == TRANSCRIBE_REQ_RESET_COMMITS: + commits = [] + op = None + ctrl.transcribe_queue = [] + break + elif cur_op == TRANSCRIBE_REQ_WHOLE_BUFFER: + op = TRANSCRIBE_REQ_WHOLE_BUFFER + if op == TRANSCRIBE_REQ_WHOLE_BUFFER: + print("Retranscribing committed buffers", file=sys.stderr) + audio = b''.join(commit.audio for commit in commits) + segments = ctrl.whisper.transcribe(audio) + # TODO support concatenation + ctrl.transcript = "".join([s.transcript for s in segments]) + ctrl.preview = ctrl.transcript + commits = [] + commit = ctrl.committer.getDelta() - ctrl.pager.page(ctrl.transcript + commit.preview) - ctrl.transcript += commit.delta + preview = commit.preview + if False and len(preview) > 0: + preview = "[" + preview + "]" if len(commit.delta): - print(f"Transcript: {ctrl.transcript}") + print(f"Transcript: {ctrl.transcript}{preview}") if cfg["enable_debug_mode"]: print(f"commit latency: {commit.latency_s}", file=sys.stderr) print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) + commits.append(commit) + + ctrl.preview = ctrl.transcript + preview + ctrl.transcript += commit.delta def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -662,9 +699,6 @@ def vrInputThread(ctrl: ThreadControl): if state == PAUSE_STATE: ctrl.stream.pause(False) ctrl.stream.getSamples() - ctrl.pager.clear() - if ctrl.cfg["reset_on_toggle"]: - ctrl.transcript = "" elif event.opcode == steamvr.EVENT_FALLING_EDGE: now = time.time() @@ -700,13 +734,16 @@ def vrInputThread(ctrl: ThreadControl): if not ctrl.cfg["use_builtin"]: ctrl.pager.toggleBoard(False) + # Flush the *entire* pipeline. ctrl.stream.pause(True) ctrl.stream.getSamples() + ctrl.collector.dropAudio() ctrl.pager.clear() else: # Short hold if state == RECORD_STATE: print("PAUSED") + ctrl.transcribe_queue.append(TRANSCRIBE_REQ_WHOLE_BUFFER) state = PAUSE_STATE if not ctrl.cfg["use_builtin"]: ctrl.pager.lockWorld(True) @@ -719,6 +756,7 @@ def vrInputThread(ctrl: ThreadControl): elif state == PAUSE_STATE: print("RECORDING", file=sys.stderr) state = RECORD_STATE + ctrl.transcribe_queue.append(TRANSCRIBE_REQ_RESET_COMMITS) if not ctrl.cfg["use_builtin"]: ctrl.pager.toggleBoard(True) ctrl.pager.lockWorld(False) @@ -728,6 +766,7 @@ def vrInputThread(ctrl: ThreadControl): print("Toggle detected, dropping transcript (3)", file=sys.stderr) ctrl.transcript = "" + ctrl.preview = "" #audio_state.drop_transcription = True else: if ctrl.cfg["enable_debug_mode"]: @@ -741,9 +780,13 @@ def vrInputThread(ctrl: ThreadControl): #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON) pass -def kbInputThread( - thread_ctrl): - while thread_ctrl.run_app: +def kbInputThread(ctrl: ThreadControl): + while ctrl.run_app: + time.sleep(0.01) + +def oscThread(ctrl: ThreadControl): + while ctrl.run_app: + ctrl.pager.page(ctrl.preview) time.sleep(0.01) def run(cfg): @@ -764,6 +807,8 @@ def run(cfg): ctrl.committer = committer ctrl.pager = pager ctrl.transcript = "" + ctrl.transcribe_queue = [] + ctrl.preview = "" transcribe_audio_thd = threading.Thread(target=transcriptionThread, args=[ctrl]) transcribe_audio_thd.daemon = True @@ -777,6 +822,10 @@ def run(cfg): kb_input_thd.daemon = True kb_input_thd.start() + osc_thd = threading.Thread(target=oscThread, args=[ctrl]) + osc_thd.daemon = True + osc_thd.start() + for line in sys.stdin: if "exit" in line or "quit" in line: break @@ -784,6 +833,8 @@ def run(cfg): ctrl.run_app = False transcribe_audio_thd.join() vr_input_thd.join() + kb_input_thd.join() + osc_thd.join() if __name__ == "__main__": parser = argparse.ArgumentParser() |
