diff options
| author | yum <yum.food.vr@gmail.com> | 2022-10-15 18:14:40 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2022-10-15 18:14:40 -0700 |
| commit | eba68f4fb35078327b75e99c25100ec1154efb13 (patch) | |
| tree | ad39f825de5e4db1feb1ec55d09b83487c40c86f | |
| parent | 3059967c75a41ed79b40ed3f84adbb874b0c3a33 (diff) | |
Tweak transcribe.py
Slightly improve temporal stability and responsiveness at the cost of
limiting to a 30 second recording.
Before committing to a transcription, wait for two consecutive
transcriptions such that they are identical, or the former is a
prefix of the latter. This helps with temporal stability by eliminating
most one-off wildly inaccurate transcriptions.
Also make osc_ctrl.sendMessageLazy a little lazier, limiting it to 2
consecutive non-empty cells per call. This allows us to recover from
mistranscriptions faster.
| -rw-r--r-- | osc_ctrl.py | 8 | ||||
| -rw-r--r-- | transcribe.py | 36 |
2 files changed, 35 insertions, 9 deletions
diff --git a/osc_ctrl.py b/osc_ctrl.py index 40cb7d1..0352e6f 100644 --- a/osc_ctrl.py +++ b/osc_ctrl.py @@ -161,6 +161,7 @@ class OscTxState: # The message last sent to the board. last_msg_encoded = [] empty_cells_to_send_per_call = 1 + nonempty_cells_to_send_per_call = 2 # 0 indicates it's closed. 1 indicates half size. 2 indicates full size. board_size = 0 @@ -246,6 +247,7 @@ def sendMessageLazy(client, msg, tx_state): msg_encoded_len = len(msg_encoded) empty_cells_sent = 0 + nonempty_cells_sent = 0 n_cells = ceil(msg_encoded_len / NUM_LAYERS) for cell in range(0, n_cells): cell_begin = cell * NUM_LAYERS @@ -272,6 +274,12 @@ def sendMessageLazy(client, msg, tx_state): tx_state.last_msg_encoded = msg_encoded[0:cell_end] return False empty_cells_sent += 1 + else: + if nonempty_cells_sent >= tx_state.nonempty_cells_to_send_per_call: + print("nonempty cell budget exceeded") + tx_state.last_msg_encoded = msg_encoded[0:cell_end] + return False + nonempty_cells_sent += 1 sendMessageCellDiscrete(client, cell_msg, cell) diff --git a/transcribe.py b/transcribe.py index dc36541..8369c43 100644 --- a/transcribe.py +++ b/transcribe.py @@ -12,8 +12,7 @@ import time import wave # python3 -m pip install git+https://github.com/openai/whisper.git # python3 -m pip install torch -f https://download.pytorch.org/whl/torch_stable.html -from whisper import transcribe as whisper_transcribe -from whisper import load_model as whisper_load_model +import whisper class AudioState: CHUNK = 1024 @@ -24,9 +23,9 @@ class AudioState: # The maximum length that recordAudio() will put into frames before it # starts dropping from the start. - MAX_LENGTH_S = 90 + MAX_LENGTH_S = 25 # The minimum length that recordAudio() will wait for before saving audio. - MIN_LENGTH_S = 3 + MIN_LENGTH_S = 1 # PyAudio object p = None @@ -38,6 +37,9 @@ class AudioState: frames_lock = threading.Lock() text = "" + # To improve temporal stability, we require two consecutive identical + # transcriptions before "committing" to a transcription. + text_candidate = "" text_lock = threading.Lock() record_audio = True @@ -118,6 +120,8 @@ def saveAudio(audio_state, filename): wf.writeframes(b''.join(frames)) wf.close() + print("audio save") + def resetAudio(audio_state): audio_state.frames_lock.acquire() audio_state.frames = [] @@ -125,8 +129,16 @@ def resetAudio(audio_state): # Transcribe the audio recorded in a file. def transcribe(model, filename): - result = whisper_transcribe(model=model, audio=filename, language="en") - return result["text"] + + audio = whisper.load_audio(filename) + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio).to(model.device) + _, probs = model.detect_language(mel) + print(f"Detected language: {max(probs, key=probs.get)}") + options = whisper.DecodingOptions() + result = whisper.decode(model, mel, options) + + return result.text def transcribeAudio(audio_state, model): while audio_state.transcribe_audio == True: @@ -140,13 +152,18 @@ def transcribeAudio(audio_state, model): text = transcribe(model, "audio.wav") audio_state.text_lock.acquire() - audio_state.text = text + + if text == audio_state.text_candidate or text.startswith(audio_state.text_candidate): + audio_state.text = text + audio_state.text_candidate = text + audio_state.text_lock.release() print("Transcription: {}".format(audio_state.text)) + print("Candidate: {}".format(audio_state.text_candidate)) # Pace this out - time.sleep(0.2) + time.sleep(0.05) def sendAudio(audio_state): tx_state = osc_ctrl.OscTxState() @@ -155,6 +172,7 @@ def sendAudio(audio_state): text = copy.deepcopy(audio_state.text) audio_state.text_lock.release() + print("here") osc_ctrl.sendMessageLazy(audio_state.osc_client, text, tx_state) # Pace this out @@ -179,7 +197,7 @@ if __name__ == "__main__": print("Safe to start talking") - model = whisper_load_model("base") + model = whisper.load_model("base") transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state, model]) transcribe_audio_thd.daemon = True |
