diff options
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/transcribe.py | 849 | ||||
| -rw-r--r-- | Scripts/vad.py | 315 |
2 files changed, 315 insertions, 849 deletions
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py deleted file mode 100644 index 7098400..0000000 --- a/Scripts/transcribe.py +++ /dev/null @@ -1,849 +0,0 @@ -#!/usr/bin/env python3 - -from datetime import datetime -from emotes_v2 import EmotesState -from faster_whisper import WhisperModel -from functools import partial -from math import ceil, floor -from profanity_filter import ProfanityFilter -from sentence_splitter import split_text_into_sentences - -import argparse -import app_config -import copy -import ctranslate2 -import editdistance -import generate_utils -import keybind_event_machine -import keyboard -import lang_compat -import langcodes -import numpy as np -import os -import osc_ctrl -import pyaudio -import steamvr -import subprocess -import sys -import threading -import time -import transformers -import typing -import wave -import winsound - -class AudioState: - def __init__(self): - self.CHUNK = 1024 - self.FORMAT = pyaudio.paInt16 - self.CHANNELS = 1 - # This matches the framerate expected by whisper. - self.RATE = 16000 - - # The maximum length that recordAudio() will put into frames before it - # starts dropping from the start. - self.MAX_LENGTH_S = 300 - # The minimum length that recordAudio() will wait for before saving audio. - self.MIN_LENGTH_S = 1 - - # PyAudio object - self.p = None - - # PyAudio stream object - self.stream = None - - self.preview_text = "" - self.text = "" - self.filtered_text = "" - - # The edit distance under which two consecutive transcripts are - # considered to match. This affects how easily `preview_text` - # gets appended to `text`. - self.commit_fuzz_threshold = 1 - - # If set, profanity in transcriptions will have their vowels replaced - # with asterisks. Only works in English. - self.profanity_filter: ProfanityFilter = None - - # List of: - # List of tuples of: - # Segment start time, end time, and text - self.ranges_ls = [] - self.frames = [] - self.drop_samples_till_i = -1 - - # Locks access to `text`. - self.transcribe_lock = threading.Lock() - - # Locks access to `frames`, and audio stored on disk. - self.audio_lock = threading.Lock() - - # Audio events that should play. Input thread appends to this list, - # audio feedback thread drains it. - self.audio_events = [] - self.AUDIO_EVENT_TOGGLE_ON = 1 - self.AUDIO_EVENT_TOGGLE_OFF = 2 - self.AUDIO_EVENT_DISMISS = 3 - self.AUDIO_EVENT_UPDATE = 4 - - # Used to tell the threads when to stop. - self.run_app = True - - self.transcribe_sleep_duration_min_s = 0.05 - self.transcribe_sleep_duration_max_s = 5.00 - self.transcribe_no_change_count = 0 - self.transcribe_sleep_duration = self.transcribe_sleep_duration_min_s - - # The transcription thread transcribes without holding locks, then - # blocks on it. Thus we need some way to tell the transcription - # thread to drop that transcription. - self.drop_transcription = False - - # The language the user is speaking in. Default is English but user may set - # this to whatever they want. - self.language = "english" - - self.audio_paused = False - - self.osc_state = osc_ctrl.OscState(generate_utils.config.CHARS_PER_SYNC, - generate_utils.config.BOARD_ROWS, - generate_utils.config.BOARD_COLS) - - def sleepInterruptible(self, dur_s, stride_ms = 5): - timeout = time.time() + dur_s - while self.audio_paused and self.run_app and time.time() < timeout: - time.sleep(stride_ms / 1000.0) - -def dumpMicDevices(): - p = pyaudio.PyAudio() - info = p.get_host_api_info_by_index(0) - numdevices = info.get('deviceCount') - - for i in range(0, numdevices): - if (p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - device_name = p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) - -def onAudioFramesAvailable( - audio_state, - input_rate, - frames, - frame_count, - time_info, - status_flags): - # Reduce sample rate from mic rate to Whisper rate by dropping frames. - decimated = b'' - frame_len = int(len(frames) / frame_count) - next_frame = 0.0 - keep_every = float(input_rate) / audio_state.RATE - #print(f"Keep every {keep_every}th frame") - #print(f"len frames: {len(frames)}") - #print(f"len decimated: {len(decimated)}") - i = 0 - for i in range(0, frame_count): - if i >= next_frame: - decimated += frames[i*frame_len:(i+1)*frame_len] - next_frame += keep_every - i += 1 - - if not audio_state.audio_paused: - audio_state.frames.append(decimated) - - # If buffer is getting long, tell the transcription loop to be more ready - # to accept transcripts. - fps = int(input_rate / audio_state.CHUNK) - cur_len_s = len(audio_state.frames) / fps - double_at_s = 3.0 - double_every_s = 1.5 - delta_s = cur_len_s - double_at_s - n_doubles = ceil(delta_s / double_every_s) - if n_doubles >= 1: - audio_state.commit_fuzz_threshold = 2 ** n_doubles - else: - audio_state.commit_fuzz_threshold = 1 - - if audio_state.drop_samples_till_i > 0: - # Caller wants us to keep this many *whisper* samples, assuming that - # we're getting one full frame every (1024 / 16KHz) seconds. - # However we really get one full whisper frame a little slower, since - # mics usually have a higher sample rate than 16 KHz (see decimation - # code above). - # The ratio of (mic sample rate) / (16KHz) is simply `keep_every`. - n_frames_to_drop = float(audio_state.drop_samples_till_i) / audio_state.CHUNK - n_frames_to_drop *= keep_every - n_frames_to_drop_int = int(floor(n_frames_to_drop)) - if audio_state.cfg["enable_debug_mode"]: - print(f"Dropping {n_frames_to_drop_int} frames, buffer has {len(audio_state.frames)} frames total") - # First drop every whole chunk - audio_state.frames = audio_state.frames[n_frames_to_drop_int:] - # Then drop the part of the most recent chunk we no longer want - if len(audio_state.frames) > 0: - n_samples_to_drop = int(ceil((n_frames_to_drop % 1.0) * audio_state.CHUNK / keep_every)) - if audio_state.cfg["enable_debug_mode"]: - print(f"Zeroing {n_samples_to_drop} samples in frame 0") - print(f"Frame 0 has length {len(audio_state.frames[0])}") - bytes_per_sample = 2 - audio_state.frames[0] = b'00' * n_samples_to_drop + audio_state.frames[0][n_samples_to_drop * bytes_per_sample:] - audio_state.drop_samples_till_i = -1 - - max_frames = int(input_rate * audio_state.MAX_LENGTH_S / - audio_state.CHUNK) - if len(audio_state.frames) > max_frames: - audio_state.frames = audio_state.frames[-1 * max_frames:] - - # Now enforce a minimum duration on frames. This reduces cases where the - # STT hallucinates random things. In the Whisper paper, they enforce a - # minimum audio buffer duration of 5.0 seconds, so I do the same here. - empty_chunk = b'00' * int(ceil(audio_state.CHUNK / keep_every)) - chunk_duration_s = float(audio_state.CHUNK) / audio_state.RATE - cur_duration_s = len(audio_state.frames) * chunk_duration_s - desired_min_duration_s = 5.0 - delta_duration_s = desired_min_duration_s - cur_duration_s - if delta_duration_s > 0: - delta_chunks = int(ceil(delta_duration_s / chunk_duration_s)) - if audio_state.cfg["enable_debug_mode"]: - print(f"Padding with {delta_duration_s} seconds ({delta_chunks} chunks) of silence") - print(f"Each chunk has {len(empty_chunk)} samples") - audio_state.frames = [empty_chunk] * delta_chunks + audio_state.frames - - return (frames, pyaudio.paContinue) - -def getMicStream(which_mic) -> AudioState: - audio_state = AudioState() - audio_state.p = pyaudio.PyAudio() - - print("Finding mic {}...".format(which_mic)) - dumpMicDevices() - got_match = False - device_index = -1 - focusrite_str = "Focusrite" - index_str = "Digital Audio Interface" - if which_mic == "index": - target_str = index_str - elif which_mic == "focusrite": - target_str = focusrite_str - else: - print("Mic {} requested, treating it as a numerical device ID".format(which_mic)) - device_index = int(which_mic) - got_match = True - - while got_match == False: - info = audio_state.p.get_host_api_info_by_index(0) - numdevices = info.get('deviceCount') - for i in range(0, numdevices): - if (audio_state.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - device_name = audio_state.p.get_device_info_by_host_api_device_index(0, i).get('name') - if target_str in device_name: - print("Got match: {}".format(device_name)) - device_index = i - got_match = True - break - if got_match == False: - print("No match, sleeping") - time.sleep(3) - - info = audio_state.p.get_device_info_by_host_api_device_index(0, device_index) - input_rate = int(info['defaultSampleRate']) - print("input rate: {}".format(input_rate)) - - # Bind audio_state to onAudioFramesAvailable - callback = partial(onAudioFramesAvailable, audio_state, input_rate) - - audio_state.stream = audio_state.p.open( - rate=input_rate, - channels=audio_state.CHANNELS, - format=audio_state.FORMAT, - input=True, frames_per_buffer=audio_state.CHUNK, - input_device_index=device_index, - stream_callback=callback) - - audio_state.stream.start_stream() - - return audio_state - -def resetAudioLocked(audio_state): - audio_state.frames = [] - audio_state.transcribe_no_change_count = 0 - audio_state.transcribe_sleep_duration = \ - audio_state.transcribe_sleep_duration_min_s - - if audio_state.cfg["reset_on_toggle"]: - if audio_state.cfg["enable_debug_mode"]: - print("resetAudioLocked resetting text") - audio_state.text = "" - audio_state.preview_text = "" - audio_state.filtered_text = "" - -def resetDisplayLocked(audio_state): - osc_ctrl.clear(audio_state.osc_state) - -# Transcribe the audio recorded in a file. -# Returns two strings: committed text, and preview text. -# Committed text is temporally stable. Preview text is *not* temporally stable, -# but is lower latency than committed text. -def transcribe(audio_state, model, frames, use_cpu: bool) -> typing.Tuple[str,str]: - start_time = time.time() - - frames = audio_state.frames - # Convert from signed 16-bit int [-32768, 32767] to signed 16-bit float on - # [-1, 1]. - # We should technically acquire a lock to protect frames, but this is - # really slow and in practice it doesn't make the app crash, so who cares. - frames = np.asarray(audio_state.frames) - audio = np.frombuffer(frames, np.int16).flatten().astype(np.float32) / 32768.0 - - segments, info = model.transcribe( - audio, - beam_size = 5, - language = langcodes.find(audio_state.cfg["language"]).language, - temperature = 0.0, - log_prob_threshold = -0.8, - vad_filter = True, - condition_on_previous_text = True, - without_timestamps = False) - ranges = [] - for s in segments: - if s.avg_logprob < -0.8 or s.no_speech_prob > 0.6: - continue - if audio_state.cfg["enable_debug_mode"]: - print(f"Segment: {s}") - ranges.append((s.start, s.end, s.text)) - audio_state.ranges_ls.append(ranges) - - committed_text = "" - if True: - # Tuple of (start time, end time, transcript) - first_segments = [] - for ranges in audio_state.ranges_ls: - for segment in ranges: - first_segments.append(segment) - break - if len(first_segments) >= 4: - # Hack: require convergence across many frames to give the - # algorithm a longer buffer to work with. - c0 = first_segments[-1] - c1 = first_segments[-2] - c2 = first_segments[-3] - c3 = first_segments[-4] - - c0_c1_d = editdistance.eval(c0[2], c1[2]) - c1_c2_d = editdistance.eval(c1[2], c2[2]) - c2_c3_d = editdistance.eval(c2[2], c3[2]) - - max_edit = audio_state.commit_fuzz_threshold - - if audio_state.cfg["enable_debug_mode"]: - print(f"c0: {c0}, c1: {c1}, c2: {c2}, c3: {c3}") - if c0_c1_d < max_edit and c1_c2_d < max_edit and c2_c3_d < max_edit: - # For simplicity, completely reset saved audio ranges. - audio_state.ranges_ls = [] - committed_text = c0[2] - if audio_state.cfg["enable_debug_mode"]: - print(f"Dropping frames until {c0[1]}") - n_samples_to_drop = int(ceil(audio_state.RATE * c0[1])) - audio_state.drop_samples_till_i = n_samples_to_drop - while audio_state.drop_samples_till_i == n_samples_to_drop: - # To prevent a race, wait until those audio samples are - # dropped by the microphone capture thread before returning. - time.sleep(.001) - - preview_text = "" - for seg in ranges: - if seg[2] == committed_text: - continue - preview_text += seg[2] - - return (committed_text, preview_text) - -def transcribeAudio(audio_state): - print("Ready!") - last_transcribe_time = time.time() - while audio_state.run_app == True: - # Pace this out. - # If `preview_text` is not empty, then we're still transcribing a - # message, so don't enter the idle path. - if audio_state.audio_paused and len(audio_state.preview_text) == 0: - audio_state.sleepInterruptible(audio_state.transcribe_sleep_duration) - - audio_state.transcribe_no_change_count += 1 - # Increase sleep time. Code below will set sleep time back to minimum - # if a change is detected. - longer_sleep_dur = audio_state.transcribe_sleep_duration - longer_sleep_dur += audio_state.transcribe_sleep_duration_min_s * (1.3**audio_state.transcribe_no_change_count) - audio_state.transcribe_sleep_duration = min( - 1000 * 1000, - longer_sleep_dur) - - text, preview_text = transcribe(audio_state, audio_state.cfg["model"], audio_state.frames, - audio_state.cfg["use_cpu"]) - if len(text) == 0 and len(preview_text) == 0: - if audio_state.cfg["enable_debug_mode"]: - print("no transcription, spin ({} seconds)".format(time.time() - last_transcribe_time)) - last_transcribe_time = time.time() - # Prevent audio buffer from holding more than a few seconds of silence - # before real speech. - audio_state.MAX_LENGTH_S = 5 - continue - else: - audio_state.MAX_LENGTH_S = 300 - - if audio_state.drop_transcription: - audio_state.drop_transcription = False - audio_state.text = "" - audio_state.preview_text = "" - audio_state.filtered_text = "" - if audio_state.cfg["enable_debug_mode"]: - print("drop transcription ({} seconds)".format(time.time() - last_transcribe_time)) - last_transcribe_time = time.time() - continue - - old_text = audio_state.text - audio_state.text += text - audio_state.preview_text = preview_text - - if len(preview_text) == 0: - print("Finalized: 1") - else: - print("Finalized: 0") - - # Hard cap transcript at 4096 chars. Letting it grow longer than this - # eventually causes lag. This happens routinely when streaming. Capping - # like this does not affect the visible portion of the transcript in - # OBS, but it might affect the visible portion in-game. (Don't make - # your friends read more than 4k characters on a fucking chatbox.) - audio_state.text = audio_state.text[-4096:] - - now = time.time() - if audio_state.cfg["enable_debug_mode"]: - print("Raw transcription ({} seconds): {}".format( - now - last_transcribe_time, - audio_state.text + audio_state.preview_text)) - last_transcribe_time = now - print(f"Commit text: {text}") - print(f"Preview text: {preview_text}") - - # Translate if requested. - translated = audio_state.text + audio_state.preview_text - if audio_state.language_target: - whisper_lang = audio_state.cfg["language"] - nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] - ss_lang = lang_compat.nllb_to_ss[nllb_lang] - sentences = split_text_into_sentences(translated, language=ss_lang) - - translated_sentences = [] - for sentence in sentences: - source = audio_state.tokenizer.convert_ids_to_tokens(audio_state.tokenizer.encode(sentence)) - target_prefix = [audio_state.language_target] - results = audio_state.translator.translate_batch([source], target_prefix=[target_prefix]) - target = results[0].hypotheses[0][1:] - translated_sentence = audio_state.tokenizer.decode(audio_state.tokenizer.convert_tokens_to_ids(target)) - translated_sentences.append(translated_sentence) - translated = " ".join(translated_sentences) - print(f"Translation: {translated}") - - # Apply filters to transcription - filtered_text = translated - if audio_state.cfg["enable_uwu_filter"]: - uwu_proc = subprocess.Popen(["Resources/Models/Uwwwu.exe", filtered_text], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - uwu_stdout, uwu_stderr = uwu_proc.communicate() - uwu_text = uwu_stdout.decode("utf-8") - uwu_text = uwu_text.replace("\n", "") - uwu_text = uwu_text.replace("\r", "") - filtered_text = uwu_text - if audio_state.cfg["remove_trailing_period"]: - if len(filtered_text) > 0 and filtered_text[-1] == '.' and not filtered_text.endswith("..."): - filtered_text = filtered_text[0:len(filtered_text)-1] - if audio_state.cfg["enable_uppercase_filter"]: - filtered_text = filtered_text.upper() - if audio_state.cfg["enable_lowercase_filter"]: - filtered_text = filtered_text.lower() - if audio_state.cfg["enable_profanity_filter"]: - filtered_text = audio_state.profanity_filter.filter(filtered_text) - audio_state.filtered_text = filtered_text - - now = time.time() - print("Transcription ({} seconds): {}".format( - now - last_transcribe_time, - filtered_text)) - last_transcribe_time = now - - if old_text != audio_state.text + audio_state.preview_text: - # We think the user said something, so reset the amount of - # time we sleep between transcriptions to the minimum. - audio_state.transcribe_no_change_count = 0 - audio_state.transcribe_sleep_duration = audio_state.transcribe_sleep_duration_min_s - -def sendAudio(audio_state): - estate = EmotesState() - while audio_state.run_app == True: - text = audio_state.filtered_text - if audio_state.cfg["use_builtin"]: - ret = osc_ctrl.pageMessageBuiltin(audio_state.osc_state, text) - time.sleep(1.5) - else: - ret = osc_ctrl.pageMessage(audio_state.osc_state, text, estate) - is_paging = (ret == False) - - # Pace this out - time.sleep(0.01) - -def readKeyboardInput(audio_state): - machine = keybind_event_machine.KeybindEventMachine(audio_state.cfg["keybind"]) - last_press_time = 0 - - # double pressing the keybind - double_press_timeout = 0.5 - - RECORD_STATE = 0 - PAUSE_STATE = 1 - state = PAUSE_STATE - - while audio_state.run_app == True: - time.sleep(0.05) - - cur_press_time = machine.getNextPressTime() - if cur_press_time == 0: - continue - - EVENT_SINGLE_PRESS = 0 - EVENT_DOUBLE_PRESS = 1 - if last_press_time == 0: - event = EVENT_SINGLE_PRESS - elif cur_press_time - last_press_time < double_press_timeout: - event = EVENT_DOUBLE_PRESS - else: - event = EVENT_SINGLE_PRESS - last_press_time = cur_press_time - - if event == EVENT_DOUBLE_PRESS: - state = PAUSE_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.toggleBoard(audio_state.osc_state.client, False) - - if audio_state.cfg["reset_on_toggle"]: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (1)") - audio_state.drop_transcription = True - else: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (1)") - audio_state.text += audio_state.preview_text - audio_state.audio_paused = True - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - continue - - # Short hold - if state == RECORD_STATE: - state = PAUSE_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.lockWorld(audio_state.osc_state.client, True) - audio_state.transcribe_sleep_duration = audio_state.transcribe_sleep_duration_min_s - - audio_state.audio_paused = True - - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF) - elif state == PAUSE_STATE: - state = RECORD_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.toggleBoard(audio_state.osc_state.client, True) - osc_ctrl.lockWorld(audio_state.osc_state.client, False) - osc_ctrl.ellipsis(audio_state.osc_state.client, True) - if audio_state.cfg["reset_on_toggle"]: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (2)") - audio_state.drop_transcription = True - else: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (2)") - audio_state.text += audio_state.preview_text - audio_state.audio_paused = False - - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON) - - -def audioFeedbackThread(audio_state): - with open(os.path.abspath("Resources/Sounds/Noise_On_Quiet.wav"), "rb") as f: - waveform0 = f.read() - with open(os.path.abspath("Resources/Sounds/Noise_Off_Quiet.wav"), "rb") as f: - waveform1 = f.read() - with open(os.path.abspath("Resources/Sounds/Dismiss_Noise_Quiet.wav"), "rb") as f: - waveform2 = f.read() - with open(os.path.abspath("Resources/Sounds/KB_Noise_Off_Quiet.wav"), "rb") as f: - waveform3 = f.read() - while audio_state.run_app == True: - time.sleep(0.01) - - if len(audio_state.audio_events) == 0: - continue - - event = audio_state.audio_events[0] - audio_state.audio_events = audio_state.audio_events[1:] - - waveform = waveform0 - if event == audio_state.AUDIO_EVENT_TOGGLE_ON: - waveform = waveform0 - elif event == audio_state.AUDIO_EVENT_TOGGLE_OFF: - waveform = waveform1 - elif event == audio_state.AUDIO_EVENT_DISMISS: - waveform = waveform2 - elif event == audio_state.AUDIO_EVENT_UPDATE: - waveform = waveform3 - winsound.PlaySound(waveform, winsound.SND_MEMORY) - -def readControllerInput(audio_state): - RECORD_STATE = 0 - PAUSE_STATE = 1 - state = PAUSE_STATE - - hand_id = audio_state.cfg["button"].split()[0] - button_id = audio_state.cfg["button"].split()[1] - - # Rough description of state machine: - # Single short press: toggle transcription - # Medium press: dismiss custom chatbox - # Long press: update chatbox in place - # Medium press + long press: type transcription - - last_rising = time.time() - last_medium_press_end = 0 - - button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id) - while audio_state.run_app == True: - time.sleep(0.01) - event = next(button_generator) - - if event.opcode == steamvr.EVENT_RISING_EDGE: - last_rising = time.time() - - if state == PAUSE_STATE: - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - audio_state.drop_transcription = True - audio_state.audio_paused = False - - elif event.opcode == steamvr.EVENT_FALLING_EDGE: - now = time.time() - if now - last_rising > 1.5: - # Long press: treat as the end of transcription. - state = PAUSE_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.lockWorld(audio_state.osc_state.client, True) - audio_state.transcribe_sleep_duration = audio_state.transcribe_sleep_duration_min_s - audio_state.audio_paused = True - - if last_rising - last_medium_press_end < 1.0: - # Type transcription - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_UPDATE) - keyboard.write(audio_state.filtered_text) - else: - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF) - - elif now - last_rising > 0.5: - # Medium press - last_medium_press_end = now - state = PAUSE_STATE - - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_DISMISS) - - if not audio_state.cfg["use_builtin"]: - osc_ctrl.toggleBoard(audio_state.osc_state.client, False) - - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - audio_state.drop_transcription = True - audio_state.audio_paused = True - else: - # Short hold - if state == RECORD_STATE: - state = PAUSE_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.lockWorld(audio_state.osc_state.client, True) - audio_state.transcribe_sleep_duration = audio_state.transcribe_sleep_duration_min_s - - audio_state.audio_paused = True - - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_OFF) - elif state == PAUSE_STATE: - state = RECORD_STATE - if not audio_state.cfg["use_builtin"]: - osc_ctrl.toggleBoard(audio_state.osc_state.client, True) - osc_ctrl.lockWorld(audio_state.osc_state.client, False) - osc_ctrl.ellipsis(audio_state.osc_state.client, True) - if audio_state.cfg["reset_on_toggle"]: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, dropping transcript (3)") - audio_state.drop_transcription = True - else: - if audio_state.cfg["enable_debug_mode"]: - print("Toggle detected, committing preview text (3)") - audio_state.text += audio_state.preview_text - - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - - if audio_state.cfg["enable_local_beep"]: - audio_state.audio_events.append(audio_state.AUDIO_EVENT_TOGGLE_ON) - -# model should correspond to one of the Whisper models defined in -# whisper/__init__.py. Examples: tiny, base, small, medium. -def transcribeLoop(config_path: str): - cfg = app_config.getConfig(config_path) - - generate_utils.config.BYTES_PER_CHAR = int(cfg["bytes_per_char"]) - generate_utils.config.CHARS_PER_SYNC = int(cfg["chars_per_sync"]) - generate_utils.config.BOARD_ROWS = int(cfg["rows"]) - generate_utils.config.BOARD_COLS = int(cfg["cols"]) - - audio_state = getMicStream(cfg["microphone"]) - audio_state.cfg = cfg - - # Set up profanity filter - en_profanity_path = os.path.abspath("Resources/Profanity/en") - audio_state.profanity_filter = ProfanityFilter(en_profanity_path) - if cfg["enable_profanity_filter"]: - audio_state.profanity_filter.load() - - lang_bits = cfg["language_target"].split(" | ") - if len(lang_bits) == 2: - lang_code = lang_bits[1] - audio_state.language_target = lang_code - else: - audio_state.language_target = None - - if audio_state.language_target: - print("Translation requested") - - print("Installing torch and sentencepiece in virtual environment. " - "Nothing will print " - "for several minutes while these download (~2.4 GB).") - pip_proc = subprocess.Popen( - "Resources/Python/python.exe -m pip install sentencepiece torch".split(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - pip_stdout, pip_stderr = pip_proc.communicate() - pip_stdout = pip_stdout.decode("utf-8") - pip_stderr = pip_stderr.decode("utf-8") - print(pip_stdout) - print(pip_stderr) - if pip_proc.returncode != 0: - print(f"Failed to set up for translation: `pip install torch` " - "exited with {pip_proc.returncode}") - - output_dir = "Resources/" + cfg["model_translation"] - # Provided by ctranslate2 Python package - cmd = "ct2-transformers-converter.exe --model facebook/" + \ - cfg["model_translation"] + " --output_dir " + output_dir - - print(f"Fetching translation algorithm ({cfg['model_translation']})") - if not os.path.exists(output_dir): - ct2_proc = subprocess.Popen( - cmd.split(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - ct2_stdout, ct2_stderr = ct2_proc.communicate() - ct2_stdout = ct2_stdout.decode("utf-8") - ct2_stderr = ct2_stderr.decode("utf-8") - print(ct2_stdout) - print(ct2_stderr) - if ct2_proc.returncode != 0: - print(f"Failed to get NLLB model: ct2 process exited with " - "{ct2_proc.returncode}") - print(f"Using model at {output_dir}") - - audio_state.translator = ctranslate2.Translator(output_dir) - - whisper_lang = cfg["language"] - nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] - - audio_state.tokenizer = transformers.AutoTokenizer.from_pretrained( - "facebook/" + cfg["model_translation"], - src_lang=nllb_lang) - - print(f"Translation ready to go") - - abspath = os.path.abspath(__file__) - dname = os.path.dirname(abspath) - model_root = os.path.join(dname, "Models", cfg["model"]) - - print("Model {} will be saved to {}".format(cfg["model"], model_root)) - - model_device = "cuda" - if cfg["use_cpu"]: - model_device = "cpu" - - download_it = os.path.exists(model_root) - if download_it: - cfg["model"] = model_root - cfg["model"] = WhisperModel(cfg["model"], - device = model_device, - device_index = cfg["gpu_idx"], - compute_type = "int8", - download_root = model_root, - local_files_only = download_it) - - transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state]) - transcribe_audio_thd.daemon = True - transcribe_audio_thd.start() - - send_audio_thd = threading.Thread(target = sendAudio, args = [audio_state]) - send_audio_thd.daemon = True - send_audio_thd.start() - - controller_input_thd = threading.Thread(target = readControllerInput, args = [audio_state]) - controller_input_thd.daemon = True - controller_input_thd.start() - - audio_feedback_thd = threading.Thread(target = audioFeedbackThread, args = [audio_state]) - audio_feedback_thd.daemon = True - audio_feedback_thd.start() - - keyboard_input_thd = threading.Thread(target = readKeyboardInput, args = [audio_state]) - keyboard_input_thd.daemon = True - keyboard_input_thd.start() - - for line in sys.stdin: - audio_state.transcribe_lock.acquire() - audio_state.audio_lock.acquire() - resetAudioLocked(audio_state) - resetDisplayLocked(audio_state) - audio_state.drop_transcription = True - audio_state.audio_paused = False - audio_state.audio_lock.release() - audio_state.transcribe_lock.release() - if "exit" in line or "quit" in line: - break - - print("Joining threads") - audio_state.run_app = False - transcribe_audio_thd.join() - controller_input_thd.join() - audio_feedback_thd.join() - keyboard_input_thd.join() - -if __name__ == "__main__": - sys.stdout.reconfigure(encoding="utf-8") - - print("args: {}".format(" ".join(sys.argv))) - - print(f"Set cwd to {os.getcwd()}") - - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, help="Path to app config YAML file.") - args = parser.parse_args() - - print(f"PATH: {os.environ['PATH']}") - - transcribeLoop(args.config) - diff --git a/Scripts/vad.py b/Scripts/vad.py new file mode 100644 index 0000000..25f0ad0 --- /dev/null +++ b/Scripts/vad.py @@ -0,0 +1,315 @@ +# MIT License +# +# Copyright (c) 2023 Guillaume Klein +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import bisect +import functools +import os +import warnings + +from typing import List, NamedTuple, Optional + +import numpy as np + + +# The code below is adapted from https://github.com/snakers4/silero-vad. +class VadOptions(NamedTuple): + """VAD options. + + Attributes: + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model performance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +def get_speech_timestamps( + audio: np.ndarray, + vad_options: Optional[VadOptions] = None, + **kwargs, +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + if vad_options is None: + vad_options = VadOptions(**kwargs) + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + model = get_vad_model() + state = model.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[current_start_sample : current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = model(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + +def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + + +class SpeechTimestampsMap: + """Helper class to restore original speech timestamps.""" + + def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): + self.sampling_rate = sampling_rate + self.time_precision = time_precision + self.chunk_end_sample = [] + self.total_silence_before = [] + + previous_end = 0 + silent_samples = 0 + + for chunk in chunks: + silent_samples += chunk["start"] - previous_end + previous_end = chunk["end"] + + self.chunk_end_sample.append(chunk["end"] - silent_samples) + self.total_silence_before.append(silent_samples / sampling_rate) + + def get_original_time( + self, + time: float, + chunk_index: Optional[int] = None, + ) -> float: + if chunk_index is None: + chunk_index = self.get_chunk_index(time) + + total_silence_before = self.total_silence_before[chunk_index] + return round(total_silence_before + time, self.time_precision) + + def get_chunk_index(self, time: float) -> int: + sample = int(time * self.sampling_rate) + return min( + bisect.bisect(self.chunk_end_sample, sample), + len(self.chunk_end_sample) - 1, + ) + + +@functools.lru_cache +def get_vad_model(): + """Returns the VAD model instance.""" + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + parent_dir = os.path.dirname(my_dir) + + path = os.path.join(parent_dir, "Models/silero_vad.onnx") + return SileroVADModel(path) + + +class SileroVADModel: + def __init__(self, path): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state |
