diff options
| author | yum <yum.food.vr@gmail.com> | 2022-10-03 21:57:51 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2022-10-03 21:57:51 -0700 |
| commit | 2fd5771ae4c8b7774b859422eb00216af07ef4fa (patch) | |
| tree | 0b44b0f4736309836e103d38819082f5ee747c05 | |
| parent | d4556af258ae3911c83ece4b817335e8c5a2a2d2 (diff) | |
Introduce STT proof-of-concept
Using OpenAI's whisper neural network, we can do local STT. Translation
quality is good, system resource usage is minimal (1 GB VRAM), latency
is much lower than cloud-based translation.
* Add transcribe.py
* Creates 3 threads:
* One saves mic audio to a buffer
* One passes mic audio to the STT
* One sends the transcribed text to the board
* Main thread listens for input. Press enter to start a new message.
* Add osc_ctrl.sendMessageLazy, a simple diff-based message sending utility.
* A little complexity: it only sends 1 empty cell per call, allowing us to
quickly say new things without having to wait for the whole buffer to
clear.
| -rw-r--r-- | osc_ctrl.py | 97 | ||||
| -rw-r--r-- | transcribe.py | 175 |
2 files changed, 255 insertions, 17 deletions
diff --git a/osc_ctrl.py b/osc_ctrl.py index 6f4ac65..200ac54 100644 --- a/osc_ctrl.py +++ b/osc_ctrl.py @@ -17,9 +17,9 @@ from generate_utils import NUM_LAYERS from generate_utils import BOARD_ROWS from generate_utils import BOARD_COLS -#CELL_TX_TIME_S=3.0 -#CELL_TX_TIME_S=1.0 -CELL_TX_TIME_S=0.1 +# Based on a couple experiments, this seems like about as fast as we can go +# before players start losing events. +CELL_TX_TIME_S=0.3 def usage(): print("python3 -m pip install python-osc") @@ -63,11 +63,14 @@ generateEncoding(state) def encodeMessage(lines): result = [] # Pad the number of lines up to a multiple of BOARD_ROWS. - print("Pad {} lines".format(BOARD_ROWS - (len(lines) % BOARD_ROWS))) + #print("Pad {} lines".format(BOARD_ROWS - (len(lines) % BOARD_ROWS))) lines += [" "] * ((BOARD_ROWS - len(lines)) % BOARD_ROWS) for line in lines: - print("encode line {}".format(line)) + #print("encode line {}".format(line)) for char in line: + if not char in state.encoding: + print("skip unrecognized char {}".format(char)) + continue result.append(state.encoding[char]) result += [state.encoding[' ']] * (BOARD_COLS - len(line)) return result @@ -112,7 +115,6 @@ def sendMessageCellDiscrete(msg_cell, which_cell): s2 = ((floor(which_cell / 2) % 2) == 1) s3 = ((floor(which_cell / 1) % 2) == 1) - print("Cell s0/s1/s2/s3: {}/{}/{}/{}".format(s0,s1,s2,s3)) # Seek each layer to the current cell. for i in range(0, len(msg_cell)): updateCell(i, msg_cell[i], s0, s1, s2, s3) @@ -121,8 +123,6 @@ def sendMessageCellDiscrete(msg_cell, which_cell): time.sleep(CELL_TX_TIME_S / 3.0) # Enable each layer. - # TODO(yum_food) for some reason, if we don't active every layer, the - # desired subset won't reliably fire. Why? enable() # Wait for convergence. @@ -163,7 +163,50 @@ def splitMessage(msg): return lines -def sendMessage(msg): +class OscTxState: + # The message last sent to the board. + last_msg_encoded = [] + empty_cells_to_send_per_call = 1 + +# Send a message to the board, but only overwrite cells that we know need to +# change. +def sendMessageLazy(msg, tx_state): + lines = splitMessage(msg) + msg_encoded = encodeMessage(lines) + msg_encoded_len = len(msg_encoded) + + empty_cells_sent = 0 + n_cells = ceil(msg_encoded_len / NUM_LAYERS) + for cell in range(0, n_cells): + if cell > 0 and cell % (2 ** generate_utils.INDEX_BITS) == 0: + # TODO(yum_food) support messages longer than one page + print("Page limit exceeded, no support yet") + return + + cell_begin = cell * NUM_LAYERS + cell_end = (cell + 1) * NUM_LAYERS + cell_msg = msg_encoded[cell_begin:cell_end] + last_cell_msg = [] + + # Skip cells we've already sent. This makes the board much more + # responsive. + if cell_end < len(tx_state.last_msg_encoded): + last_cell_msg = tx_state.last_msg_encoded[cell_begin:cell_end] + if cell_msg == last_cell_msg: + continue + + if cell_msg == [state.encoding[' ']] * NUM_LAYERS: + if empty_cells_sent >= tx_state.empty_cells_to_send_per_call: + print("empty cell budget exceeded") + tx_state.last_msg_encoded = msg_encoded[0:cell_end] + return + empty_cells_sent += 1 + + sendMessageCellDiscrete(cell_msg, cell) + + tx_state.last_msg_encoded = msg_encoded + +def sendMessage(msg, page_sleep_s): lines = splitMessage(msg) msg = encodeMessage(lines) msg_len = len(msg) @@ -175,14 +218,18 @@ def sendMessage(msg): n_cells = ceil(msg_len / NUM_LAYERS) print("n_cells: {}".format(n_cells)) for cell in range(0, n_cells): + if cell > 0 and cell % (2 ** generate_utils.INDEX_BITS) == 0: + print("Sleeping before sending next page") + time.sleep(page_sleep_s) + cell_begin = cell * NUM_LAYERS cell_end = (cell + 1) * NUM_LAYERS cell_msg = msg[cell_begin:cell_end] print("Send cell {}".format(cell)) sendMessageCellDiscrete(cell_msg, cell) - #sendMessageCellContinuous(cell_msg, cell) #closeBoard() + #clear() def sendRawMessage(msg): n_cells = ceil(len(msg) / NUM_LAYERS) @@ -190,49 +237,65 @@ def sendRawMessage(msg): cell_begin = cell * NUM_LAYERS cell_end = (cell + 1) * NUM_LAYERS cell_msg = msg[cell_begin:cell_end] - print("Send cell {}".format(cell)) + #print("Send cell {}".format(cell)) sendMessageCellDiscrete(cell_msg, cell) def clear(): sendRawMessage([state.encoding[' ']] * BOARD_ROWS * BOARD_COLS) def openBoard(): + print("Opening board... "), addr="/avatar/parameters/" + generate_utils.getResize0Param() client.send_message(addr, False) addr="/avatar/parameters/" + generate_utils.getResize1Param() client.send_message(addr, False) - time.sleep(0.3) + time.sleep(CELL_TX_TIME_S / 3.0) addr="/avatar/parameters/" + generate_utils.getResizeEnableParam() client.send_message(addr, True) - time.sleep(0.3) + # The animation is 0.5 seconds, with another 0.5 second buffer after. We + # want to stop in that buffer. + time.sleep(0.7) addr="/avatar/parameters/" + generate_utils.getResizeEnableParam() client.send_message(addr, False) + # Wait for the 1-second animation to complete, plus a wide margin for + # safety. + time.sleep(0.3 + 1) + print("done") + def closeBoard(): + print("Closing board... "), addr="/avatar/parameters/" + generate_utils.getResize0Param() client.send_message(addr, True) addr="/avatar/parameters/" + generate_utils.getResize1Param() client.send_message(addr, True) - time.sleep(0.1) + time.sleep(CELL_TX_TIME_S / 3.0) addr="/avatar/parameters/" + generate_utils.getResizeEnableParam() client.send_message(addr, True) - time.sleep(0.1) + # The animation is 0.5 seconds, with another 0.5 second buffer after. We + # want to stop in that buffer. + time.sleep(0.7) addr="/avatar/parameters/" + generate_utils.getResizeEnableParam() client.send_message(addr, False) + time.sleep(1) + print("done") + if __name__ == "__main__": generateEncoding(state) + #closeBoard() clear() for line in fileinput.input(): - sendMessage(line) - time.sleep(1 + len(line) / 40.0) + page_sleep_s = 3 + sendMessage(line, page_sleep_s) + #time.sleep(2 + len(line) / 40.0) clear() diff --git a/transcribe.py b/transcribe.py new file mode 100644 index 0000000..4548214 --- /dev/null +++ b/transcribe.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3
+
+import copy
+import fileinput
+import os
+import osc_ctrl
+# python3 -m pip install pyaudio
+import pyaudio
+import threading
+import time
+import wave
+# python3 -m pip install git+https://github.com/openai/whisper.git
+# python3 -m pip install torch -f https://download.pytorch.org/whl/torch_stable.html
+import whisper
+
+class AudioState:
+ CHUNK = 1024
+ FORMAT = pyaudio.paInt16
+ CHANNELS = 1
+ # This matches the framerate expected by whisper.
+ RATE = 16000
+
+ # The maximum length that recordAudio() will put into frames before it
+ # starts dropping from the start.
+ MAX_LENGTH_S = 30
+
+ # PyAudio object
+ p = None
+
+ # PyAudio stream object
+ stream = None
+
+ frames = []
+ frames_lock = threading.Lock()
+
+ text = ""
+ text_lock = threading.Lock()
+
+ record_audio = True
+ transcribe_audio = True
+ send_audio = True
+
+def getMicStream():
+ audio_state = AudioState()
+ audio_state.p = pyaudio.PyAudio()
+
+ info = audio_state.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+
+ print("Finding index mic...")
+ got_match = False
+ device_index = -1
+ while got_match == False:
+ for i in range(0, numdevices):
+ if (audio_state.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = audio_state.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ #print("Input Device id ", i, " - ", device_name)
+ if "Digital Audio Interface" in device_name:
+ print("Got match: {}".format(device_name))
+ device_index = i
+ got_match = True
+ if got_match == False:
+ print("No match, sleeping")
+ time.sleep(3)
+
+ audio_state.stream = audio_state.p.open(format=audio_state.FORMAT,
+ channels=audio_state.CHANNELS, rate=audio_state.RATE,
+ input=True, frames_per_buffer=audio_state.CHUNK,
+ input_device_index=device_index)
+
+ return audio_state
+
+# Continuously records audio as long as audio_state.record_audio is set.
+def recordAudio(audio_state):
+ print("Recording audio")
+ while audio_state.record_audio:
+ data = audio_state.stream.read(audio_state.CHUNK)
+
+ audio_state.frames_lock.acquire()
+ audio_state.frames.append(data)
+ max_frames = int(audio_state.RATE * audio_state.MAX_LENGTH_S / audio_state.CHUNK)
+ if len(audio_state.frames) > max_frames:
+ audio_state.frames = audio_state.frames[-1 * max_frames :]
+ audio_state.frames_lock.release()
+
+ print("Done recording")
+
+# Saves audio. recordAudio() may continue running while this takes place.
+def saveAudio(audio_state, filename):
+ wf = wave.open(filename, 'wb')
+ wf.setnchannels(audio_state.CHANNELS)
+ wf.setsampwidth(audio_state.p.get_sample_size(audio_state.FORMAT))
+ wf.setframerate(audio_state.RATE)
+
+ audio_state.frames_lock.acquire()
+ frames = copy.deepcopy(audio_state.frames)
+ audio_state.frames_lock.release()
+
+ wf.writeframes(b''.join(frames))
+ wf.close()
+
+def resetAudio(audio_state):
+ audio_state.frames_lock.acquire()
+ audio_state.frames = []
+ audio_state.frames_lock.release()
+
+# Transcribe the audio recorded in a file.
+def transcribe(model, filename):
+ print("Loading audio")
+ audio = whisper.load_audio(filename)
+ audio = whisper.pad_or_trim(audio)
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
+ options = whisper.DecodingOptions(language = "en")
+ result = whisper.decode(model, mel, options)
+ print("Transcribed text: {}".format(result.text))
+ return result.text
+
+def transcribeAudio(audio_state, model):
+ while audio_state.transcribe_audio == True:
+ print("Saving audio")
+ saveAudio(audio_state, "audio.wav")
+
+ print("Beginning transcription")
+ text = transcribe(model, "audio.wav")
+
+ audio_state.text_lock.acquire()
+ audio_state.text = text
+ audio_state.text_lock.release()
+
+ # Pace this out
+ time.sleep(0.2)
+
+def sendAudio(audio_state):
+ tx_state = osc_ctrl.OscTxState()
+ while audio_state.send_audio == True:
+ audio_state.text_lock.acquire()
+ text = copy.deepcopy(audio_state.text)
+ audio_state.text_lock.release()
+
+ osc_ctrl.sendMessageLazy(text, tx_state)
+
+ # Pace this out
+ time.sleep(0.05)
+
+if __name__ == "__main__":
+ audio_state = getMicStream()
+
+ record_audio_thd = threading.Thread(target = recordAudio, args = [audio_state])
+ record_audio_thd.daemon = True
+ record_audio_thd.start()
+
+ print("Safe to start talking")
+
+ model = whisper.load_model("base")
+
+ transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state, model])
+ transcribe_audio_thd.daemon = True
+ transcribe_audio_thd.start()
+
+ send_audio_thd = threading.Thread(target = sendAudio, args = [audio_state])
+ send_audio_thd.daemon = True
+ send_audio_thd.start()
+
+ print("Press enter to start a new message")
+ for line in fileinput.input():
+ resetAudio(audio_state)
+ if "exit" in line or "quit" in line:
+ break
+
+ print("Joining threads")
+ audio_state.record_audio = False
+ audio_state.transcribe_audio = False
+ record_audio_thd.join()
+ transcribe_audio_thd.join()
+
|
