summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-02 20:43:18 -0700
committeryum <yum.food.vr@gmail.com>2023-09-02 20:43:18 -0700
commite9b5b4f1da2a8ff07b2d13e5e63dae491325251d (patch)
treeb4b030954839b429e6d3d2572e626c300ca52eec /Scripts/transcribe_v2.py
parentc0c53fc3f0aeb762d44ce43f123385b2c87869ca (diff)
Begin rewriting transcribe.py
A set of proper interfaces is called for. See #dev-update-spam in discord for drawing of design. Also add code to mechanically optimize committer parameters using an audio file. Not perfectly repeatable since it depends on the performance characteristics of the machine, but prob better than what we had before (nothing).
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py494
1 files changed, 494 insertions, 0 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
new file mode 100644
index 0000000..3cf8fe7
--- /dev/null
+++ b/Scripts/transcribe_v2.py
@@ -0,0 +1,494 @@
+from datetime import datetime
+from faster_whisper import WhisperModel
+from functools import partial
+from pydub import AudioSegment
+from whisper.normalizers import EnglishTextNormalizer
+from scipy.optimize import minimize
+
+import app_config
+import argparse
+import editdistance
+import langcodes
+import math
+import numpy as np
+import os
+import pyaudio
+import time
+import typing
+
+class AudioStream():
+ FORMAT = pyaudio.paInt16
+ # Size of each frame (audio sample), in bytes. If you change FORMAT, make
+ # sure this stays up to date!
+ FRAME_SZ = 2
+ # Frames per second.
+ FPS = 16000
+ CHANNELS = 1
+ def __init__(self):
+ pass
+
+ def getSamples(self) -> bytes:
+ raise NotImplementedError("getSamples is not implemented!")
+
+class DiskStream(AudioStream):
+ def __init__(self, path: str, pace_at_real_time: bool):
+ fmt = None
+ if path.endswith(".mp3"):
+ fmt = "mp3"
+ elif path.endswith(".wav"):
+ fmt = "wav"
+ else:
+ raise NotImplementedError(f"Requested file type {path} " + \
+ "is not supported")
+ print(f"Loading audio data")
+ audio = AudioSegment.from_file(path, format=fmt)
+ audio.set_channels(1)
+ # TODO(yum) replace manual decimation code with this!
+ audio = audio.set_frame_rate(16000)
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ self.frames = frames
+ self.wall_ts = time.time()
+ self.pace_at_real_time = pace_at_real_time
+
+ print(f"Loaded data")
+
+ def getSamples(self) -> bytes:
+ if self.pace_at_real_time:
+ now = time.time()
+ nframes = int((now - self.wall_ts) * AudioStream.FPS)
+ self.wall_ts = now
+ frames = self.frames[0:nframes * AudioStream.FRAME_SZ];
+ self.frames = self.frames[nframes * AudioStream.FRAME_SZ:]
+ return frames
+
+ frames = self.frames
+ self.frames = b''
+ return frames
+
+class MicStream(AudioStream):
+ CHUNK_SZ = 1024
+
+ def __init__(self, which_mic: str):
+ self.p = pyaudio.PyAudio()
+ self.stream = None
+ self.sample_rate = None
+ # Each time pyaudio gives us audio data, it's in the form of a chunk of
+ # samples. We keep these in a list to keep the audio callback as light
+ # as possible. Whenever downstream layers want data, we collapse the
+ # list into a single array of data (a bytes object).
+ self.chunks = []
+
+ print(f"Finding mic {which_mic}")
+ self.dumpMicDevices()
+
+ got_match = False
+ device_index = -1
+ focusrite_str = "Focusrite"
+ index_str = "Digital Audio Interface"
+ if which_mic == "index":
+ target_str = index_str
+ elif which_mic == "focusrite":
+ target_str = focusrite_str
+ else:
+ print(f"Mic {which_mic} requested, treating it as a numerical " +
+ "device ID")
+ device_index = int(which_mic)
+ got_match = True
+ if not got_match:
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ if target_str in device_name:
+ print(f"Got matching mic: {device_name}")
+ device_index = i
+ got_match = True
+ break
+ if not got_match:
+ raise KeyError(f"Mic {which_mic} not found")
+
+ info = self.p.get_device_info_by_host_api_device_index(0, device_index)
+ print(f"Found mic {which_mic}: {info['name']}")
+ self.sample_rate = int(info['defaultSampleRate'])
+ print(f"Mic sample rate: {self.sample_rate}")
+
+ self.stream = self.p.open(
+ rate=self.sample_rate,
+ channels=AudioStream.CHANNELS,
+ format=AudioStream.FORMAT,
+ input=True,
+ frames_per_buffer=MicStream.CHUNK_SZ,
+ input_device_index=device_index,
+ stream_callback=self.onAudioFramesAvailable)
+
+ self.stream.start_stream()
+
+ AudioStream.__init__(self)
+
+ def dumpMicDevices(self):
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ print("Input Device id ", i, " - ", device_name)
+
+ def onAudioFramesAvailable(self,
+ frames,
+ frame_count,
+ time_info,
+ status_flags):
+ decimated = b''
+ # In pyaudio, a `frame` is a single sample of audio data.
+ frame_len = AudioStream.FRAME_SZ
+ next_frame = 0.0
+ # The mic probably has a higher sample rate than Whisper wants, so
+ # decrease the sample rate by dropping samples. Note that this
+ # algorithm only works if the mic's rate is higher than whisper's
+ # expected rate.
+ keep_every = float(self.sample_rate) / AudioStream.FPS
+ for i in range(frame_count):
+ if i >= next_frame:
+ decimated += frames[i*frame_len:(i+1)*frame_len]
+ next_frame += keep_every
+ self.chunks.append(decimated)
+
+ return (frames, pyaudio.paContinue)
+
+ # Get audio data and the corresponding timestamp.
+ def getSamples(self) -> bytes:
+ chunks = self.chunks
+ self.chunks = []
+ result = b''.join(chunks)
+ return result
+
+class AudioCollector:
+ def __init__(self, stream: AudioStream):
+ self.stream = stream
+ self.frames = b''
+ # Note: by design, this is the only spot where we anchor our timestamps
+ # against the real world. This is done to make it possible to profile
+ # test cases which read from disk (at much faster than real speed) in
+ # the same way that we profile real-time data.
+ self.wall_ts = time.time()
+
+ def getAudio(self) -> bytes:
+ frames = self.stream.getSamples()
+ if frames:
+ self.frames += frames
+ return self.frames
+
+ def dropAudioPrefix(self, dur_s: float):
+ n_bytes = int(dur_s * self.stream.FPS) * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts = self.wall_ts + self.duration()
+
+ def dropAudio(self):
+ self.wall_ts += self.duration()
+ self.frames = b''
+
+ def duration(self):
+ return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ)
+
+ def now(self):
+ return self.wall_ts + self.duration()
+
+# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
+# number of seconds since the beginning of audio data.
+class Segment:
+ def __init__(self,
+ transcript: str,
+ start_ts: float,
+ end_ts: float,
+ wall_ts: float,
+ avg_logprob: float,
+ no_speech_prob: float):
+ self.transcript = transcript
+ # start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
+ self.start_ts = start_ts
+ self.end_ts = end_ts
+ # wall_ts is the time.time() at which the oldest audio sample leading
+ # to this transcript was collected.
+ self.wall_ts = wall_ts
+ self.avg_logprob = avg_logprob
+ self.no_speech_prob = no_speech_prob
+
+ def __str__(self):
+ ts = f"(ts: {self.start_ts}-{self.end_ts}) "
+
+ wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) "
+
+ no_speech = f"(no_speech: {self.no_speech_prob}) "
+ avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+
+class Whisper:
+ def __init__(self,
+ collector: AudioCollector,
+ cfg: typing.Dict):
+ self.collector = collector
+ self.model = None
+ self.cfg = cfg
+
+ abspath = os.path.abspath(__file__)
+ dname = os.path.dirname(abspath)
+
+ model_root = os.path.join(dname, "Models", cfg["model"])
+ print("Model {} will be saved to {}".format(cfg["model"], model_root))
+
+ model_device = "cuda"
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+
+ download_it = os.path.exists(model_root)
+ model_str = cfg["model"]
+ if download_it:
+ model_str = model_root
+ self.model = WhisperModel(model_str,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = "int8",
+ download_root = model_root,
+ local_files_only = download_it)
+
+ def transcribe(self) -> typing.List[Segment]:
+ frames = self.collector.getAudio()
+ # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on
+ # [-1, 1].
+ audio = np.frombuffer(frames,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+
+ segments, info = self.model.transcribe(
+ audio,
+ beam_size = 5,
+ language = langcodes.find(self.cfg["language"]).language,
+ temperature = 0.0,
+ log_prob_threshold = -0.8,
+ vad_filter = True,
+ condition_on_previous_text = True,
+ without_timestamps = False)
+ res = []
+ for s in segments:
+ res.append(Segment(s.text, s.start, s.end,
+ self.collector.wall_ts,
+ s.avg_logprob, s.no_speech_prob))
+ return res
+
+class TranscriptCommit:
+ def __init__(self,
+ delta: str,
+ preview: str,
+ latency_s: int = None,
+ thresh_at_commit: int = None):
+ self.delta = delta
+ self.preview = preview
+ self.latency_s = latency_s
+ self.thresh_at_commit = thresh_at_commit
+
+class FuzzyRepeatCommitter:
+ def __init__(self,
+ collector: AudioCollector,
+ whisper: Whisper,
+ last_n_must_match: int = 4,
+ edit_thresh_min: int = 1,
+ edit_thresh_grow_begin_s: float = 1.5,
+ edit_thresh_grow_halflife_s: float = 1.5,
+ min_segment_age_s: float = 1.0):
+ self.collector = collector
+ self.whisper = whisper
+ # List of candidate segments. Once these all match, we commit the
+ # corresponding audio data.
+ self.candidates = []
+ self.last_n_must_match = last_n_must_match
+ self.edit_thresh_min = edit_thresh_min
+ self.edit_thresh_grow_begin_s = edit_thresh_grow_begin_s
+ self.edit_thresh_grow_halflife_s = edit_thresh_grow_halflife_s
+ self.min_segment_age_s = min_segment_age_s
+
+ def getDelta(self) -> TranscriptCommit:
+ segments = self.whisper.transcribe()
+ preview = ''.join(s.transcript for s in segments)
+
+ if len(segments) == 0:
+ return TranscriptCommit("", preview, None)
+
+ s = segments[0]
+
+ if len(self.candidates) < self.last_n_must_match:
+ if len(self.candidates) == 0:
+ self.candidates.append(s)
+ return TranscriptCommit("", preview, None)
+ s0 = self.candidates[0]
+ if s.wall_ts != s0.wall_ts:
+ print("Frames dropped, committer resetting candidates")
+ self.candidates = []
+ return TranscriptCommit("", preview, None)
+ self.candidates.append(s)
+ return TranscriptCommit("", preview, None)
+
+ # Rule 1: last n segments must be within a certain edit distance of
+ # each other. This edit distance starts low and increases exponentially
+ # as the buffer size grows, thus allowing the check to get weaker under
+ # compute pressure.
+ edit_thresh = self.edit_thresh_min
+ dt = self.collector.now() - (self.collector.wall_ts + s.start_ts)
+ if dt > self.edit_thresh_grow_begin_s:
+ dt -= self.edit_thresh_grow_begin_s
+ edit_thresh = int(math.ceil(2**(dt /
+ self.edit_thresh_grow_halflife_s)))
+
+ drop_candidates = 0
+ for i in range(1, len(self.candidates)):
+ prev = self.candidates[i-1]
+ cur = self.candidates[i]
+ dist = editdistance.eval(prev.transcript, cur.transcript)
+ if dist > edit_thresh:
+ drop_candidates = i
+ if drop_candidates != 0:
+ self.candidates = self.candidates[drop_candidates:]
+ return TranscriptCommit("", preview, None)
+
+ candidate = self.candidates[-1]
+
+ # Rule 2: no committing segments that are fewer than the configured
+ # number of seconds old.
+ if self.collector.now() - (candidate.end_ts + candidate.wall_ts) < self.min_segment_age_s:
+ self.candidates = []
+ return TranscriptCommit("", preview, None)
+
+ # Got a candidate! Commit it and return.
+ self.candidates = []
+ latency_s = self.collector.now() - (candidate.wall_ts + candidate.start_ts)
+ self.collector.dropAudioPrefix(candidate.end_ts)
+
+ return TranscriptCommit(candidate.transcript, preview, latency_s,
+ thresh_at_commit = edit_thresh)
+
+def evaluate(cfg,
+ audio_path: str,
+ control_path: str,
+ last_n_must_match: int = 4,
+ edit_thresh_min: int = 1,
+ edit_thresh_grow_begin_s: float = 1.5,
+ edit_thresh_grow_halflife_s: float = 1.5,
+ min_segment_age_s: float = 1.0
+ ):
+ stream = DiskStream(audio_path, True)
+
+ collector = AudioCollector(stream)
+ whisper = Whisper(collector, cfg)
+ com = FuzzyRepeatCommitter(collector, whisper,
+ last_n_must_match=last_n_must_match,
+ edit_thresh_min=edit_thresh_min,
+ edit_thresh_grow_begin_s=edit_thresh_grow_begin_s,
+ edit_thresh_grow_halflife_s=edit_thresh_grow_halflife_s,
+ min_segment_age_s=min_segment_age_s)
+ transcript = ""
+ commits = []
+ while len(stream.frames) > 0:
+ commit = com.getDelta()
+
+ if len(stream.frames) == 0:
+ commit.delta = commit.preview
+ commit.latency_s = 0
+
+ if len(commit.delta) > 0:
+ commits.append(commit)
+
+ transcript += commit.delta
+
+ #if len(commit.delta):
+ # print(f"transcript: {transcript}")
+ # print(f"commit latency: {commit.latency_s}")
+ # print(f"commit thresh: {commit.thresh_at_commit}")
+
+ with open(control_path, "r") as f:
+ control = f.read()
+ normalizer = EnglishTextNormalizer()
+ control = normalizer(control)
+ experiment = normalizer(transcript)
+
+ sum_latency = 0
+ for commit in commits:
+ sum_latency += commit.latency_s
+ avg_latency = sum_latency / len(commits)
+
+ dist = editdistance.eval(control, experiment)
+
+ print(f"PARAMS")
+ print(f"last_n_must_match: {last_n_must_match}")
+ print(f"edit_thresh_min: {edit_thresh_min}")
+ print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}")
+ print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}")
+ print(f"min_segment_age_s: {min_segment_age_s}")
+ print(f"RESULTS")
+ print(f"edit distance: {dist}")
+ print(f"avg latency: {avg_latency}")
+ print(f"num commits: {len(commits)}")
+ print(f"final transcript: {transcript}")
+
+ score = dist * avg_latency
+ print(f"score: {score}")
+ return score
+
+def optimize(cfg,
+ audio_path: str,
+ control_path: str):
+
+ def wrapper_to_optimize(x):
+ return evaluate(
+ cfg,
+ audio_path,
+ control_path,
+ int(x[0]), # last_n_must_match
+ int(x[1]), # edit_thresh_min
+ x[2], # edit_thresh_grow_begin_s
+ x[3], # edit_thresh_grow_halflife_s
+ x[4] # min_segment_age_s
+ )
+
+ initial_guess = [4, 1, 1.5, 1.5, 1.0]
+ bounds = [
+ (1, 5), # last_n_must_match
+ (1, 100), # edit_thresh_min
+ (0, 10), # edit_thresh_grow_begin_s
+ (0.1, 10), # edit_thresh_grow_halflife_s
+ (0, 3) # min_segment_age_s
+ ]
+
+ result = minimize(
+ wrapper_to_optimize,
+ initial_guess,
+ bounds=bounds,
+ method='L-BFGS-B',
+ )
+
+ optimized_params = result.x
+
+ print("Optimized Parameters:")
+ print(f"last_n_must_match: {int(optimized_params[0])}")
+ print(f"edit_thresh_min: {int(optimized_params[1])}")
+ print(f"edit_thresh_grow_begin_s: {optimized_params[2]}")
+ print(f"edit_thresh_grow_halflife_s: {optimized_params[3]}")
+ print(f"min_segment_age_s: {optimized_params[4]}")
+
+ return optimized_params
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, help="Path to app config YAML file.")
+ args = parser.parse_args()
+
+ cfg = app_config.getConfig(args.config)
+
+ optimize(cfg, "Evaluate/declaration_short/audio.mp3",
+ "Evaluate/declaration_short/control.txt")
+
+