summaryrefslogtreecommitdiffstats
path: root/app/hi.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/hi.py')
-rw-r--r--app/hi.py65
1 files changed, 29 insertions, 36 deletions
diff --git a/app/hi.py b/app/hi.py
index bb09418..7ea4616 100644
--- a/app/hi.py
+++ b/app/hi.py
@@ -2,9 +2,11 @@ import app_config
import argparse
import io
import keybind_event_machine
+from logger import log, log_err
from math import floor, ceil
import msvcrt
import os
+import pygame
from pythonosc import udp_client
import sentencepiece as spm
import steamvr
@@ -13,10 +15,6 @@ import stt
import sys
import threading
import time
-import pygame
-
-sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
-sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
# Initialize pygame mixer
pygame.mixer.init()
@@ -31,10 +29,10 @@ PROJECT_ROOT = os.path.dirname(APP_ROOT)
def get_tokenizer():
model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model")
- print(f"Loading SentencePiece tokenizer from: {model_path}")
+ log(f"Loading SentencePiece tokenizer from: {model_path}")
sp = spm.SentencePieceProcessor()
sp.load(model_path)
- print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
+ log(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
return sp
def parse_args():
@@ -137,16 +135,16 @@ def wrap_line(line: str, cols):
def get_blocks(lines, tokenizer, block_width, num_blocks):
if LOG_LEVEL == 2:
- print(f"Lines sent to tokenizer: {''.join(lines)}")
+ log(f"Lines sent to tokenizer: {''.join(lines)}")
tokens = tokenizer.encode_as_ids(''.join(lines))
if LOG_LEVEL == 2:
- print(f"Tokens: {tokens}")
+ log(f"Tokens: {tokens}")
pieces = []
for tok in tokens:
piece = tokenizer.id_to_piece(tok)
pieces.append(piece)
if LOG_LEVEL == 2:
- print(f"Pieces: {pieces}")
+ log(f"Pieces: {pieces}")
# Group tokens into blocks and pad with empty characters.
# Also get visual pointers - the location where each block will be rendered.
@@ -168,8 +166,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks):
return (blocks, visual_pointers)
blocks, visual_pointers = get_blocks()
if LOG_LEVEL == 2:
- print(f"Blocks: {blocks}")
- print(f"Visual pointers: {visual_pointers}")
+ log(f"Blocks: {blocks}")
+ log(f"Visual pointers: {visual_pointers}")
# Set all blocks up to the next `num_blocks` boundary to blank tokens.
# This handles the edge case where a prior message wrote data there which
@@ -183,8 +181,8 @@ def get_blocks(lines, tokenizer, block_width, num_blocks):
return blocks, visual_pointers
blocks, visual_pointers = pad_blocks(blocks, visual_pointers)
if LOG_LEVEL == 2:
- print(f"Blocks (padded): {blocks}")
- print(f"Visual pointers (padded): {visual_pointers}")
+ log(f"Blocks (padded): {blocks}")
+ log(f"Visual pointers (padded): {visual_pointers}")
return blocks, visual_pointers
@@ -223,11 +221,10 @@ def send_data(osc_client, indices, blocks, visual_pointers):
blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks)
if LOG_LEVEL == 2:
- print(f"Blocks (byte 00): {blocks_byte00}")
- print(f"Blocks (byte 01): {blocks_byte01}")
+ log(f"Blocks (byte 00): {blocks_byte00}")
+ log(f"Blocks (byte 01): {blocks_byte01}")
def send_osc(osc_client, addr, data):
- #print(f"Sending {data} to {addr}")
osc_client.send_message(addr, data)
for i in range(0, len(blocks)):
@@ -241,7 +238,7 @@ def send_data(osc_client, indices, blocks, visual_pointers):
addr = "/avatar/parameters/" + vp_param
send_osc(osc_client, addr, vp_float)
if LOG_LEVEL == 2:
- print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
+ log(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
for j in range(0, len(blocks[i])):
byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5
byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5
@@ -271,7 +268,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
for line in line_wrapped:
assert_equal(len(line), cfg["cols"])
if LOG_LEVEL == 2:
- print(f"Wrapped lines: {line_wrapped}")
+ log(f"Wrapped lines: {line_wrapped}")
# Get several blank lines whenever we roll over.
# It's better for the reader to have some continuity when the board pages
@@ -312,7 +309,7 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
state.blocks.append(diff_blocks[0])
state.visual_pointers.append(diff_visual_pointers[0])
elif indices[0] > len(state.blocks):
- print(f"This should never happen!")
+ log(f"This should never happen!")
sys.exit(1)
else:
state.blocks[indices[0]] = diff_blocks[0]
@@ -345,7 +342,7 @@ def osc_thread(shared_data: SharedThreadData):
continue
addr = "/chatbox/input"
if shared_data.cfg["enable_debug_mode"]:
- print(f"Send {local_word}", flush=True)
+ log(f"Send {local_word}")
osc_client.send_message(addr, (local_word, True, False))
last_change = time.time()
remote_word = local_word
@@ -354,7 +351,7 @@ def osc_thread(shared_data: SharedThreadData):
tokenizer = get_tokenizer()
# Prime the board
- print("Priming the board")
+ log("Priming the board")
input_state = InputState()
handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
@@ -424,7 +421,7 @@ def vrInputThread(shared_data: SharedThreadData):
elif now - last_rising > 0.5:
# Medium press
- print("CLEARING", file=sys.stderr)
+ log_err("CLEARING")
last_medium_press_end = now
state = PAUSE_STATE
play_sound_with_volume(waveform2, shared_data.cfg)
@@ -439,25 +436,23 @@ def vrInputThread(shared_data: SharedThreadData):
# Short hold
if state == RECORD_STATE:
- print("PAUSED", file=sys.stderr)
+ log_err("PAUSED")
state = PAUSE_STATE
shared_data.stream.pause(True)
play_sound_with_volume(waveform1, shared_data.cfg)
elif state == PAUSE_STATE:
- print("RECORDING", file=sys.stderr)
+ log_err("RECORDING")
state = RECORD_STATE
if shared_data.cfg["reset_on_toggle"]:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, dropping transcript (3)",
- file=sys.stderr)
+ log_err("Toggle detected, dropping transcript (3)")
shared_data.transcript = ""
shared_data.preview = ""
#audio_state.drop_transcription = True
else:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, committing preview text (3)",
- file=sys.stderr)
+ log_err("Toggle detected, committing preview text (3)")
#audio_state.text += audio_state.preview_text
shared_data.stream.pause(False)
@@ -502,7 +497,7 @@ def kbInputThread(shared_data: SharedThreadData):
last_press_time = cur_press_time
if event == EVENT_DOUBLE_PRESS:
- print("CLEARING", file=sys.stderr)
+ log_err("CLEARING")
state = PAUSE_STATE
play_sound_with_volume(waveform2, shared_data.cfg)
@@ -516,23 +511,21 @@ def kbInputThread(shared_data: SharedThreadData):
# Short hold
if state == RECORD_STATE:
- print("PAUSED", file=sys.stderr)
+ log_err("PAUSED")
state = PAUSE_STATE
shared_data.stream.pause(True)
play_sound_with_volume(waveform1, shared_data.cfg)
elif state == PAUSE_STATE:
- print("RECORDING", file=sys.stderr)
+ log_err("RECORDING")
state = RECORD_STATE
if shared_data.cfg["reset_on_toggle"]:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, dropping transcript (2)",
- file=sys.stderr)
+ log_err("Toggle detected, dropping transcript (2)")
shared_data.transcript = ""
shared_data.preview = ""
else:
if shared_data.cfg["enable_debug_mode"]:
- print("Toggle detected, committing preview text (2)",
- file=sys.stderr)
+ log_err("Toggle detected, committing preview text (2)")
shared_data.stream.pause(False)
play_sound_with_volume(waveform0, shared_data.cfg)
@@ -545,7 +538,7 @@ def play_sound_with_volume(filepath, cfg):
sound.set_volume(volume * 0.01)
sound.play()
except Exception as e:
- print(f"Error playing sound {filepath}: {e}", file=sys.stderr)
+ log_err(f"Error playing sound {filepath}: {e}")
if __name__ == "__main__":
cli_args = parse_args()