from datetime import datetime from faster_whisper import WhisperModel from functools import partial from pydub import AudioSegment from whisper.normalizers import EnglishTextNormalizer from scipy.optimize import minimize import app_config import argparse import editdistance import langcodes import math import numpy as np import os import pyaudio import time import typing class AudioStream(): FORMAT = pyaudio.paInt16 # Size of each frame (audio sample), in bytes. If you change FORMAT, make # sure this stays up to date! FRAME_SZ = 2 # Frames per second. FPS = 16000 CHANNELS = 1 def __init__(self): pass def getSamples(self) -> bytes: raise NotImplementedError("getSamples is not implemented!") class DiskStream(AudioStream): def __init__(self, path: str, pace_at_real_time: bool): fmt = None if path.endswith(".mp3"): fmt = "mp3" elif path.endswith(".wav"): fmt = "wav" else: raise NotImplementedError(f"Requested file type {path} " + \ "is not supported") print(f"Loading audio data") audio = AudioSegment.from_file(path, format=fmt) audio.set_channels(1) # TODO(yum) replace manual decimation code with this! audio = audio.set_frame_rate(16000) frames = np.array(audio.get_array_of_samples()) frames = np.int16(frames).tobytes() self.frames = frames self.wall_ts = time.time() self.pace_at_real_time = pace_at_real_time print(f"Loaded data") def getSamples(self) -> bytes: if self.pace_at_real_time: now = time.time() nframes = int((now - self.wall_ts) * AudioStream.FPS) self.wall_ts = now frames = self.frames[0:nframes * AudioStream.FRAME_SZ]; self.frames = self.frames[nframes * AudioStream.FRAME_SZ:] return frames frames = self.frames self.frames = b'' return frames class MicStream(AudioStream): CHUNK_SZ = 1024 def __init__(self, which_mic: str): self.p = pyaudio.PyAudio() self.stream = None self.sample_rate = None # Each time pyaudio gives us audio data, it's in the form of a chunk of # samples. We keep these in a list to keep the audio callback as light # as possible. Whenever downstream layers want data, we collapse the # list into a single array of data (a bytes object). self.chunks = [] print(f"Finding mic {which_mic}") self.dumpMicDevices() got_match = False device_index = -1 focusrite_str = "Focusrite" index_str = "Digital Audio Interface" if which_mic == "index": target_str = index_str elif which_mic == "focusrite": target_str = focusrite_str else: print(f"Mic {which_mic} requested, treating it as a numerical " + "device ID") device_index = int(which_mic) got_match = True if not got_match: info = self.p.get_host_api_info_by_index(0) numdevices = info.get('deviceCount') for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') if target_str in device_name: print(f"Got matching mic: {device_name}") device_index = i got_match = True break if not got_match: raise KeyError(f"Mic {which_mic} not found") info = self.p.get_device_info_by_host_api_device_index(0, device_index) print(f"Found mic {which_mic}: {info['name']}") self.sample_rate = int(info['defaultSampleRate']) print(f"Mic sample rate: {self.sample_rate}") self.stream = self.p.open( rate=self.sample_rate, channels=AudioStream.CHANNELS, format=AudioStream.FORMAT, input=True, frames_per_buffer=MicStream.CHUNK_SZ, input_device_index=device_index, stream_callback=self.onAudioFramesAvailable) self.stream.start_stream() AudioStream.__init__(self) def dumpMicDevices(self): info = self.p.get_host_api_info_by_index(0) numdevices = info.get('deviceCount') for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') print("Input Device id ", i, " - ", device_name) def onAudioFramesAvailable(self, frames, frame_count, time_info, status_flags): decimated = b'' # In pyaudio, a `frame` is a single sample of audio data. frame_len = AudioStream.FRAME_SZ next_frame = 0.0 # The mic probably has a higher sample rate than Whisper wants, so # decrease the sample rate by dropping samples. Note that this # algorithm only works if the mic's rate is higher than whisper's # expected rate. keep_every = float(self.sample_rate) / AudioStream.FPS for i in range(frame_count): if i >= next_frame: decimated += frames[i*frame_len:(i+1)*frame_len] next_frame += keep_every self.chunks.append(decimated) return (frames, pyaudio.paContinue) # Get audio data and the corresponding timestamp. def getSamples(self) -> bytes: chunks = self.chunks self.chunks = [] result = b''.join(chunks) return result class AudioCollector: def __init__(self, stream: AudioStream): self.stream = stream self.frames = b'' # Note: by design, this is the only spot where we anchor our timestamps # against the real world. This is done to make it possible to profile # test cases which read from disk (at much faster than real speed) in # the same way that we profile real-time data. self.wall_ts = time.time() def getAudio(self) -> bytes: frames = self.stream.getSamples() if frames: self.frames += frames return self.frames def dropAudioPrefix(self, dur_s: float): n_bytes = int(dur_s * self.stream.FPS) * self.stream.FRAME_SZ n_bytes = min(n_bytes, len(self.frames)) self.frames = self.frames[n_bytes:] self.wall_ts = self.wall_ts + self.duration() def dropAudio(self): self.wall_ts += self.duration() self.frames = b'' def duration(self): return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ) def now(self): return self.wall_ts + self.duration() # A segment of transcribed audio. `start_ts` and `end_ts` are floating point # number of seconds since the beginning of audio data. class Segment: def __init__(self, transcript: str, start_ts: float, end_ts: float, wall_ts: float, avg_logprob: float, no_speech_prob: float): self.transcript = transcript # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. self.start_ts = start_ts self.end_ts = end_ts # wall_ts is the time.time() at which the oldest audio sample leading # to this transcript was collected. self.wall_ts = wall_ts self.avg_logprob = avg_logprob self.no_speech_prob = no_speech_prob def __str__(self): ts = f"(ts: {self.start_ts}-{self.end_ts}) " wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " no_speech = f"(no_speech: {self.no_speech_prob}) " avg_logprob = f"(avg_logprob: {self.avg_logprob}) " return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob class Whisper: def __init__(self, collector: AudioCollector, cfg: typing.Dict): self.collector = collector self.model = None self.cfg = cfg abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) model_root = os.path.join(dname, "Models", cfg["model"]) print("Model {} will be saved to {}".format(cfg["model"], model_root)) model_device = "cuda" if cfg["use_cpu"]: model_device = "cpu" download_it = os.path.exists(model_root) model_str = cfg["model"] if download_it: model_str = model_root self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], compute_type = "int8", download_root = model_root, local_files_only = download_it) def transcribe(self) -> typing.List[Segment]: frames = self.collector.getAudio() # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on # [-1, 1]. audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 segments, info = self.model.transcribe( audio, beam_size = 5, language = langcodes.find(self.cfg["language"]).language, temperature = 0.0, log_prob_threshold = -0.8, vad_filter = True, condition_on_previous_text = True, without_timestamps = False) res = [] for s in segments: res.append(Segment(s.text, s.start, s.end, self.collector.wall_ts, s.avg_logprob, s.no_speech_prob)) return res class TranscriptCommit: def __init__(self, delta: str, preview: str, latency_s: int = None, thresh_at_commit: int = None): self.delta = delta self.preview = preview self.latency_s = latency_s self.thresh_at_commit = thresh_at_commit class FuzzyRepeatCommitter: def __init__(self, collector: AudioCollector, whisper: Whisper, last_n_must_match: int = 4, edit_thresh_min: int = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 1.5, min_segment_age_s: float = 1.0): self.collector = collector self.whisper = whisper # List of candidate segments. Once these all match, we commit the # corresponding audio data. self.candidates = [] self.last_n_must_match = last_n_must_match self.edit_thresh_min = edit_thresh_min self.edit_thresh_grow_begin_s = edit_thresh_grow_begin_s self.edit_thresh_grow_halflife_s = edit_thresh_grow_halflife_s self.min_segment_age_s = min_segment_age_s def getDelta(self) -> TranscriptCommit: segments = self.whisper.transcribe() preview = ''.join(s.transcript for s in segments) if len(segments) == 0: return TranscriptCommit("", preview, None) s = segments[0] if len(self.candidates) < self.last_n_must_match: if len(self.candidates) == 0: self.candidates.append(s) return TranscriptCommit("", preview, None) s0 = self.candidates[0] if s.wall_ts != s0.wall_ts: print("Frames dropped, committer resetting candidates") self.candidates = [] return TranscriptCommit("", preview, None) self.candidates.append(s) return TranscriptCommit("", preview, None) # Rule 1: last n segments must be within a certain edit distance of # each other. This edit distance starts low and increases exponentially # as the buffer size grows, thus allowing the check to get weaker under # compute pressure. edit_thresh = self.edit_thresh_min dt = self.collector.now() - (self.collector.wall_ts + s.start_ts) if dt > self.edit_thresh_grow_begin_s: dt -= self.edit_thresh_grow_begin_s edit_thresh = int(math.ceil(2**(dt / self.edit_thresh_grow_halflife_s))) drop_candidates = 0 for i in range(1, len(self.candidates)): prev = self.candidates[i-1] cur = self.candidates[i] dist = editdistance.eval(prev.transcript, cur.transcript) if dist > edit_thresh: drop_candidates = i if drop_candidates != 0: self.candidates = self.candidates[drop_candidates:] return TranscriptCommit("", preview, None) candidate = self.candidates[-1] # Rule 2: no committing segments that are fewer than the configured # number of seconds old. if self.collector.now() - (candidate.end_ts + candidate.wall_ts) < self.min_segment_age_s: self.candidates = [] return TranscriptCommit("", preview, None) # Got a candidate! Commit it and return. self.candidates = [] latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts) self.collector.dropAudioPrefix(candidate.end_ts) return TranscriptCommit(candidate.transcript, preview, latency_s, thresh_at_commit = edit_thresh) def evaluate(cfg, audio_path: str, control_path: str, last_n_must_match: int = 4, edit_thresh_min: int = 1, edit_thresh_grow_begin_s: float = 1.5, edit_thresh_grow_halflife_s: float = 1.5, min_segment_age_s: float = 1.0 ): stream = DiskStream(audio_path, True) collector = AudioCollector(stream) whisper = Whisper(collector, cfg) com = FuzzyRepeatCommitter(collector, whisper, last_n_must_match=last_n_must_match, edit_thresh_min=edit_thresh_min, edit_thresh_grow_begin_s=edit_thresh_grow_begin_s, edit_thresh_grow_halflife_s=edit_thresh_grow_halflife_s, min_segment_age_s=min_segment_age_s) transcript = "" commits = [] while len(stream.frames) > 0: commit = com.getDelta() if len(stream.frames) == 0: commit.delta = commit.preview commit.latency_s = 0 if len(commit.delta) > 0: commits.append(commit) transcript += commit.delta #if len(commit.delta): # print(f"transcript: {transcript}") # print(f"commit latency: {commit.latency_s}") # print(f"commit thresh: {commit.thresh_at_commit}") with open(control_path, "r") as f: control = f.read() normalizer = EnglishTextNormalizer() control = normalizer(control) experiment = normalizer(transcript) sum_latency = 0 for commit in commits: sum_latency += commit.latency_s avg_latency = sum_latency / len(commits) dist = editdistance.eval(control, experiment) print(f"PARAMS") print(f"last_n_must_match: {last_n_must_match}") print(f"edit_thresh_min: {edit_thresh_min}") print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}") print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}") print(f"min_segment_age_s: {min_segment_age_s}") print(f"RESULTS") print(f"edit distance: {dist}") print(f"avg latency: {avg_latency}") print(f"num commits: {len(commits)}") print(f"final transcript: {transcript}") score = dist * avg_latency print(f"score: {score}") return score def optimize(cfg, audio_path: str, control_path: str): def wrapper_to_optimize(x): return evaluate( cfg, audio_path, control_path, int(x[0]), # last_n_must_match int(x[1]), # edit_thresh_min x[2], # edit_thresh_grow_begin_s x[3], # edit_thresh_grow_halflife_s x[4] # min_segment_age_s ) initial_guess = [4, 1, 1.5, 1.5, 1.0] bounds = [ (1, 5), # last_n_must_match (1, 100), # edit_thresh_min (0, 10), # edit_thresh_grow_begin_s (0.1, 10), # edit_thresh_grow_halflife_s (0, 3) # min_segment_age_s ] result = minimize( wrapper_to_optimize, initial_guess, bounds=bounds, method='L-BFGS-B', ) optimized_params = result.x print("Optimized Parameters:") print(f"last_n_must_match: {int(optimized_params[0])}") print(f"edit_thresh_min: {int(optimized_params[1])}") print(f"edit_thresh_grow_begin_s: {optimized_params[2]}") print(f"edit_thresh_grow_halflife_s: {optimized_params[3]}") print(f"min_segment_age_s: {optimized_params[4]}") return optimized_params if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="Path to app config YAML file.") args = parser.parse_args() cfg = app_config.getConfig(args.config) optimize(cfg, "Evaluate/declaration_short/audio.mp3", "Evaluate/declaration_short/control.txt")