summaryrefslogtreecommitdiffstats
path: root/hi.py
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-05-29 15:03:06 -0700
committeryum <yum.food.vr@gmail.com>2025-05-29 15:03:06 -0700
commit0ebc79354ace812731a5c9a0a670cecd1ea941d7 (patch)
tree10a83d8761f365a320919d219b4a6f653db31c4d /hi.py
parentf8e95c0b85288a10f435e0edabf43defa0c303ac (diff)
Move core app logic into folder
Diffstat (limited to 'hi.py')
-rw-r--r--hi.py384
1 files changed, 0 insertions, 384 deletions
diff --git a/hi.py b/hi.py
deleted file mode 100644
index 0129958..0000000
--- a/hi.py
+++ /dev/null
@@ -1,384 +0,0 @@
-import app_config
-import argparse
-from math import floor, ceil
-import msvcrt
-from pythonosc import udp_client
-import sentencepiece as spm
-from shared_thread_data import SharedThreadData
-import stt
-import sys
-import threading
-import time
-
-TESTS_ENABLED = True
-
-# 0 = quiet, 1 = verbose, 2 = very verbose
-LOG_LEVEL = 0
-
-def get_tokenizer():
- model_path = "./custom_unigram_tokenizer_65k/unigram.model"
- print(f"Loading SentencePiece tokenizer from: {model_path}")
- sp = spm.SentencePieceProcessor()
- sp.load(model_path)
- print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
- return sp
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True)
- return parser.parse_args()
-
-def assert_equal(a, b):
- err_msg = f"{a} != {b}"
- assert a == b, err_msg
-
-# Turn a whitespace-delimited string into a list of strings no longer than
-# `cols`.
-# Preferentially breaks strings at whitespace boundaries. Preserves whitespace
-# between words, except if that whitespace comes between lines. Breaks words
-# longer than `cols` with a hyphen.
-def wrap_line(line: str, cols):
- # First, split line into alternating chunks of words and whitespace.
- def get_sequences(line):
- is_space = False
- sequences = []
- seq_start = 0
- seq_end = -1
- for i in range(0, len(line)):
- if line[i].isspace():
- if is_space:
- seq_end = i
- continue
- # We were looking at text, now we see whitespace.
- seq = line[seq_start:seq_end+1]
- if len(seq) > 0:
- sequences.append(seq)
- seq_start = i
- seq_end = i
- is_space = True
- else:
- if not is_space:
- seq_end = i
- continue
- # We were looking at whitespace, now we see text.
- seq = line[seq_start:seq_end+1]
- if len(seq) > 0:
- sequences.append(seq)
- seq_start = i
- seq_end = i
- is_space = False
- sequences.append(line[seq_start:seq_end+1])
- return sequences
- if TESTS_ENABLED:
- assert_equal(get_sequences("foo"), ["foo"])
- assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"])
- assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"])
- assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"])
-
- # Next, greedily construct lines out of those sequences.
- # Whitespace gets treated specially. If it would push us over the limit, we
- # end the line and drop the whitespace.
- sequences = get_sequences(line)
- def coalesce_sequences(sequences, cols):
- cur_line = ""
- lines = []
- for seq in sequences:
- if len(cur_line) + len(seq) <= cols:
- cur_line += seq
- continue
- if seq.isspace():
- lines.append(cur_line)
- cur_line = ""
- continue
- if len(cur_line) > 0:
- lines.append(cur_line)
- # Edge case: text sequence is longer than a line.
- while len(seq) > cols:
- seq_prefix = seq[0:cols-1] + "-"
- seq = seq[cols-1:]
- lines.append(seq_prefix)
- cur_line = seq
- if len(cur_line) > 0:
- lines.append(cur_line)
- return lines
- if TESTS_ENABLED:
- assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"])
- assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"])
- assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"])
- assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"])
- assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"])
-
- lines = coalesce_sequences(sequences, cols)
-
- # Next, pad each line with whitespace.
- def pad_lines(lines, cols):
- for i in range(0, len(lines)):
- lines[i] += ' ' * (cols - len(lines[i]))
- return lines
- if TESTS_ENABLED:
- assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "])
- assert_equal(pad_lines(["foo"], 2), ["foo"])
-
- return pad_lines(lines, cols)
-
-def get_blocks(lines, tokenizer, block_width, num_blocks):
- if LOG_LEVEL == 2:
- print(f"Lines sent to tokenizer: {''.join(lines)}")
- tokens = tokenizer.encode_as_ids(''.join(lines))
- if LOG_LEVEL == 2:
- print(f"Tokens: {tokens}")
- pieces = []
- for tok in tokens:
- piece = tokenizer.id_to_piece(tok)
- pieces.append(piece)
- if LOG_LEVEL == 2:
- print(f"Pieces: {pieces}")
-
- # Group tokens into blocks and pad with empty characters.
- # Also get visual pointers - the location where each block will be rendered.
- def get_blocks():
- blocks = []
- visual_pointer = 0
- visual_pointers = []
- for i in range(0, ceil(len(tokens) / block_width)):
- visual_pointers.append(visual_pointer)
- block = []
- for j in range(0, block_width):
- if i*block_width + j >= len(tokens):
- # Pad block with empty characters. 65535 is a special token.
- block += [65535] * (block_width - len(block))
- break
- block.append(tokens[i*block_width+j])
- visual_pointer += len(pieces[i*block_width+j])
- blocks.append(block)
- return (blocks, visual_pointers)
- blocks, visual_pointers = get_blocks()
- if LOG_LEVEL == 2:
- print(f"Blocks: {blocks}")
- print(f"Visual pointers: {visual_pointers}")
-
- # Set all blocks up to the next `num_blocks` boundary to blank tokens.
- # This handles the edge case where a prior message wrote data there which
- # is covering up our new data.
- def pad_blocks(blocks, visual_pointers):
- cur_num_blocks = len(blocks)
- num_pad_blocks = num_blocks - cur_num_blocks
- for i in range(0, num_pad_blocks):
- blocks.append([65535] * block_width)
- visual_pointers.append(255)
- return blocks, visual_pointers
- blocks, visual_pointers = pad_blocks(blocks, visual_pointers)
- if LOG_LEVEL == 2:
- print(f"Blocks (padded): {blocks}")
- print(f"Visual pointers (padded): {visual_pointers}")
-
- return blocks, visual_pointers
-
-def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks,
- cur_visual_pointers):
- diff_indices = []
- diff_blocks = []
- diff_visual_pointers = []
-
- for i in range(0, len(cur_blocks)):
- if i >= len(prev_blocks):
- diff_blocks.append(cur_blocks[i])
- diff_visual_pointers.append(cur_visual_pointers[i])
- diff_indices.append(i)
- continue
- if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]:
- diff_blocks.append(cur_blocks[i])
- diff_visual_pointers.append(cur_visual_pointers[i])
- diff_indices.append(i)
-
- return diff_indices, diff_blocks, diff_visual_pointers
-
-def send_data(osc_client, indices, blocks, visual_pointers):
- def split_blocks_by_byte(blocks):
- blocks_byte00 = []
- blocks_byte01 = []
- for block in blocks:
- block_byte00 = []
- block_byte01 = []
- for datum in block:
- block_byte00.append((datum >> 0) & 0xFF)
- block_byte01.append((datum >> 8) & 0xFF)
- blocks_byte00.append(block_byte00)
- blocks_byte01.append(block_byte01)
- return blocks_byte00, blocks_byte01
-
- blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks)
- if LOG_LEVEL == 2:
- print(f"Blocks (byte 00): {blocks_byte00}")
- print(f"Blocks (byte 01): {blocks_byte01}")
-
- def send_osc(osc_client, addr, data):
- #print(f"Sending {data} to {addr}")
- osc_client.send_message(addr, data)
-
- for i in range(0, len(blocks)):
- lp_int = indices[i]
- lp_param = "_Unigram_Letter_Grid_OSC_Pointer"
- addr = "/avatar/parameters/" + lp_param
- send_osc(osc_client, addr, lp_int)
-
- vp_float = (-127.5 + visual_pointers[i]) / 127.5
- vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer"
- addr = "/avatar/parameters/" + vp_param
- send_osc(osc_client, addr, vp_float)
- if LOG_LEVEL == 2:
- print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
- for j in range(0, len(blocks[i])):
- byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5
- byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5
- byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00"
- byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01"
- addr = "/avatar/parameters/" + byte00_param
- send_osc(osc_client, addr, byte00_float)
- addr = "/avatar/parameters/" + byte01_param
- send_osc(osc_client, addr, byte01_float)
- time.sleep(0.34)
-
-def getOscClient(ip = "127.0.0.1", port = 9000):
- return udp_client.SimpleUDPClient(ip, port)
-
-class InputState:
- def __init__(self):
- self.page = 0
- # Initialize the known state of the board to empty array. This will cause
- # our paging logic to re-send everything the first time around.
- self.blocks = []
- self.visual_pointers = []
- pass
-
-def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
- line_wrapped = wrap_line(line, cfg["cols"])
- if TESTS_ENABLED:
- for line in line_wrapped:
- assert_equal(len(line), cfg["cols"])
- if LOG_LEVEL == 2:
- print(f"Wrapped lines: {line_wrapped}")
-
- # Get several blank lines whenever we roll over.
- # It's better for the reader to have some continuity when the board pages
- # over. If we simply replaced the entire screen, it would be harder to
- # understand.
- line_rollover = cfg["rows"] - 2
- blank_line = ' ' * cfg["cols"]
- # We show a full page, then only `line_rollover` additional lines per page.
- end_ptr = cfg["rows"]
- which_page = 0
- while end_ptr < len(line_wrapped):
- end_ptr += line_rollover
- which_page += 1
- if state.page != which_page:
- state.blocks = []
- state.visual_pointers = []
- state.page = which_page
- line_wrapped = line_wrapped[end_ptr-cfg["rows"]:]
-
- # Get blocks and visual pointers.
- blocks, visual_pointers = get_blocks(line_wrapped, tokenizer,
- cfg["block_width"], cfg["num_blocks"])
-
- # Note that because we only send one page of data at a time, we don't have
- # to worry about wrapping visual pointers! We will basically never run out
- # of space.
- indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers)
- indices = [idx % cfg["num_blocks"] for idx in indices]
- # Send only one block at a time to make things snappier in interactive use
- # case.
- # TODO use a continuation (yield) instead of returning. Then we can be a
- # little lighter on the cpu. Measurements show that this script is
- # already very light but we're clearly wasting a lot of work by
- # re-tokenizing the entire input every time we send a block.
- if len(indices) == 0:
- return
- if indices[0] == len(state.blocks):
- state.blocks.append(diff_blocks[0])
- state.visual_pointers.append(diff_visual_pointers[0])
- elif indices[0] > len(state.blocks):
- print(f"This should never happen!")
- sys.exit(1)
- else:
- state.blocks[indices[0]] = diff_blocks[0]
- state.visual_pointers[indices[0]] = diff_visual_pointers[0]
-
- send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]])
-
-def osc_thread(shared_data: SharedThreadData):
- tokenizer = get_tokenizer()
- osc_client = getOscClient()
-
- # Prime the board
- print("Priming the board")
- input_state = InputState()
- handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
-
- while not shared_data.exit_event.is_set():
- word_copy = ""
- with shared_data.word_lock:
- word_copy = shared_data.word
- handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
- time.sleep(0.01)
-
-if __name__ == "__main__":
- cli_args = parse_args()
- cfg = app_config.getConfig(cli_args.config)
- shared_data = SharedThreadData(cfg)
- osc_thread = threading.Thread(
- target=osc_thread,
- args=(shared_data,))
- osc_thread.start()
-
- transcribe_thread = threading.Thread(
- target=stt.transcriptionThread,
- args=(shared_data,))
- transcribe_thread.start()
-
- word_is_over = False
- local_word = ""
- while True:
- char_bytes = msvcrt.getch()
- if char_bytes == b'\x03': # ctrl+C
- break
-
- time.sleep(0.1)
- continue
-
-
- try:
- char = char_bytes.decode('utf-8')
- if char == '\r' or char == '\n':
- word_is_over = True
- continue
- except UnicodeDecodeError:
- print(f"Unsupported character: {char_bytes}")
- if char_bytes == b'\x00' or char_bytes == b'\xe0':
- special_char = msvcrt.getch()
- continue
-
- if char_bytes == b'\x03': # ctrl+C
- break
- elif char_bytes == b'\x08': # backspace
- with shared_data.word_lock:
- shared_data.word = shared_data.word[:-1]
- local_word = shared_data.word
- elif char_bytes == b'\x0c': # ctrl+L
- with shared_data.word_lock:
- shared_data.word = ""
- local_word = shared_data.word
- elif word_is_over:
- with shared_data.word_lock:
- shared_data.word = char
- local_word = shared_data.word
- word_is_over = False
- else:
- with shared_data.word_lock:
- shared_data.word += char
- local_word = shared_data.word
- print(local_word + "_")
- shared_data.exit_event.set()
- osc_thread.join()
- transcribe_thread.join()
-