diff options
| -rw-r--r-- | Scripts/evaluate_requirements.txt | 2 | ||||
| -rw-r--r-- | Scripts/requirements.txt | 1 | ||||
| -rw-r--r-- | Scripts/transcribe_v2.py | 494 |
3 files changed, 497 insertions, 0 deletions
diff --git a/Scripts/evaluate_requirements.txt b/Scripts/evaluate_requirements.txt new file mode 100644 index 0000000..21e8582 --- /dev/null +++ b/Scripts/evaluate_requirements.txt @@ -0,0 +1,2 @@ +git+https://github.com/openai/whisper.git +scipy diff --git a/Scripts/requirements.txt b/Scripts/requirements.txt index 8989ed1..cba3d15 100644 --- a/Scripts/requirements.txt +++ b/Scripts/requirements.txt @@ -8,6 +8,7 @@ language-data openvr pillow pyaudio +pydub python-osc pyyaml sentence_splitter diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py new file mode 100644 index 0000000..3cf8fe7 --- /dev/null +++ b/Scripts/transcribe_v2.py @@ -0,0 +1,494 @@ +from datetime import datetime +from faster_whisper import WhisperModel +from functools import partial +from pydub import AudioSegment +from whisper.normalizers import EnglishTextNormalizer +from scipy.optimize import minimize + +import app_config +import argparse +import editdistance +import langcodes +import math +import numpy as np +import os +import pyaudio +import time +import typing + +class AudioStream(): + FORMAT = pyaudio.paInt16 + # Size of each frame (audio sample), in bytes. If you change FORMAT, make + # sure this stays up to date! + FRAME_SZ = 2 + # Frames per second. + FPS = 16000 + CHANNELS = 1 + def __init__(self): + pass + + def getSamples(self) -> bytes: + raise NotImplementedError("getSamples is not implemented!") + +class DiskStream(AudioStream): + def __init__(self, path: str, pace_at_real_time: bool): + fmt = None + if path.endswith(".mp3"): + fmt = "mp3" + elif path.endswith(".wav"): + fmt = "wav" + else: + raise NotImplementedError(f"Requested file type {path} " + \ + "is not supported") + print(f"Loading audio data") + audio = AudioSegment.from_file(path, format=fmt) + audio.set_channels(1) + # TODO(yum) replace manual decimation code with this! + audio = audio.set_frame_rate(16000) + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + self.frames = frames + self.wall_ts = time.time() + self.pace_at_real_time = pace_at_real_time + + print(f"Loaded data") + + def getSamples(self) -> bytes: + if self.pace_at_real_time: + now = time.time() + nframes = int((now - self.wall_ts) * AudioStream.FPS) + self.wall_ts = now + frames = self.frames[0:nframes * AudioStream.FRAME_SZ]; + self.frames = self.frames[nframes * AudioStream.FRAME_SZ:] + return frames + + frames = self.frames + self.frames = b'' + return frames + +class MicStream(AudioStream): + CHUNK_SZ = 1024 + + def __init__(self, which_mic: str): + self.p = pyaudio.PyAudio() + self.stream = None + self.sample_rate = None + # Each time pyaudio gives us audio data, it's in the form of a chunk of + # samples. We keep these in a list to keep the audio callback as light + # as possible. Whenever downstream layers want data, we collapse the + # list into a single array of data (a bytes object). + self.chunks = [] + + print(f"Finding mic {which_mic}") + self.dumpMicDevices() + + got_match = False + device_index = -1 + focusrite_str = "Focusrite" + index_str = "Digital Audio Interface" + if which_mic == "index": + target_str = index_str + elif which_mic == "focusrite": + target_str = focusrite_str + else: + print(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID") + device_index = int(which_mic) + got_match = True + if not got_match: + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + if target_str in device_name: + print(f"Got matching mic: {device_name}") + device_index = i + got_match = True + break + if not got_match: + raise KeyError(f"Mic {which_mic} not found") + + info = self.p.get_device_info_by_host_api_device_index(0, device_index) + print(f"Found mic {which_mic}: {info['name']}") + self.sample_rate = int(info['defaultSampleRate']) + print(f"Mic sample rate: {self.sample_rate}") + + self.stream = self.p.open( + rate=self.sample_rate, + channels=AudioStream.CHANNELS, + format=AudioStream.FORMAT, + input=True, + frames_per_buffer=MicStream.CHUNK_SZ, + input_device_index=device_index, + stream_callback=self.onAudioFramesAvailable) + + self.stream.start_stream() + + AudioStream.__init__(self) + + def dumpMicDevices(self): + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + print("Input Device id ", i, " - ", device_name) + + def onAudioFramesAvailable(self, + frames, + frame_count, + time_info, + status_flags): + decimated = b'' + # In pyaudio, a `frame` is a single sample of audio data. + frame_len = AudioStream.FRAME_SZ + next_frame = 0.0 + # The mic probably has a higher sample rate than Whisper wants, so + # decrease the sample rate by dropping samples. Note that this + # algorithm only works if the mic's rate is higher than whisper's + # expected rate. + keep_every = float(self.sample_rate) / AudioStream.FPS + for i in range(frame_count): + if i >= next_frame: + decimated += frames[i*frame_len:(i+1)*frame_len] + next_frame += keep_every + self.chunks.append(decimated) + + return (frames, pyaudio.paContinue) + + # Get audio data and the corresponding timestamp. + def getSamples(self) -> bytes: + chunks = self.chunks + self.chunks = [] + result = b''.join(chunks) + return result + +class AudioCollector: + def __init__(self, stream: AudioStream): + self.stream = stream + self.frames = b'' + # Note: by design, this is the only spot where we anchor our timestamps + # against the real world. This is done to make it possible to profile + # test cases which read from disk (at much faster than real speed) in + # the same way that we profile real-time data. + self.wall_ts = time.time() + + def getAudio(self) -> bytes: + frames = self.stream.getSamples() + if frames: + self.frames += frames + return self.frames + + def dropAudioPrefix(self, dur_s: float): + n_bytes = int(dur_s * self.stream.FPS) * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + self.frames = self.frames[n_bytes:] + self.wall_ts = self.wall_ts + self.duration() + + def dropAudio(self): + self.wall_ts += self.duration() + self.frames = b'' + + def duration(self): + return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ) + + def now(self): + return self.wall_ts + self.duration() + +# A segment of transcribed audio. `start_ts` and `end_ts` are floating point +# number of seconds since the beginning of audio data. +class Segment: + def __init__(self, + transcript: str, + start_ts: float, + end_ts: float, + wall_ts: float, + avg_logprob: float, + no_speech_prob: float): + self.transcript = transcript + # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. + self.start_ts = start_ts + self.end_ts = end_ts + # wall_ts is the time.time() at which the oldest audio sample leading + # to this transcript was collected. + self.wall_ts = wall_ts + self.avg_logprob = avg_logprob + self.no_speech_prob = no_speech_prob + + def __str__(self): + ts = f"(ts: {self.start_ts}-{self.end_ts}) " + + wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " + + no_speech = f"(no_speech: {self.no_speech_prob}) " + avg_logprob = f"(avg_logprob: {self.avg_logprob}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + +class Whisper: + def __init__(self, + collector: AudioCollector, + cfg: typing.Dict): + self.collector = collector + self.model = None + self.cfg = cfg + + abspath = os.path.abspath(__file__) + dname = os.path.dirname(abspath) + + model_root = os.path.join(dname, "Models", cfg["model"]) + print("Model {} will be saved to {}".format(cfg["model"], model_root)) + + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + + download_it = os.path.exists(model_root) + model_str = cfg["model"] + if download_it: + model_str = model_root + self.model = WhisperModel(model_str, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = "int8", + download_root = model_root, + local_files_only = download_it) + + def transcribe(self) -> typing.List[Segment]: + frames = self.collector.getAudio() + # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on + # [-1, 1]. + audio = np.frombuffer(frames, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + + segments, info = self.model.transcribe( + audio, + beam_size = 5, + language = langcodes.find(self.cfg["language"]).language, + temperature = 0.0, + log_prob_threshold = -0.8, + vad_filter = True, + condition_on_previous_text = True, + without_timestamps = False) + res = [] + for s in segments: + res.append(Segment(s.text, s.start, s.end, + self.collector.wall_ts, + s.avg_logprob, s.no_speech_prob)) + return res + +class TranscriptCommit: + def __init__(self, + delta: str, + preview: str, + latency_s: int = None, + thresh_at_commit: int = None): + self.delta = delta + self.preview = preview + self.latency_s = latency_s + self.thresh_at_commit = thresh_at_commit + +class FuzzyRepeatCommitter: + def __init__(self, + collector: AudioCollector, + whisper: Whisper, + last_n_must_match: int = 4, + edit_thresh_min: int = 1, + edit_thresh_grow_begin_s: float = 1.5, + edit_thresh_grow_halflife_s: float = 1.5, + min_segment_age_s: float = 1.0): + self.collector = collector + self.whisper = whisper + # List of candidate segments. Once these all match, we commit the + # corresponding audio data. + self.candidates = [] + self.last_n_must_match = last_n_must_match + self.edit_thresh_min = edit_thresh_min + self.edit_thresh_grow_begin_s = edit_thresh_grow_begin_s + self.edit_thresh_grow_halflife_s = edit_thresh_grow_halflife_s + self.min_segment_age_s = min_segment_age_s + + def getDelta(self) -> TranscriptCommit: + segments = self.whisper.transcribe() + preview = ''.join(s.transcript for s in segments) + + if len(segments) == 0: + return TranscriptCommit("", preview, None) + + s = segments[0] + + if len(self.candidates) < self.last_n_must_match: + if len(self.candidates) == 0: + self.candidates.append(s) + return TranscriptCommit("", preview, None) + s0 = self.candidates[0] + if s.wall_ts != s0.wall_ts: + print("Frames dropped, committer resetting candidates") + self.candidates = [] + return TranscriptCommit("", preview, None) + self.candidates.append(s) + return TranscriptCommit("", preview, None) + + # Rule 1: last n segments must be within a certain edit distance of + # each other. This edit distance starts low and increases exponentially + # as the buffer size grows, thus allowing the check to get weaker under + # compute pressure. + edit_thresh = self.edit_thresh_min + dt = self.collector.now() - (self.collector.wall_ts + s.start_ts) + if dt > self.edit_thresh_grow_begin_s: + dt -= self.edit_thresh_grow_begin_s + edit_thresh = int(math.ceil(2**(dt / + self.edit_thresh_grow_halflife_s))) + + drop_candidates = 0 + for i in range(1, len(self.candidates)): + prev = self.candidates[i-1] + cur = self.candidates[i] + dist = editdistance.eval(prev.transcript, cur.transcript) + if dist > edit_thresh: + drop_candidates = i + if drop_candidates != 0: + self.candidates = self.candidates[drop_candidates:] + return TranscriptCommit("", preview, None) + + candidate = self.candidates[-1] + + # Rule 2: no committing segments that are fewer than the configured + # number of seconds old. + if self.collector.now() - (candidate.end_ts + candidate.wall_ts) < self.min_segment_age_s: + self.candidates = [] + return TranscriptCommit("", preview, None) + + # Got a candidate! Commit it and return. + self.candidates = [] + latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts) + self.collector.dropAudioPrefix(candidate.end_ts) + + return TranscriptCommit(candidate.transcript, preview, latency_s, + thresh_at_commit = edit_thresh) + +def evaluate(cfg, + audio_path: str, + control_path: str, + last_n_must_match: int = 4, + edit_thresh_min: int = 1, + edit_thresh_grow_begin_s: float = 1.5, + edit_thresh_grow_halflife_s: float = 1.5, + min_segment_age_s: float = 1.0 + ): + stream = DiskStream(audio_path, True) + + collector = AudioCollector(stream) + whisper = Whisper(collector, cfg) + com = FuzzyRepeatCommitter(collector, whisper, + last_n_must_match=last_n_must_match, + edit_thresh_min=edit_thresh_min, + edit_thresh_grow_begin_s=edit_thresh_grow_begin_s, + edit_thresh_grow_halflife_s=edit_thresh_grow_halflife_s, + min_segment_age_s=min_segment_age_s) + transcript = "" + commits = [] + while len(stream.frames) > 0: + commit = com.getDelta() + + if len(stream.frames) == 0: + commit.delta = commit.preview + commit.latency_s = 0 + + if len(commit.delta) > 0: + commits.append(commit) + + transcript += commit.delta + + #if len(commit.delta): + # print(f"transcript: {transcript}") + # print(f"commit latency: {commit.latency_s}") + # print(f"commit thresh: {commit.thresh_at_commit}") + + with open(control_path, "r") as f: + control = f.read() + normalizer = EnglishTextNormalizer() + control = normalizer(control) + experiment = normalizer(transcript) + + sum_latency = 0 + for commit in commits: + sum_latency += commit.latency_s + avg_latency = sum_latency / len(commits) + + dist = editdistance.eval(control, experiment) + + print(f"PARAMS") + print(f"last_n_must_match: {last_n_must_match}") + print(f"edit_thresh_min: {edit_thresh_min}") + print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}") + print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}") + print(f"min_segment_age_s: {min_segment_age_s}") + print(f"RESULTS") + print(f"edit distance: {dist}") + print(f"avg latency: {avg_latency}") + print(f"num commits: {len(commits)}") + print(f"final transcript: {transcript}") + + score = dist * avg_latency + print(f"score: {score}") + return score + +def optimize(cfg, + audio_path: str, + control_path: str): + + def wrapper_to_optimize(x): + return evaluate( + cfg, + audio_path, + control_path, + int(x[0]), # last_n_must_match + int(x[1]), # edit_thresh_min + x[2], # edit_thresh_grow_begin_s + x[3], # edit_thresh_grow_halflife_s + x[4] # min_segment_age_s + ) + + initial_guess = [4, 1, 1.5, 1.5, 1.0] + bounds = [ + (1, 5), # last_n_must_match + (1, 100), # edit_thresh_min + (0, 10), # edit_thresh_grow_begin_s + (0.1, 10), # edit_thresh_grow_halflife_s + (0, 3) # min_segment_age_s + ] + + result = minimize( + wrapper_to_optimize, + initial_guess, + bounds=bounds, + method='L-BFGS-B', + ) + + optimized_params = result.x + + print("Optimized Parameters:") + print(f"last_n_must_match: {int(optimized_params[0])}") + print(f"edit_thresh_min: {int(optimized_params[1])}") + print(f"edit_thresh_grow_begin_s: {optimized_params[2]}") + print(f"edit_thresh_grow_halflife_s: {optimized_params[3]}") + print(f"min_segment_age_s: {optimized_params[4]}") + + return optimized_params + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="Path to app config YAML file.") + args = parser.parse_args() + + cfg = app_config.getConfig(args.config) + + optimize(cfg, "Evaluate/declaration_short/audio.mp3", + "Evaluate/declaration_short/control.txt") + + |
