summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py232
1 files changed, 213 insertions, 19 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 1904526..3f924dd 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -4,6 +4,7 @@ from functools import partial
from pydub import AudioSegment
from whisper.normalizers import EnglishTextNormalizer
from scipy.optimize import minimize
+from emotes_v2 import EmotesState
import app_config
import argparse
@@ -12,10 +13,19 @@ import langcodes
import math
import numpy as np
import os
+import osc_ctrl
import pyaudio
+import steamvr
+import sys
+import threading
import time
import typing
+class ThreadControl:
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.run_app = True
+
class AudioStream():
FORMAT = pyaudio.paInt16
# Size of each frame (audio sample), in bytes. If you change FORMAT, make
@@ -73,6 +83,8 @@ class MicStream(AudioStream):
# as possible. Whenever downstream layers want data, we collapse the
# list into a single array of data (a bytes object).
self.chunks = []
+ # If set, incoming frames are simply discarded.
+ self.paused = False
print(f"Finding mic {which_mic}")
self.dumpMicDevices()
@@ -123,6 +135,9 @@ class MicStream(AudioStream):
AudioStream.__init__(self)
+ def pause(self, state: bool = True):
+ self.paused = state
+
def dumpMicDevices(self):
info = self.p.get_host_api_info_by_index(0)
numdevices = info.get('deviceCount')
@@ -138,6 +153,15 @@ class MicStream(AudioStream):
frame_count,
time_info,
status_flags):
+ if self.paused:
+ # Don't literally pause, just start returning silence. This allows
+ # the `min_segment_age_s` check to work while paused.
+ n_frames = int(frame_count * AudioStream.FPS /
+ float(self.sample_rate))
+ self.chunks.append(np.zeros(n_frames,
+ dtype=np.int16).tobytes())
+ return (frames, pyaudio.paContinue)
+
decimated = b''
# In pyaudio, a `frame` is a single sample of audio data.
frame_len = AudioStream.FRAME_SZ
@@ -367,7 +391,7 @@ class FuzzyRepeatCommitter:
def __init__(self,
collector: AudioCollector,
whisper: Whisper,
- last_n_must_match: int = 2,
+ last_n_must_match: int = 3,
edit_thresh_min: float = 1,
edit_thresh_grow_begin_s: float = 1.5,
edit_thresh_grow_halflife_s: float = 0.5,
@@ -443,10 +467,42 @@ class FuzzyRepeatCommitter:
return TranscriptCommit(candidate.transcript, preview, latency_s,
thresh_at_commit = edit_thresh)
+class OscPager:
+ def __init__(self, cfg):
+ self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"],
+ cfg["rows"],
+ cfg["cols"],
+ cfg["bytes_per_char"])
+ self.cfg = cfg
+ self.next_sync_window = time.time()
+
+ def page(self, text):
+ now = time.time()
+ if now < self.next_sync_window:
+ return
+ if self.cfg["use_builtin"]:
+ osc_ctrl.pageMessageBuiltin(self.osc_state, text)
+ self.next_sync_window = now + 1.5
+ else:
+ osc_ctrl.pageMessage(self.osc_state, text, EmotesState())
+ self.next_sync_window = now + osc_ctrl.SYNC_DELAY_S
+
+ def clear(self):
+ osc_ctrl.clear(self.osc_state)
+
+ def toggleBoard(self, state: bool):
+ osc_ctrl.toggleBoard(self.osc_state.client, state)
+
+ def lockWorld(self, state):
+ osc_ctrl.lockWorld(self.osc_state.client, state)
+
+ def ellipsis(self, state):
+ osc_ctrl.ellipsis(self.osc_state.client, state)
+
def evaluate(cfg,
audio_path: str,
control_path: str,
- last_n_must_match: int = 2,
+ last_n_must_match: int = 3,
edit_thresh_min: float = 1,
edit_thresh_grow_begin_s: float = 1.5,
edit_thresh_grow_halflife_s: float = 0.5,
@@ -486,6 +542,7 @@ def evaluate(cfg,
commits.append(commit)
transcript += commit.delta
+ preview = commit.preview
if False and len(commit.delta):
print(f"transcript: {transcript}")
@@ -562,6 +619,130 @@ def optimize(cfg,
return optimized_params
+def transcriptionThread(ctrl: ThreadControl):
+ while ctrl.run_app:
+ commit = ctrl.committer.getDelta()
+
+ ctrl.pager.page(ctrl.transcript + commit.preview)
+ ctrl.transcript += commit.delta
+
+ if len(commit.delta):
+ print(f"{ctrl.transcript}")
+ if cfg["enable_debug_mode"]:
+ print(f"commit latency: {commit.latency_s}", file=sys.stderr)
+ print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr)
+
+def vrInputThread(ctrl: ThreadControl):
+ RECORD_STATE = 0
+ PAUSE_STATE = 1
+ state = PAUSE_STATE
+
+ hand_id = ctrl.cfg["button"].split()[0]
+ button_id = ctrl.cfg["button"].split()[1]
+
+ # Rough description of state machine:
+ # Single short press: toggle transcription
+ # Medium press: dismiss custom chatbox
+ # Long press: update chatbox in place
+ # Medium press + long press: type transcription
+
+ last_rising = time.time()
+ last_medium_press_end = 0
+
+ button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id)
+ while ctrl.run_app:
+ time.sleep(0.01)
+ event = next(button_generator)
+
+ if event.opcode == steamvr.EVENT_RISING_EDGE:
+ last_rising = time.time()
+
+ if state == PAUSE_STATE:
+ ctrl.stream.pause(False)
+ ctrl.stream.getSamples()
+ ctrl.pager.clear()
+ if ctrl.cfg["reset_on_toggle"]:
+ ctrl.transcript = ""
+
+ elif event.opcode == steamvr.EVENT_FALLING_EDGE:
+ now = time.time()
+ if now - last_rising > 1.5:
+ # Long press: treat as the end of transcription.
+ state = PAUSE_STATE
+ if not ctrl.cfg["use_builtin"]:
+ ctrl.pager.lockWorld(True)
+
+ ctrl.stream.pause(True)
+
+ if last_rising - last_medium_press_end < 1.0:
+ # Type transcription
+ if ctrl.cfg["enable_local_beep"]:
+ #audio_state.audio_events.append(audio_state.AUDIO_EVENT_UPDATE)
+ pass
+ #keyboard.write(audio_state.filtered_text)
+ else:
+ if ctrl.cfg["enable_local_beep"]:
+ #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF)
+ pass
+
+ elif now - last_rising > 0.5:
+ # Medium press
+ print("CLEARING")
+ last_medium_press_end = now
+ state = PAUSE_STATE
+
+ if ctrl.cfg["enable_local_beep"]:
+ #audio_state.audio_events.append(audio_state.AUDIO_EVENT_DISMISS)
+ pass
+
+ if not ctrl.cfg["use_builtin"]:
+ ctrl.pager.toggleBoard(False)
+
+ ctrl.stream.pause(True)
+ ctrl.stream.getSamples()
+ ctrl.pager.clear()
+ else:
+ # Short hold
+ if state == RECORD_STATE:
+ print("PAUSED")
+ state = PAUSE_STATE
+ if not ctrl.cfg["use_builtin"]:
+ ctrl.pager.lockWorld(True)
+
+ ctrl.stream.pause(True)
+
+ if ctrl.cfg["enable_local_beep"]:
+ #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF)
+ pass
+ elif state == PAUSE_STATE:
+ print("RECORDING")
+ state = RECORD_STATE
+ if not ctrl.cfg["use_builtin"]:
+ ctrl.pager.toggleBoard(True)
+ ctrl.pager.lockWorld(False)
+ ctrl.pager.ellipsis(True)
+ if ctrl.cfg["reset_on_toggle"]:
+ if ctrl.cfg["enable_debug_mode"]:
+ print("Toggle detected, dropping transcript (3)")
+ ctrl.transcript = ""
+ #audio_state.drop_transcription = True
+ else:
+ if ctrl.cfg["enable_debug_mode"]:
+ print("Toggle detected, committing preview text (3)")
+ #audio_state.text += audio_state.preview_text
+
+ ctrl.stream.pause(False)
+ ctrl.pager.clear()
+
+ if ctrl.cfg["enable_local_beep"]:
+ #audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON)
+ pass
+
+def kbInputThread(
+ thread_ctrl):
+ while thread_ctrl.run_app:
+ time.sleep(0.01)
+
def run(cfg):
stream = MicStream(cfg["microphone"])
@@ -569,24 +750,37 @@ def run(cfg):
#collector = LengthEnforcingAudioCollector(collector, 5.0)
#collector = NormalizingAudioCollector(collector)
collector = CompressingAudioCollector(collector)
-
whisper = Whisper(collector, cfg)
- com = FuzzyRepeatCommitter(collector, whisper)
- transcript = ""
- commits = []
-
- while True:
- commit = com.getDelta()
-
- if len(commit.delta) > 0:
- commits.append(commit)
-
- transcript += commit.delta
-
- if True and len(commit.delta):
- print(f"{transcript}")
- print(f"commit latency: {commit.latency_s}", file=sys.stderr)
- print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr)
+ committer = FuzzyRepeatCommitter(collector, whisper)
+ pager = OscPager(cfg)
+
+ ctrl = ThreadControl(cfg)
+ ctrl.stream = stream
+ ctrl.collector = collector
+ ctrl.whisper = whisper
+ ctrl.committer = committer
+ ctrl.pager = pager
+ ctrl.transcript = ""
+
+ transcribe_audio_thd = threading.Thread(target=transcriptionThread, args=[ctrl])
+ transcribe_audio_thd.daemon = True
+ transcribe_audio_thd.start()
+
+ vr_input_thd = threading.Thread(target=vrInputThread, args=[ctrl])
+ vr_input_thd.daemon = True
+ vr_input_thd.start()
+
+ kb_input_thd = threading.Thread(target=kbInputThread, args=[ctrl])
+ kb_input_thd.daemon = True
+ kb_input_thd.start()
+
+ for line in sys.stdin:
+ if "exit" in line or "quit" in line:
+ break
+
+ ctrl.run_app = False
+ transcribe_audio_thd.join()
+ vr_input_thd.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser()