diff options
| author | yum <yum.food.vr@gmail.com> | 2023-09-03 18:27:45 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-09-03 18:27:45 -0700 |
| commit | ae5db8b21e7db2ab9941cca47a5d57352d3bb1fa (patch) | |
| tree | c2c4ea74158c9b9c76db519819957f1317af59c4 /Scripts/transcribe_v2.py | |
| parent | 2a4c6051acd8140bde6c1abad62bd613673de4b4 (diff) | |
Add threads to transcribe_v2.py
Four threads:
* Main thread
* Transcription (mic -> collector -> whisper -> committer -> pager)
* VR input
* Keyboard input
Also:
* add OscPager class to encapsulate all OSC interactions.
* bump `last_n_must_match` from 2 to 3 to reduce hallucinations
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 232 |
1 files changed, 213 insertions, 19 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 1904526..3f924dd 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -4,6 +4,7 @@ from functools import partial from pydub import AudioSegment from whisper.normalizers import EnglishTextNormalizer from scipy.optimize import minimize +from emotes_v2 import EmotesState import app_config import argparse @@ -12,10 +13,19 @@ import langcodes import math import numpy as np import os +import osc_ctrl import pyaudio +import steamvr +import sys +import threading import time import typing +class ThreadControl: + def __init__(self, cfg): + self.cfg = cfg + self.run_app = True + class AudioStream(): FORMAT = pyaudio.paInt16 # Size of each frame (audio sample), in bytes. If you change FORMAT, make @@ -73,6 +83,8 @@ class MicStream(AudioStream): # as possible. Whenever downstream layers want data, we collapse the # list into a single array of data (a bytes object). self.chunks = [] + # If set, incoming frames are simply discarded. + self.paused = False print(f"Finding mic {which_mic}") self.dumpMicDevices() @@ -123,6 +135,9 @@ class MicStream(AudioStream): AudioStream.__init__(self) + def pause(self, state: bool = True): + self.paused = state + def dumpMicDevices(self): info = self.p.get_host_api_info_by_index(0) numdevices = info.get('deviceCount') @@ -138,6 +153,15 @@ class MicStream(AudioStream): frame_count, time_info, status_flags): + if self.paused: + # Don't literally pause, just start returning silence. This allows + # the `min_segment_age_s` check to work while paused. + n_frames = int(frame_count * AudioStream.FPS / + float(self.sample_rate)) + self.chunks.append(np.zeros(n_frames, + dtype=np.int16).tobytes()) + return (frames, pyaudio.paContinue) + decimated = b'' # In pyaudio, a `frame` is a single sample of audio data. frame_len = AudioStream.FRAME_SZ @@ -367,7 +391,7 @@ class FuzzyRepeatCommitter: def __init__(self, collector: AudioCollector, whisper: Whisper, - last_n_must_match: int = 2, + last_n_must_match: int = 3, edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 0.5, @@ -443,10 +467,42 @@ class FuzzyRepeatCommitter: return TranscriptCommit(candidate.transcript, preview, latency_s, thresh_at_commit = edit_thresh) +class OscPager: + def __init__(self, cfg): + self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"], + cfg["rows"], + cfg["cols"], + cfg["bytes_per_char"]) + self.cfg = cfg + self.next_sync_window = time.time() + + def page(self, text): + now = time.time() + if now < self.next_sync_window: + return + if self.cfg["use_builtin"]: + osc_ctrl.pageMessageBuiltin(self.osc_state, text) + self.next_sync_window = now + 1.5 + else: + osc_ctrl.pageMessage(self.osc_state, text, EmotesState()) + self.next_sync_window = now + osc_ctrl.SYNC_DELAY_S + + def clear(self): + osc_ctrl.clear(self.osc_state) + + def toggleBoard(self, state: bool): + osc_ctrl.toggleBoard(self.osc_state.client, state) + + def lockWorld(self, state): + osc_ctrl.lockWorld(self.osc_state.client, state) + + def ellipsis(self, state): + osc_ctrl.ellipsis(self.osc_state.client, state) + def evaluate(cfg, audio_path: str, control_path: str, - last_n_must_match: int = 2, + last_n_must_match: int = 3, edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 0.5, @@ -486,6 +542,7 @@ def evaluate(cfg, commits.append(commit) transcript += commit.delta + preview = commit.preview if False and len(commit.delta): print(f"transcript: {transcript}") @@ -562,6 +619,130 @@ def optimize(cfg, return optimized_params +def transcriptionThread(ctrl: ThreadControl): + while ctrl.run_app: + commit = ctrl.committer.getDelta() + + ctrl.pager.page(ctrl.transcript + commit.preview) + ctrl.transcript += commit.delta + + if len(commit.delta): + print(f"{ctrl.transcript}") + if cfg["enable_debug_mode"]: + print(f"commit latency: {commit.latency_s}", file=sys.stderr) + print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) + +def vrInputThread(ctrl: ThreadControl): + RECORD_STATE = 0 + PAUSE_STATE = 1 + state = PAUSE_STATE + + hand_id = ctrl.cfg["button"].split()[0] + button_id = ctrl.cfg["button"].split()[1] + + # Rough description of state machine: + # Single short press: toggle transcription + # Medium press: dismiss custom chatbox + # Long press: update chatbox in place + # Medium press + long press: type transcription + + last_rising = time.time() + last_medium_press_end = 0 + + button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id) + while ctrl.run_app: + time.sleep(0.01) + event = next(button_generator) + + if event.opcode == steamvr.EVENT_RISING_EDGE: + last_rising = time.time() + + if state == PAUSE_STATE: + ctrl.stream.pause(False) + ctrl.stream.getSamples() + ctrl.pager.clear() + if ctrl.cfg["reset_on_toggle"]: + ctrl.transcript = "" + + elif event.opcode == steamvr.EVENT_FALLING_EDGE: + now = time.time() + if now - last_rising > 1.5: + # Long press: treat as the end of transcription. + state = PAUSE_STATE + if not ctrl.cfg["use_builtin"]: + ctrl.pager.lockWorld(True) + + ctrl.stream.pause(True) + + if last_rising - last_medium_press_end < 1.0: + # Type transcription + if ctrl.cfg["enable_local_beep"]: + #audio_state.audio_events.append(audio_state.AUDIO_EVENT_UPDATE) + pass + #keyboard.write(audio_state.filtered_text) + else: + if ctrl.cfg["enable_local_beep"]: + #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF) + pass + + elif now - last_rising > 0.5: + # Medium press + print("CLEARING") + last_medium_press_end = now + state = PAUSE_STATE + + if ctrl.cfg["enable_local_beep"]: + #audio_state.audio_events.append(audio_state.AUDIO_EVENT_DISMISS) + pass + + if not ctrl.cfg["use_builtin"]: + ctrl.pager.toggleBoard(False) + + ctrl.stream.pause(True) + ctrl.stream.getSamples() + ctrl.pager.clear() + else: + # Short hold + if state == RECORD_STATE: + print("PAUSED") + state = PAUSE_STATE + if not ctrl.cfg["use_builtin"]: + ctrl.pager.lockWorld(True) + + ctrl.stream.pause(True) + + if ctrl.cfg["enable_local_beep"]: + #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF) + pass + elif state == PAUSE_STATE: + print("RECORDING") + state = RECORD_STATE + if not ctrl.cfg["use_builtin"]: + ctrl.pager.toggleBoard(True) + ctrl.pager.lockWorld(False) + ctrl.pager.ellipsis(True) + if ctrl.cfg["reset_on_toggle"]: + if ctrl.cfg["enable_debug_mode"]: + print("Toggle detected, dropping transcript (3)") + ctrl.transcript = "" + #audio_state.drop_transcription = True + else: + if ctrl.cfg["enable_debug_mode"]: + print("Toggle detected, committing preview text (3)") + #audio_state.text += audio_state.preview_text + + ctrl.stream.pause(False) + ctrl.pager.clear() + + if ctrl.cfg["enable_local_beep"]: + #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON) + pass + +def kbInputThread( + thread_ctrl): + while thread_ctrl.run_app: + time.sleep(0.01) + def run(cfg): stream = MicStream(cfg["microphone"]) @@ -569,24 +750,37 @@ def run(cfg): #collector = LengthEnforcingAudioCollector(collector, 5.0) #collector = NormalizingAudioCollector(collector) collector = CompressingAudioCollector(collector) - whisper = Whisper(collector, cfg) - com = FuzzyRepeatCommitter(collector, whisper) - transcript = "" - commits = [] - - while True: - commit = com.getDelta() - - if len(commit.delta) > 0: - commits.append(commit) - - transcript += commit.delta - - if True and len(commit.delta): - print(f"{transcript}") - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) + committer = FuzzyRepeatCommitter(collector, whisper) + pager = OscPager(cfg) + + ctrl = ThreadControl(cfg) + ctrl.stream = stream + ctrl.collector = collector + ctrl.whisper = whisper + ctrl.committer = committer + ctrl.pager = pager + ctrl.transcript = "" + + transcribe_audio_thd = threading.Thread(target=transcriptionThread, args=[ctrl]) + transcribe_audio_thd.daemon = True + transcribe_audio_thd.start() + + vr_input_thd = threading.Thread(target=vrInputThread, args=[ctrl]) + vr_input_thd.daemon = True + vr_input_thd.start() + + kb_input_thd = threading.Thread(target=kbInputThread, args=[ctrl]) + kb_input_thd.daemon = True + kb_input_thd.start() + + for line in sys.stdin: + if "exit" in line or "quit" in line: + break + + ctrl.run_app = False + transcribe_audio_thd.join() + vr_input_thd.join() if __name__ == "__main__": parser = argparse.ArgumentParser() |
