diff options
| -rw-r--r-- | osc_ctrl.py | 40 | ||||
| -rw-r--r-- | steamvr.py | 18 | ||||
| -rw-r--r-- | string_matcher.py | 14 | ||||
| -rw-r--r-- | transcribe.py | 123 |
4 files changed, 116 insertions, 79 deletions
diff --git a/osc_ctrl.py b/osc_ctrl.py index 5fadd6f..2e7ef39 100644 --- a/osc_ctrl.py +++ b/osc_ctrl.py @@ -243,6 +243,15 @@ def sendMessageLazy(client, msg, tx_state): if cell_msg == last_cell_msg: continue + # Skip cells on previous pages. This mitigates a bug where updating the + # earlier part of a transcription causes that text to overwrite text + # from a later part of the transcription. + page = floor(cell / (2 ** generate_utils.INDEX_BITS)) + last_cell = (len(tx_state.last_msg_encoded) / NUM_LAYERS) - 1 + last_page = floor(last_cell / (2 ** generate_utils.INDEX_BITS)) + if page < last_page: + continue + if cell_msg == [state.encoding[' ']] * NUM_LAYERS: if empty_cells_sent >= tx_state.empty_cells_to_send_per_call: return False @@ -261,31 +270,6 @@ def sendMessageLazy(client, msg, tx_state): #resizeBoard(len(lines), tx_state, shrink_only=True) return True -def sendMessage(client, msg, page_sleep_s): - lines = splitMessage(msg) - msg = encodeMessage(lines) - msg_len = len(msg) - - print("Encoded message: {}".format(msg)) - - #openBoard() - - n_cells = ceil(msg_len / NUM_LAYERS) - print("n_cells: {}".format(n_cells)) - for cell in range(0, n_cells): - if cell > 0 and cell % (2 ** generate_utils.INDEX_BITS) == 0: - print("Sleeping before sending next page") - time.sleep(page_sleep_s) - - cell_begin = cell * NUM_LAYERS - cell_end = (cell + 1) * NUM_LAYERS - cell_msg = msg[cell_begin:cell_end] - print("Send cell {}".format(cell)) - sendMessageCellDiscrete(client, cell_msg, cell) - - #closeBoard() - #clear() - def sendRawMessage(client, msg): n_cells = ceil(len(msg) / NUM_LAYERS) for cell in range(0, n_cells): @@ -295,7 +279,7 @@ def sendRawMessage(client, msg): #print("Send cell {}".format(cell)) sendMessageCellDiscrete(client, cell_msg, cell) -def clear(client): +def clear(client, tx_state): disable(client) addr="/avatar/parameters/" + generate_utils.getClearBoardParam() @@ -306,6 +290,8 @@ def clear(client): addr="/avatar/parameters/" + generate_utils.getClearBoardParam() client.send_message(addr, False) + tx_state.last_message_encoded = [] + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-i", default="127.0.0.1", help="OSC server IP") @@ -325,5 +311,5 @@ if __name__ == "__main__": for line in fileinput.input(): while not sendMessageLazy(client, line, tx_state): continue - clear() + clear(client, tx_state) @@ -16,9 +16,6 @@ class SessionState: # Whether the configured input event is high or low. self.event_high = False -def getState() -> SessionState: - return SessionState() - # Checks if the given button on the given controller is pressed. # Defaults to joystick click / left hand. # Returns three values: @@ -40,12 +37,21 @@ def pollButtonPress( if state.unPacketNum == session_state.last_packet: return EVENT_NONE - #print("button pressed: %016x" % state.ulButtonPressed) + # Clicking joysticks and moving joysticks fire the same events. To + # differentiate movement from clicking, we create a dead zone: if the event + # fires while the stick isn't moved far from center, we assume it's a + # click, not movement. + dead_zone_radius = 0.5 + # This is the ID of event for the joystick being clicked. joy_click = vr.k_EButton_Axis0 joy_click_mask = (1 << joy_click) ret = EVENT_NONE - if (state.ulButtonPressed & joy_click_mask) != 0: + if (state.ulButtonPressed & joy_click_mask) != 0 and\ + (state.rAxis[0].x**2 + state.rAxis[0].y**2 < dead_zone_radius**2): + #print("button pressed: %016x" % state.ulButtonPressed) + #for i in range(0, 5): + # print("axis {} x: {} y: {}".format(i, state.rAxis[i].x, state.rAxis[i].y)) if not session_state.event_high: ret = EVENT_RISING_EDGE session_state.event_high = True @@ -65,5 +71,3 @@ if __name__ == "__main__": elif event == EVENT_FALLING_EDGE: print("falling edge") - - diff --git a/string_matcher.py b/string_matcher.py index cf11133..1c6868e 100644 --- a/string_matcher.py +++ b/string_matcher.py @@ -78,13 +78,13 @@ def matchStrings(old_text: str, new_text: str, window_size = 3) -> str: for j in range(0, 1 + len(new_text) - window_size): new_slice = new_text[j:j + window_size] cur_d = editdistance.eval(old_slice, new_slice) - if cur_d <= best_match_d: + if cur_d < best_match_d: best_match_i = i best_match_j = j best_match_d = cur_d if DEBUG: - print("optimum at old '{}'/{} new '{}'/{} d={}".format( + print("optimum at old '{}' i={} new '{}' j={} d={}".format( old_slice, i, new_slice, j, cur_d)) old_prefix = old_text[0:best_match_i] @@ -128,7 +128,7 @@ if __name__ == "__main__": in1 = "Okay, what about now? Looks like it sort of works. Key word being sort of." in2 = "okay what about now looks like it sort of works key word being sort of looks" bad_out = "Okay, what about now? Looks like it sort of works. Key word being sort of works key word being sort of looks" - good_out = "Okay, what about now? Looks like it sort of works. Key word being sort of looks" + good_out = "Okay what about now looks like it sort of works key word being sort of looks" assert(matchStrings(in1, in2) == good_out) in1 = "This repository can take" @@ -137,10 +137,16 @@ if __name__ == "__main__": good_out = "This repository contains the code for" assert(matchStrings(in1, in2) == good_out) + in1 = "See something." + in2 = "See something. Say something." + bad_out = in1 + good_out = in2 + assert(matchStrings(in1, in2) == good_out) + in1 = "a" * 1000 in2 = "b" * 10 * 1000 # This should be fast (< 1 second) - matchStrings(in1, in2) + #matchStrings(in1, in2) print("Tests passed.") diff --git a/transcribe.py b/transcribe.py index 2e4457e..1327515 100644 --- a/transcribe.py +++ b/transcribe.py @@ -2,7 +2,6 @@ import argparse import copy -import string_matcher import os import osc_ctrl # python3 -m pip install pydub @@ -12,6 +11,7 @@ from pydub import effects as pydub_effects # python3 -m pip install pyaudio # License: MIT. import pyaudio +import steamvr import sys import threading import time @@ -42,21 +42,29 @@ class AudioState: # PyAudio stream object stream = None - frames = [] - frames_lock = threading.Lock() - text = "" - text_lock = threading.Lock() + committed_text = "" + frames = [] + # Locks access to `text`, `frames`, and audio stored on disk. + lock = threading.Lock() - record_audio = True - transcribe_audio = True - send_audio = True + # Used to tell the threads when to stop. + run_app = True transcribe_sleep_duration_min_s = 0.05 transcribe_sleep_duration_max_s = 1.50 transcribe_no_change_count = 0 transcribe_sleep_duration = transcribe_sleep_duration_min_s - # The language the user is speaking in. + + tx_state = osc_ctrl.OscTxState() + + # The transcription thread transcribes without holding locks, then + # blocks on it. Thus we need some way to tell the transcription + # thread to drop that transcription. + drop_transcription = False + + # The language the user is speaking in. Default is English but user may set + # this to whatever they want. language = whisper.tokenizer.TO_LANGUAGE_CODE["english"] # When the user says `over`, we stop displaying new transcriptions until @@ -104,18 +112,18 @@ def getMicStream(which_mic): return audio_state -# Continuously records audio as long as audio_state.record_audio is set. +# Continuously records audio as long as audio_state.run_app is set. def recordAudio(audio_state): print("Recording audio") - while audio_state.record_audio: + while audio_state.run_app: data = audio_state.stream.read(audio_state.CHUNK) - audio_state.frames_lock.acquire() + audio_state.lock.acquire() audio_state.frames.append(data) max_frames = int(audio_state.RATE * audio_state.MAX_LENGTH_S / audio_state.CHUNK) if len(audio_state.frames) > max_frames: audio_state.frames = audio_state.frames[-1 * max_frames :] - audio_state.frames_lock.release() + audio_state.lock.release() print("Done recording") @@ -130,9 +138,9 @@ def saveAudio(audio_state, filename): wf.setsampwidth(audio_state.p.get_sample_size(audio_state.FORMAT)) wf.setframerate(audio_state.RATE) - audio_state.frames_lock.acquire() + audio_state.lock.acquire() frames = copy.deepcopy(audio_state.frames) - audio_state.frames_lock.release() + audio_state.lock.release() wf.writeframes(b''.join(frames)) wf.close() @@ -164,20 +172,23 @@ def resetAudioLocked(audio_state): resetDiskAudioLocked(audio_state, audio_state.VOICE_AUDIO_FILENAME) + audio_state.committed_text = "" audio_state.text = "" - osc_ctrl.clear(audio_state.osc_client) + +def resetDisplayLocked(audio_state): + osc_ctrl.clear(audio_state.osc_client, audio_state.tx_state) def resetAudio(audio_state): - audio_state.frames_lock.acquire() + audio_state.lock.acquire() resetAudioLocked(audio_state) - audio_state.frames_lock.release() + audio_state.lock.release() # Transcribe the audio recorded in a file. def transcribe(audio_state, model, filename): - audio_state.frames_lock.acquire() + audio_state.lock.acquire() audio = whisper.load_audio(filename) - audio_state.frames_lock.release() + audio_state.lock.release() audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio).to(model.device) @@ -186,7 +197,7 @@ def transcribe(audio_state, model, filename): beam_size = 5) result = whisper.decode(model, mel, options) - if result.no_speech_prob > 0.15: + if result.no_speech_prob > 0.60: print("no speech prob: {}".format(result.no_speech_prob)) return None @@ -201,7 +212,7 @@ def transcribe(audio_state, model, filename): return result.text def transcribeAudio(audio_state, model): - while audio_state.transcribe_audio == True: + while audio_state.run_app == True: # Pace this out print("sleep duration: {}".format(audio_state.transcribe_sleep_duration)) time.sleep(audio_state.transcribe_sleep_duration) @@ -226,18 +237,29 @@ def transcribeAudio(audio_state, model): text = transcribe(audio_state, model, audio_state.VOICE_AUDIO_FILENAME) if not text: continue + audio_state.lock.acquire() - audio_state.text_lock.acquire() + if audio_state.drop_transcription: + audio_state.drop_transcription = False + audio_state.lock.release() + continue + + # Hack: two consecutive identical transcriptions get "committed". + if text == audio_state.text: + print("Commit!") + old_commit = audio_state.committed_text + resetAudioLocked(audio_state) + audio_state.committed_text = old_commit + " " + text + audio_state.lock.release() + continue + else: + print("text: {}".format(text)) + print("audio_state.text: {}".format(audio_state.text)) words = ''.join(c for c in text.lower() if (c.isalpha() or c == " ")).split() if len(words) > 0: - if words[-1] == "clear": - resetAudio(audio_state) - audio_state.text_lock.release() - audio_state.display_paused = False - continue - elif words[-1] == "over": + if words[-1] == "over": words = words[0:-1] audio_state.display_paused = True @@ -247,32 +269,45 @@ def transcribeAudio(audio_state, model): #old_words = audio_state.text.split() #new_words = text.split() - audio_state.text = string_matcher.matchStrings(audio_state.text, - text, window_size = 5) + audio_state.text = text if old_text != audio_state.text: # We think the user said something, so reset the amount of # time we sleep between transcriptions to the minimum. audio_state.transcribe_no_change_count = 0 audio_state.transcribe_sleep_duration = audio_state.transcribe_sleep_duration_min_s - audio_state.text_lock.release() + audio_state.lock.release() def sendAudio(audio_state): - tx_state = osc_ctrl.OscTxState() - - while audio_state.send_audio == True: + while audio_state.run_app == True: if audio_state.display_paused: time.sleep(0.1) continue - audio_state.text_lock.acquire() - text = copy.deepcopy(audio_state.text) - osc_ctrl.sendMessageLazy(audio_state.osc_client, text, tx_state) - audio_state.text_lock.release() + audio_state.lock.acquire() + text = audio_state.committed_text + " " + audio_state.text + osc_ctrl.sendMessageLazy(audio_state.osc_client, text, audio_state.tx_state) + audio_state.lock.release() # Pace this out time.sleep(0.01) +def readControllerInput(audio_state): + session = steamvr.SessionState() + while audio_state.run_app == True: + time.sleep(0.05) + + event = steamvr.pollButtonPress(session) + + if event == steamvr.EVENT_RISING_EDGE: + print("event get") + audio_state.lock.acquire() + resetAudioLocked(audio_state) + resetDisplayLocked(audio_state) + audio_state.drop_transcription = True + audio_state.display_paused = False + audio_state.lock.release() + def transcribeLoop(mic: str, language: str): audio_state = getMicStream(mic) audio_state.language = whisper.tokenizer.TO_LANGUAGE_CODE[language] @@ -297,18 +332,24 @@ def transcribeLoop(mic: str, language: str): send_audio_thd.daemon = True send_audio_thd.start() + controller_input_thd = threading.Thread(target = readControllerInput, args = [audio_state]) + controller_input_thd.daemon = True + controller_input_thd.start() + print("Press enter or say 'Clear' to start a new message. Say 'Over' to " + "pause the display (saying 'Clear' resets it again).") for line in sys.stdin: resetAudio(audio_state) + resetDisplayLocked(audio_state) if "exit" in line or "quit" in line: break print("Joining threads") - audio_state.record_audio = False - audio_state.transcribe_audio = False + audio_state.run_app = False + audio_state.run_app = False record_audio_thd.join() transcribe_audio_thd.join() + controller_input_thd.join() if __name__ == "__main__": |
