From 2fd5771ae4c8b7774b859422eb00216af07ef4fa Mon Sep 17 00:00:00 2001 From: yum Date: Mon, 3 Oct 2022 21:57:51 -0700 Subject: Introduce STT proof-of-concept Using OpenAI's whisper neural network, we can do local STT. Translation quality is good, system resource usage is minimal (1 GB VRAM), latency is much lower than cloud-based translation. * Add transcribe.py * Creates 3 threads: * One saves mic audio to a buffer * One passes mic audio to the STT * One sends the transcribed text to the board * Main thread listens for input. Press enter to start a new message. * Add osc_ctrl.sendMessageLazy, a simple diff-based message sending utility. * A little complexity: it only sends 1 empty cell per call, allowing us to quickly say new things without having to wait for the whole buffer to clear. --- transcribe.py | 175 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 transcribe.py (limited to 'transcribe.py') diff --git a/transcribe.py b/transcribe.py new file mode 100644 index 0000000..4548214 --- /dev/null +++ b/transcribe.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +import copy +import fileinput +import os +import osc_ctrl +# python3 -m pip install pyaudio +import pyaudio +import threading +import time +import wave +# python3 -m pip install git+https://github.com/openai/whisper.git +# python3 -m pip install torch -f https://download.pytorch.org/whl/torch_stable.html +import whisper + +class AudioState: + CHUNK = 1024 + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + # This matches the framerate expected by whisper. + RATE = 16000 + + # The maximum length that recordAudio() will put into frames before it + # starts dropping from the start. + MAX_LENGTH_S = 30 + + # PyAudio object + p = None + + # PyAudio stream object + stream = None + + frames = [] + frames_lock = threading.Lock() + + text = "" + text_lock = threading.Lock() + + record_audio = True + transcribe_audio = True + send_audio = True + +def getMicStream(): + audio_state = AudioState() + audio_state.p = pyaudio.PyAudio() + + info = audio_state.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + print("Finding index mic...") + got_match = False + device_index = -1 + while got_match == False: + for i in range(0, numdevices): + if (audio_state.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = audio_state.p.get_device_info_by_host_api_device_index(0, i).get('name') + #print("Input Device id ", i, " - ", device_name) + if "Digital Audio Interface" in device_name: + print("Got match: {}".format(device_name)) + device_index = i + got_match = True + if got_match == False: + print("No match, sleeping") + time.sleep(3) + + audio_state.stream = audio_state.p.open(format=audio_state.FORMAT, + channels=audio_state.CHANNELS, rate=audio_state.RATE, + input=True, frames_per_buffer=audio_state.CHUNK, + input_device_index=device_index) + + return audio_state + +# Continuously records audio as long as audio_state.record_audio is set. +def recordAudio(audio_state): + print("Recording audio") + while audio_state.record_audio: + data = audio_state.stream.read(audio_state.CHUNK) + + audio_state.frames_lock.acquire() + audio_state.frames.append(data) + max_frames = int(audio_state.RATE * audio_state.MAX_LENGTH_S / audio_state.CHUNK) + if len(audio_state.frames) > max_frames: + audio_state.frames = audio_state.frames[-1 * max_frames :] + audio_state.frames_lock.release() + + print("Done recording") + +# Saves audio. recordAudio() may continue running while this takes place. +def saveAudio(audio_state, filename): + wf = wave.open(filename, 'wb') + wf.setnchannels(audio_state.CHANNELS) + wf.setsampwidth(audio_state.p.get_sample_size(audio_state.FORMAT)) + wf.setframerate(audio_state.RATE) + + audio_state.frames_lock.acquire() + frames = copy.deepcopy(audio_state.frames) + audio_state.frames_lock.release() + + wf.writeframes(b''.join(frames)) + wf.close() + +def resetAudio(audio_state): + audio_state.frames_lock.acquire() + audio_state.frames = [] + audio_state.frames_lock.release() + +# Transcribe the audio recorded in a file. +def transcribe(model, filename): + print("Loading audio") + audio = whisper.load_audio(filename) + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio).to(model.device) + options = whisper.DecodingOptions(language = "en") + result = whisper.decode(model, mel, options) + print("Transcribed text: {}".format(result.text)) + return result.text + +def transcribeAudio(audio_state, model): + while audio_state.transcribe_audio == True: + print("Saving audio") + saveAudio(audio_state, "audio.wav") + + print("Beginning transcription") + text = transcribe(model, "audio.wav") + + audio_state.text_lock.acquire() + audio_state.text = text + audio_state.text_lock.release() + + # Pace this out + time.sleep(0.2) + +def sendAudio(audio_state): + tx_state = osc_ctrl.OscTxState() + while audio_state.send_audio == True: + audio_state.text_lock.acquire() + text = copy.deepcopy(audio_state.text) + audio_state.text_lock.release() + + osc_ctrl.sendMessageLazy(text, tx_state) + + # Pace this out + time.sleep(0.05) + +if __name__ == "__main__": + audio_state = getMicStream() + + record_audio_thd = threading.Thread(target = recordAudio, args = [audio_state]) + record_audio_thd.daemon = True + record_audio_thd.start() + + print("Safe to start talking") + + model = whisper.load_model("base") + + transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state, model]) + transcribe_audio_thd.daemon = True + transcribe_audio_thd.start() + + send_audio_thd = threading.Thread(target = sendAudio, args = [audio_state]) + send_audio_thd.daemon = True + send_audio_thd.start() + + print("Press enter to start a new message") + for line in fileinput.input(): + resetAudio(audio_state) + if "exit" in line or "quit" in line: + break + + print("Joining threads") + audio_state.record_audio = False + audio_state.transcribe_audio = False + record_audio_thd.join() + transcribe_audio_thd.join() + -- cgit v1.2.3