diff options
Diffstat (limited to 'hi.py')
| -rw-r--r-- | hi.py | 77 |
1 files changed, 52 insertions, 25 deletions
@@ -1,8 +1,11 @@ +import app_config import argparse from math import floor, ceil import msvcrt from pythonosc import udp_client import sentencepiece as spm +from shared_thread_data import SharedThreadData +import stt import sys import threading import time @@ -22,10 +25,7 @@ def get_tokenizer(): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block") - parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator") - parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox") - parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox") + parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True) return parser.parse_args() def assert_equal(a, b): @@ -244,27 +244,48 @@ def getOscClient(ip = "127.0.0.1", port = 9000): class InputState: def __init__(self): + self.page = 0 # Initialize the known state of the board to empty array. This will cause # our paging logic to re-send everything the first time around. self.blocks = [] self.visual_pointers = [] pass -def handle_input(state: InputState, line: str, tokenizer, osc_client, args): - line_wrapped = wrap_line(line, args.cols) +def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): + line_wrapped = wrap_line(line, cfg["cols"]) if TESTS_ENABLED: for line in line_wrapped: - assert_equal(len(line), args.cols) + assert_equal(len(line), cfg["cols"]) if LOG_LEVEL == 2: print(f"Wrapped lines: {line_wrapped}") + + # Get several blank lines whenever we roll over. + # It's better for the reader to have some continuity when the board pages + # over. If we simply replaced the entire screen, it would be harder to + # understand. + line_rollover = cfg["rows"] - 2 + blank_line = ' ' * cfg["cols"] + # We show a full page, then only `line_rollover` additional lines per page. + end_ptr = cfg["rows"] + which_page = 0 + while end_ptr < len(line_wrapped): + end_ptr += line_rollover + which_page += 1 + if state.page != which_page: + state.blocks = [] + state.visual_pointers = [] + state.page = which_page + line_wrapped = line_wrapped[end_ptr-cfg["rows"]:] + + # Get blocks and visual pointers. blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, - args.block_width, args.num_blocks) - # Wrap visual pointers. This is mostly done just to stay within our - # limited 8-bit precision budget. The shader also has to wrap in order - # to properly display things. - visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ] + cfg["block_width"], cfg["num_blocks"]) + + # Note that because we only send one page of data at a time, we don't have + # to worry about wrapping visual pointers! We will basically never run out + # of space. indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) - indices = [idx % args.num_blocks for idx in indices] + indices = [idx % cfg["num_blocks"] for idx in indices] # Send only one block at a time to make things snappier in interactive use # case. # TODO use a continuation (yield) instead of returning. Then we can be a @@ -285,41 +306,46 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, args): send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) -class SharedThreadData: - def __init__(self): - self.word = "" - self.word_lock = threading.Lock() - self.exit_event = threading.Event() - -def osc_thread(shared_data: SharedThreadData, cli_args): +def osc_thread(shared_data: SharedThreadData): tokenizer = get_tokenizer() osc_client = getOscClient() # Prime the board print("Priming the board") input_state = InputState() - handle_input(input_state, "", tokenizer, osc_client, cli_args) + handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) while not shared_data.exit_event.is_set(): word_copy = "" with shared_data.word_lock: word_copy = shared_data.word - handle_input(input_state, word_copy, tokenizer, osc_client, cli_args) + handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) time.sleep(0.01) if __name__ == "__main__": cli_args = parse_args() - - shared_data = SharedThreadData() + cfg = app_config.getConfig(cli_args.config) + shared_data = SharedThreadData(cfg) osc_thread = threading.Thread( target=osc_thread, - args=(shared_data, cli_args)) + args=(shared_data,)) osc_thread.start() + transcribe_thread = threading.Thread( + target=stt.transcriptionThread, + args=(shared_data,)) + transcribe_thread.start() + word_is_over = False local_word = "" while True: char_bytes = msvcrt.getch() + if char_bytes == b'\x03': # ctrl+C + break + + time.sleep(0.1) + continue + try: char = char_bytes.decode('utf-8') @@ -354,4 +380,5 @@ if __name__ == "__main__": print(local_word + "_") shared_data.exit_event.set() osc_thread.join() + transcribe_thread.join() |
