summaryrefslogtreecommitdiffstats
path: root/app
diff options
context:
space:
mode:
Diffstat (limited to 'app')
-rw-r--r--app/hi.py308
-rw-r--r--app/keybind_event_machine.py21
-rw-r--r--app/requirements.txt3
-rw-r--r--app/shared_thread_data.py7
-rw-r--r--app/steamvr.py87
-rw-r--r--app/stt.py143
6 files changed, 451 insertions, 118 deletions
diff --git a/app/hi.py b/app/hi.py
index bab0fd4..1297b37 100644
--- a/app/hi.py
+++ b/app/hi.py
@@ -1,25 +1,34 @@
import app_config
import argparse
import io
+import keybind_event_machine
from math import floor, ceil
import msvcrt
import os
from pythonosc import udp_client
import sentencepiece as spm
+import steamvr
from shared_thread_data import SharedThreadData
import stt
import sys
import threading
import time
+import pygame
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+# Initialize pygame mixer
+pygame.mixer.init()
+
TESTS_ENABLED = True
# 0 = quiet, 1 = verbose, 2 = very verbose
LOG_LEVEL = 0
+# Global volume control (0.0 to 1.0)
+VOLUME = 0.3
+
APP_ROOT = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(APP_ROOT)
@@ -315,79 +324,276 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]])
def osc_thread(shared_data: SharedThreadData):
- tokenizer = get_tokenizer()
osc_client = getOscClient()
- # Prime the board
- print("Priming the board")
- input_state = InputState()
- handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
+ def join_segments(a, b):
+ if len(a) > 0 and a[-1] != ' ':
+ return a + ' ' + b
+ else:
+ return a + b
+
+ if shared_data.cfg["use_builtin"]:
+ last_change = time.time()
+ remote_word = ""
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.1)
+ local_word = ""
+ with shared_data.word_lock:
+ local_word = join_segments(shared_data.transcript,
+ shared_data.preview)
+ local_word = local_word[-140:]
+ if local_word == remote_word:
+ continue
+ if time.time() - last_change < 1.5:
+ continue
+ addr = "/chatbox/input"
+ print(f"Send {local_word}", flush=True)
+ osc_client.send_message(addr, (local_word, True, False))
+ last_change = time.time()
+ remote_word = local_word
+ else:
+ # Custom chatbox
+ tokenizer = get_tokenizer()
+
+ # Prime the board
+ print("Priming the board")
+ input_state = InputState()
+ handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
+
+ while not shared_data.exit_event.is_set():
+ word_copy = ""
+ with shared_data.word_lock:
+ word_copy = shared_data.word
+ handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
+ time.sleep(0.01)
+
+
+def vrInputThread(shared_data: SharedThreadData):
+ RECORD_STATE = 0
+ PAUSE_STATE = 1
+ state = PAUSE_STATE
+
+ hand_id = shared_data.cfg["button_hand"]
+ button_id = shared_data.cfg["button_type"]
+ # Rough description of state machine:
+ # Single short press: toggle transcription
+ # Medium press: dismiss custom chatbox
+ # Long press: update chatbox in place
+ # Medium press + long press: type transcription
+
+ last_rising = time.time()
+ last_medium_press_end = 0
+
+ waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav")
+ waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav")
+ waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav")
+ waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav")
+
+ button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id,
+ shared_data=shared_data)
while not shared_data.exit_event.is_set():
- word_copy = ""
+ time.sleep(0.01)
+ try:
+ event = next(button_generator)
+ except StopIteration:
+ break
+
with shared_data.word_lock:
- word_copy = shared_data.word
- handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
+ if not shared_data.stream or not shared_data.collector:
+ continue
+
+ if event.opcode == steamvr.EVENT_RISING_EDGE:
+ last_rising = time.time()
+
+ if state == PAUSE_STATE:
+ shared_data.stream.pause(False)
+ shared_data.stream.getSamples()
+
+ elif event.opcode == steamvr.EVENT_FALLING_EDGE:
+ now = time.time()
+ if now - last_rising > 1.5:
+ # Long press: treat as the end of transcription.
+ state = PAUSE_STATE
+
+ shared_data.stream.pause(True)
+
+ if last_rising - last_medium_press_end < 1.0:
+ # Type transcription
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform3)
+ else:
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform1)
+
+ elif now - last_rising > 0.5:
+ # Medium press
+ print("CLEARING", file=sys.stderr)
+ last_medium_press_end = now
+ state = PAUSE_STATE
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform2)
+
+ # Flush the *entire* pipeline.
+ shared_data.stream.pause(True)
+ shared_data.stream.getSamples()
+ shared_data.collector.dropAudio()
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ continue
+
+ # Short hold
+ if state == RECORD_STATE:
+ print("PAUSED", file=sys.stderr)
+ state = PAUSE_STATE
+
+ shared_data.stream.pause(True)
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform1)
+ elif state == PAUSE_STATE:
+ print("RECORDING", file=sys.stderr)
+ state = RECORD_STATE
+ if shared_data.cfg["reset_on_toggle"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, dropping transcript (3)",
+ file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ #audio_state.drop_transcription = True
+ else:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, committing preview text (3)",
+ file=sys.stderr)
+ #audio_state.text += audio_state.preview_text
+
+ shared_data.stream.pause(False)
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform0)
+
+
+def kbInputThread(shared_data: SharedThreadData):
+ machine = keybind_event_machine.KeybindEventMachine(shared_data.cfg["keybind"])
+ last_press_time = 0
+
+ # double pressing the keybind
+ double_press_timeout = 0.5
+
+ RECORD_STATE = 0
+ PAUSE_STATE = 1
+ state = PAUSE_STATE
+
+ waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav")
+ waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav")
+ waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav")
+ waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav")
+
+ while not shared_data.exit_event.is_set():
time.sleep(0.01)
+ cur_press_time = machine.getNextPressTime()
+ if cur_press_time == 0:
+ continue
+
+ with shared_data.word_lock:
+ if not shared_data.stream or not shared_data.collector:
+ continue
+
+ EVENT_SINGLE_PRESS = 0
+ EVENT_DOUBLE_PRESS = 1
+ if last_press_time == 0:
+ event = EVENT_SINGLE_PRESS
+ elif cur_press_time - last_press_time < double_press_timeout:
+ event = EVENT_DOUBLE_PRESS
+ else:
+ event = EVENT_SINGLE_PRESS
+ last_press_time = cur_press_time
+
+ if event == EVENT_DOUBLE_PRESS:
+ print("CLEARING", file=sys.stderr)
+ state = PAUSE_STATE
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform2)
+
+ # Flush the *entire* pipeline.
+ shared_data.stream.pause(True)
+ shared_data.stream.getSamples()
+ shared_data.collector.dropAudio()
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ continue
+
+ # Short hold
+ if state == RECORD_STATE:
+ print("PAUSED", file=sys.stderr)
+ state = PAUSE_STATE
+
+ shared_data.stream.pause(True)
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform1)
+ elif state == PAUSE_STATE:
+ print("RECORDING", file=sys.stderr)
+ state = RECORD_STATE
+ if shared_data.cfg["reset_on_toggle"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, dropping transcript (2)",
+ file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ else:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, committing preview text (2)",
+ file=sys.stderr)
+ #audio_state.text += audio_state.preview_text
+
+ shared_data.stream.pause(False)
+
+ if shared_data.cfg["enable_local_beep"]:
+ play_sound_with_volume(waveform0)
+
+def play_sound_with_volume(filepath):
+ """Play a WAV file with adjusted volume"""
+ volume = VOLUME
+
+ try:
+ sound = pygame.mixer.Sound(filepath)
+ sound.set_volume(volume)
+ sound.play()
+ except Exception as e:
+ print(f"Error playing sound {filepath}: {e}", file=sys.stderr)
+
if __name__ == "__main__":
cli_args = parse_args()
cfg = app_config.getConfig(cli_args.config)
shared_data = SharedThreadData(cfg)
- if False:
- osc_thread = threading.Thread(
- target=osc_thread,
- args=(shared_data,))
- osc_thread.start()
+ osc_thread = threading.Thread(
+ target=osc_thread,
+ args=(shared_data,))
+ osc_thread.start()
transcribe_thread = threading.Thread(
target=stt.transcriptionThread,
args=(shared_data,))
transcribe_thread.start()
+ vr_input_thd = threading.Thread(target=vrInputThread, args=(shared_data,))
+ vr_input_thd.start()
+
+ kb_input_thd = threading.Thread(target=kbInputThread, args=(shared_data,))
+ kb_input_thd.start()
+
word_is_over = False
local_word = ""
while True:
- char_bytes = msvcrt.getch()
- if char_bytes == b'\x03': # ctrl+C
- break
-
time.sleep(0.1)
continue
-
- try:
- char = char_bytes.decode('utf-8')
- if char == '\r' or char == '\n':
- word_is_over = True
- continue
- except UnicodeDecodeError:
- print(f"Unsupported character: {char_bytes}")
- if char_bytes == b'\x00' or char_bytes == b'\xe0':
- special_char = msvcrt.getch()
- continue
-
- if char_bytes == b'\x03': # ctrl+C
- break
- elif char_bytes == b'\x08': # backspace
- with shared_data.word_lock:
- shared_data.word = shared_data.word[:-1]
- local_word = shared_data.word
- elif char_bytes == b'\x0c': # ctrl+L
- with shared_data.word_lock:
- shared_data.word = ""
- local_word = shared_data.word
- elif word_is_over:
- with shared_data.word_lock:
- shared_data.word = char
- local_word = shared_data.word
- word_is_over = False
- else:
- with shared_data.word_lock:
- shared_data.word += char
- local_word = shared_data.word
- print(local_word + "_")
shared_data.exit_event.set()
- if False:
- osc_thread.join()
+ osc_thread.join()
transcribe_thread.join()
+ vr_input_thd.join()
+ kb_input_thd.join()
diff --git a/app/keybind_event_machine.py b/app/keybind_event_machine.py
new file mode 100644
index 0000000..3ce6794
--- /dev/null
+++ b/app/keybind_event_machine.py
@@ -0,0 +1,21 @@
+import keyboard
+import time
+
+class KeybindEventMachine:
+ def __init__(self, keybind: str):
+ self.keybind = keybind
+ self.events = []
+ keyboard.add_hotkey(keybind, self.onPress)
+
+ def onPress(self) -> None:
+ self.events.append(time.time())
+
+ # Returns the timestamp when the keybind was pressed, or 0 if no keypresses
+ # are queued.
+ def getNextPressTime(self) -> int:
+ if len(self.events) == 0:
+ return 0
+ ret = self.events[0]
+ self.events = self.events[1:]
+ return ret
+
diff --git a/app/requirements.txt b/app/requirements.txt
index f8b7069..e68a16c 100644
--- a/app/requirements.txt
+++ b/app/requirements.txt
@@ -1,8 +1,11 @@
faster-whisper
hf-xet
+keyboard
langcodes
pyaudio
+pygame
pydub
python-osc
sentencepiece
silero-vad
+openvr
diff --git a/app/shared_thread_data.py b/app/shared_thread_data.py
index ba0a419..40885e8 100644
--- a/app/shared_thread_data.py
+++ b/app/shared_thread_data.py
@@ -2,7 +2,12 @@ import threading
class SharedThreadData:
def __init__(self, cfg):
- self.word = ""
+ self.transcript = ""
+ self.preview = ""
+
+ self.stream = None
+ self.collector = None
+
self.word_lock = threading.Lock()
self.exit_event = threading.Event()
self.cfg = cfg
diff --git a/app/steamvr.py b/app/steamvr.py
new file mode 100644
index 0000000..64f34f5
--- /dev/null
+++ b/app/steamvr.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+
+import openvr as vr
+import sys
+import time
+
+EVENT_NONE = 0
+EVENT_RISING_EDGE = 1
+EVENT_FALLING_EDGE = 2
+
+class InputEvent:
+ def __init__(self,
+ opcode: int):
+ self.opcode = opcode
+
+# Checks if the given button on the given controller is pressed.
+def pollButtonPress(
+ hand: str = "right",
+ button: str = "b",
+ shared_data = None # SharedThreadData object
+ ) -> int:
+ hands = {}
+ hands["left"] = vr.TrackedControllerRole_LeftHand
+ hands["right"] = vr.TrackedControllerRole_RightHand
+
+ buttons = {}
+ buttons["a"] = vr.k_EButton_IndexController_A
+ buttons["b"] = vr.k_EButton_IndexController_B
+ buttons["thumbstick"] = vr.k_EButton_Axis0
+
+ system = None
+ first = True
+ while not shared_data.exit_event.is_set() and not system:
+ try:
+ system = vr.init(vr.VRApplication_Background)
+ except Exception as e:
+ if first:
+ print(f"Failed to start steamVR input thread: {repr(e)}", file=sys.stderr)
+ first = False
+ time.sleep(1)
+ last_packet = 0
+ event_high = False
+
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.01)
+
+ lh_idx = system.getTrackedDeviceIndexForControllerRole(hands[hand])
+ #print("left hand device idx: {}".format(lh_idx))
+
+ got_state, state = system.getControllerState(lh_idx)
+ if not got_state:
+ continue
+
+ if state.unPacketNum == last_packet:
+ continue
+
+ # Clicking joysticks and moving joysticks fire the same events. To
+ # differentiate movement from clicking, we create a dead zone: if the event
+ # fires while the stick isn't moved far from center, we assume it's a
+ # click, not movement.
+ dead_zone_radius = 0.7
+
+ button_mask = (1 << buttons[button])
+ ret = EVENT_NONE
+ if (state.ulButtonPressed & button_mask) != 0 and\
+ (state.rAxis[0].x**2 + state.rAxis[0].y**2 < dead_zone_radius**2):
+ #print("button pressed: %016x" % state.ulButtonPressed)
+ #for i in range(0, 5):
+ # print("axis {} x: {} y: {}".format(i, state.rAxis[i].x, state.rAxis[i].y))
+ if not event_high:
+ yield InputEvent(EVENT_RISING_EDGE)
+ event_high = True
+ elif event_high:
+ event_high = False
+ yield InputEvent(EVENT_FALLING_EDGE)
+
+if __name__ == "__main__":
+ gen = pollButtonPress()
+ while True:
+ time.sleep(0.1)
+
+ event = pollButtonPress(session_state)
+ if event == EVENT_RISING_EDGE:
+ print("rising edge")
+ elif event == EVENT_FALLING_EDGE:
+ print("falling edge")
+
diff --git a/app/stt.py b/app/stt.py
index a3988e1..c1f4836 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -299,9 +299,11 @@ class CompressingAudioCollector(AudioCollectorFilter):
class AudioSegmenter:
def __init__(self,
min_silence_ms=250,
- max_speech_s=5):
+ max_speech_s=5,
+ min_speech_duration_ms=100):
self.min_silence_ms = min_silence_ms
self.max_speech_s = max_speech_s
+ self.min_speech_duration_ms = min_speech_duration_ms
# Load Silero VAD model
self.model = load_silero_vad()
@@ -309,6 +311,7 @@ class AudioSegmenter:
self.vad_threshold = 0.3
self.min_silence_duration_ms = min_silence_ms
self.max_speech_duration_s = max_speech_s
+ self.min_speech_duration_ms = min_speech_duration_ms
def segmentAudio(self, audio: bytes):
# Convert audio bytes to numpy array expected by silero-vad
@@ -324,6 +327,7 @@ class AudioSegmenter:
threshold=self.vad_threshold,
min_silence_duration_ms=self.min_silence_duration_ms,
max_speech_duration_s=self.max_speech_duration_s,
+ min_speech_duration_ms=self.min_speech_duration_ms,
return_seconds=False # We want frame indices, not seconds
)
@@ -698,7 +702,8 @@ def transcriptionThread(shared_data: SharedThreadData):
collector = NormalizingAudioCollector(collector)
whisper = Whisper(collector, shared_data.cfg)
segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
- max_speech_s=shared_data.cfg["max_speech_duration_s"])
+ max_speech_s=shared_data.cfg["max_speech_duration_s"],
+ min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"])
committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter)
plugins = []
@@ -715,6 +720,10 @@ def transcriptionThread(shared_data: SharedThreadData):
transcript = ""
preview = ""
+ with shared_data.word_lock:
+ shared_data.stream = stream
+ shared_data.collector = collector
+
print(f"Ready to go!", flush=True)
while not shared_data.exit_event.is_set():
@@ -724,70 +733,72 @@ def transcriptionThread(shared_data: SharedThreadData):
commit = committer.getDelta()
- for plugin in plugins:
- commit = plugin.transform(commit)
-
- if len(commit.delta) > 0 or len(commit.preview) > 0:
- # Avoid re-sending text after long pauses
- if shared_data.cfg["reset_after_silence_s"] > 0:
- silence_duration = 0
- if last_stable_commit:
- last_commit_end_ts = \
- last_stable_commit.start_ts + \
- last_stable_commit.duration_s
- silence_duration = commit.start_ts - last_commit_end_ts
- if silence_duration > shared_data.cfg["reset_after_silence_s"]:
- if shared_data.cfg["enable_debug_mode"]:
- print(f"Resetting transcript after {silence_duration}-second "
- "silence", file=sys.stderr)
- transcript = ""
- preview = ""
- whisper.recent_context = "" # Reset context too
- if commit.delta:
- last_stable_commit = commit
-
- # Hard-cap displayed transcript length at 4k characters to prevent
- # runaway memory use in UI. Keep the full transcript to avoid
- # breaking OSC pager.
- transcript = transcript[-4096:]
- def join_segments(a, b):
- if len(a) > 0 and a[-1] != ' ':
- return a + ' ' + b
- else:
- return a + b
- transcript = join_segments(transcript, commit.delta)
- preview = commit.preview
-
- for filt in filters:
- transcript, preview = filt.transform(transcript, preview)
-
- try:
- print(f"Transcript: {transcript}", flush=True)
- except UnicodeEncodeError:
- print("Failed to encode transcript - discarding delta",
- file=sys.stderr)
- continue
- try:
- print(f"Preview: {preview}", flush=True)
- except UnicodeEncodeError:
- print("Failed to encode preview - discarding", file=sys.stderr)
-
- with shared_data.word_lock:
- shared_data.word = join_segments(transcript, preview)
-
- if shared_data.cfg["enable_debug_mode"]:
- print(f"commit latency: {commit.latency_s}", file=sys.stderr)
- print(f"commit thresh: {commit.thresh_at_commit}",
- file=sys.stderr)
-
- if len(transcript) > 0 and \
- (not transcript.endswith(' ')) and \
- (not commit.delta.startswith(' ')):
- commit.delta = ' ' + commit.delta
- if len(commit.delta) > 0 and \
- (not commit.delta.endswith(' ')) and \
- (not commit.preview.startswith(' ')):
- commit.preview = ' ' + commit.preview
+ with shared_data.word_lock:
+ for plugin in plugins:
+ commit = plugin.transform(commit)
+
+ if len(commit.delta) > 0 or len(commit.preview) > 0:
+ # Avoid re-sending text after long pauses
+ if shared_data.cfg["reset_after_silence_s"] > 0:
+ silence_duration = 0
+ if last_stable_commit:
+ last_commit_end_ts = \
+ last_stable_commit.start_ts + \
+ last_stable_commit.duration_s
+ silence_duration = commit.start_ts - last_commit_end_ts
+ if silence_duration > shared_data.cfg["reset_after_silence_s"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"Resetting transcript after {silence_duration}-second "
+ "silence", file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ whisper.recent_context = "" # Reset context too
+ if commit.delta:
+ last_stable_commit = commit
+
+ # Hard-cap displayed transcript length to prevent
+ # runaway memory use in UI. Keep the full transcript to avoid
+ # breaking OSC pager.
+ if len(shared_data.transcript) >= 1024:
+ shared_data.transcript = shared_data.transcript[-512:]
+ def join_segments(a, b):
+ if len(a) > 0 and a[-1] != ' ':
+ return a + ' ' + b
+ else:
+ return a + b
+ shared_data.transcript = \
+ join_segments(shared_data.transcript, commit.delta)
+ shared_data.preview = commit.preview
+
+ for filt in filters:
+ shared_data.transcript, shared_data.preview = \
+ filt.transform(shared_data.transcript,
+ shared_data.preview)
+
+ try:
+ print(f"Transcript: {shared_data.transcript}", flush=True)
+ except UnicodeEncodeError:
+ print("Failed to encode transcript - discarding delta",
+ file=sys.stderr)
+ continue
+ try:
+ print(f"Preview: {shared_data.preview}", flush=True)
+ except UnicodeEncodeError:
+ print("Failed to encode preview - discarding", file=sys.stderr)
+
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"commit latency: {commit.latency_s}", file=sys.stderr)
+ print(f"commit thresh: {commit.thresh_at_commit}",
+ file=sys.stderr)
+
+ if len(shared_data.transcript) > 0 and \
+ (not shared_data.transcript.endswith(' ')) and \
+ (not commit.delta.startswith(' ')):
+ commit.delta = ' ' + commit.delta
+ if len(commit.delta) > 0 and \
+ (not commit.delta.endswith(' ')) and \
+ (not commit.preview.startswith(' ')):
+ commit.preview = ' ' + commit.preview
for plugin in plugins:
plugin.stop()
for filt in filters: