summaryrefslogtreecommitdiffstats
path: root/hi.py
diff options
context:
space:
mode:
Diffstat (limited to 'hi.py')
-rw-r--r--hi.py77
1 files changed, 52 insertions, 25 deletions
diff --git a/hi.py b/hi.py
index 7c68071..0129958 100644
--- a/hi.py
+++ b/hi.py
@@ -1,8 +1,11 @@
+import app_config
import argparse
from math import floor, ceil
import msvcrt
from pythonosc import udp_client
import sentencepiece as spm
+from shared_thread_data import SharedThreadData
+import stt
import sys
import threading
import time
@@ -22,10 +25,7 @@ def get_tokenizer():
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block")
- parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator")
- parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox")
- parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox")
+ parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True)
return parser.parse_args()
def assert_equal(a, b):
@@ -244,27 +244,48 @@ def getOscClient(ip = "127.0.0.1", port = 9000):
class InputState:
def __init__(self):
+ self.page = 0
# Initialize the known state of the board to empty array. This will cause
# our paging logic to re-send everything the first time around.
self.blocks = []
self.visual_pointers = []
pass
-def handle_input(state: InputState, line: str, tokenizer, osc_client, args):
- line_wrapped = wrap_line(line, args.cols)
+def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
+ line_wrapped = wrap_line(line, cfg["cols"])
if TESTS_ENABLED:
for line in line_wrapped:
- assert_equal(len(line), args.cols)
+ assert_equal(len(line), cfg["cols"])
if LOG_LEVEL == 2:
print(f"Wrapped lines: {line_wrapped}")
+
+ # Get several blank lines whenever we roll over.
+ # It's better for the reader to have some continuity when the board pages
+ # over. If we simply replaced the entire screen, it would be harder to
+ # understand.
+ line_rollover = cfg["rows"] - 2
+ blank_line = ' ' * cfg["cols"]
+ # We show a full page, then only `line_rollover` additional lines per page.
+ end_ptr = cfg["rows"]
+ which_page = 0
+ while end_ptr < len(line_wrapped):
+ end_ptr += line_rollover
+ which_page += 1
+ if state.page != which_page:
+ state.blocks = []
+ state.visual_pointers = []
+ state.page = which_page
+ line_wrapped = line_wrapped[end_ptr-cfg["rows"]:]
+
+ # Get blocks and visual pointers.
blocks, visual_pointers = get_blocks(line_wrapped, tokenizer,
- args.block_width, args.num_blocks)
- # Wrap visual pointers. This is mostly done just to stay within our
- # limited 8-bit precision budget. The shader also has to wrap in order
- # to properly display things.
- visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ]
+ cfg["block_width"], cfg["num_blocks"])
+
+ # Note that because we only send one page of data at a time, we don't have
+ # to worry about wrapping visual pointers! We will basically never run out
+ # of space.
indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers)
- indices = [idx % args.num_blocks for idx in indices]
+ indices = [idx % cfg["num_blocks"] for idx in indices]
# Send only one block at a time to make things snappier in interactive use
# case.
# TODO use a continuation (yield) instead of returning. Then we can be a
@@ -285,41 +306,46 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, args):
send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]])
-class SharedThreadData:
- def __init__(self):
- self.word = ""
- self.word_lock = threading.Lock()
- self.exit_event = threading.Event()
-
-def osc_thread(shared_data: SharedThreadData, cli_args):
+def osc_thread(shared_data: SharedThreadData):
tokenizer = get_tokenizer()
osc_client = getOscClient()
# Prime the board
print("Priming the board")
input_state = InputState()
- handle_input(input_state, "", tokenizer, osc_client, cli_args)
+ handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
while not shared_data.exit_event.is_set():
word_copy = ""
with shared_data.word_lock:
word_copy = shared_data.word
- handle_input(input_state, word_copy, tokenizer, osc_client, cli_args)
+ handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
time.sleep(0.01)
if __name__ == "__main__":
cli_args = parse_args()
-
- shared_data = SharedThreadData()
+ cfg = app_config.getConfig(cli_args.config)
+ shared_data = SharedThreadData(cfg)
osc_thread = threading.Thread(
target=osc_thread,
- args=(shared_data, cli_args))
+ args=(shared_data,))
osc_thread.start()
+ transcribe_thread = threading.Thread(
+ target=stt.transcriptionThread,
+ args=(shared_data,))
+ transcribe_thread.start()
+
word_is_over = False
local_word = ""
while True:
char_bytes = msvcrt.getch()
+ if char_bytes == b'\x03': # ctrl+C
+ break
+
+ time.sleep(0.1)
+ continue
+
try:
char = char_bytes.decode('utf-8')
@@ -354,4 +380,5 @@ if __name__ == "__main__":
print(local_word + "_")
shared_data.exit_event.set()
osc_thread.join()
+ transcribe_thread.join()