diff options
| author | yum <yum.food.vr@gmail.com> | 2023-09-03 13:23:50 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-09-03 13:23:50 -0700 |
| commit | 606d223f8ba9174a2984d7cb15e6e94ef6e48228 (patch) | |
| tree | afd1b19fe801d9aac54b4e5bbe4a671e5df2217c | |
| parent | e9b5b4f1da2a8ff07b2d13e5e63dae491325251d (diff) | |
Experiment with Collector filters
Try adding two filters on top of the usual AudioCollector:
* Minimum length preservation: never report fewer than N seconds worth
of audio data. Pad with silence as needed.
* Volume normalizing: normalize audio volume.
Using my benchmark of 30-second audio clips from 3 speakers (lower is
better):
length enf + norm = 87.118
nothing = 90.917
norm = 94.538
length = 111.402
Both together are a slight improvement, but independently degrade the
result by a lot. I also observed more hallucinations in a conversational
pattern when using them vs. not. So I'll phase them out.
I'm still curious about *compression* as opposed to normalization.
| -rw-r--r-- | Scripts/transcribe_v2.py | 222 |
1 files changed, 162 insertions, 60 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 3cf8fe7..2c8c57d 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -31,7 +31,7 @@ class AudioStream(): raise NotImplementedError("getSamples is not implemented!") class DiskStream(AudioStream): - def __init__(self, path: str, pace_at_real_time: bool): + def __init__(self, path: str): fmt = None if path.endswith(".mp3"): fmt = "mp3" @@ -42,29 +42,23 @@ class DiskStream(AudioStream): "is not supported") print(f"Loading audio data") audio = AudioSegment.from_file(path, format=fmt) - audio.set_channels(1) + audio = audio.set_channels(1) # TODO(yum) replace manual decimation code with this! audio = audio.set_frame_rate(16000) frames = np.array(audio.get_array_of_samples()) frames = np.int16(frames).tobytes() self.frames = frames - self.wall_ts = time.time() - self.pace_at_real_time = pace_at_real_time print(f"Loaded data") def getSamples(self) -> bytes: - if self.pace_at_real_time: - now = time.time() - nframes = int((now - self.wall_ts) * AudioStream.FPS) - self.wall_ts = now - frames = self.frames[0:nframes * AudioStream.FRAME_SZ]; - self.frames = self.frames[nframes * AudioStream.FRAME_SZ:] - return frames - - frames = self.frames - self.frames = b'' + # Give out samples at a fixed rate to minimize + # noise. + give_s = 0.2 + nframes = int(give_s * AudioStream.FPS) + frames = self.frames[0:nframes * AudioStream.FRAME_SZ]; + self.frames = self.frames[nframes * AudioStream.FRAME_SZ:] return frames class MicStream(AudioStream): @@ -188,6 +182,10 @@ class AudioCollector: self.frames = self.frames[n_bytes:] self.wall_ts = self.wall_ts + self.duration() + def keepLast(self, dur_s: float): + drop_len = max(0, self.duration() - dur_s) + self.dropAudioPrefix(drop_len) + def dropAudio(self): self.wall_ts += self.duration() self.frames = b'' @@ -195,8 +193,60 @@ class AudioCollector: def duration(self): return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ) + def begin(self): + return self.wall_ts + + def now(self): + return self.begin() + self.duration() + +class AudioCollectorFilter: + def __init__(self, parent: AudioCollector): + self.parent = parent + + def getAudio(self) -> bytes: + return self.parent.getAudio() + def dropAudioPrefix(self, dur_s: float): + return self.parent.dropAudioPrefix(dur_s) + def keepLast(self, dur_s): + return self.parent.keepLast(dur_s) + def dropAudio(self): + return self.parent.dropAudio() + def duration(self): + return self.parent.duration() + def begin(self): + return self.parent.begin() def now(self): - return self.wall_ts + self.duration() + return self.parent.now() + +# Audio collector that enforces a minimum length on its audio data. +class LengthEnforcingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, min_duration_s: float): + AudioCollectorFilter.__init__(self, parent) + self.min_duration_s = min_duration_s + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + min_duration_frames = int(self.min_duration_s * AudioStream.FPS) + pad_len_frames = max(0, min_duration_frames - int(len(audio) / + AudioStream.FRAME_SZ)) + pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes() + return pad + audio + +class NormalizingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + audio = audio.normalize() + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames # A segment of transcribed audio. `start_ts` and `end_ts` are floating point # number of seconds since the beginning of audio data. @@ -277,7 +327,7 @@ class Whisper: res = [] for s in segments: res.append(Segment(s.text, s.start, s.end, - self.collector.wall_ts, + self.collector.begin(), s.avg_logprob, s.no_speech_prob)) return res @@ -292,15 +342,17 @@ class TranscriptCommit: self.latency_s = latency_s self.thresh_at_commit = thresh_at_commit +# Commits audio when the transcription layer repeats the same transcript, +# within some fuzzy match distance. class FuzzyRepeatCommitter: def __init__(self, collector: AudioCollector, whisper: Whisper, - last_n_must_match: int = 4, - edit_thresh_min: int = 1, + last_n_must_match: int = 2, + edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, - edit_thresh_grow_halflife_s: float = 1.5, - min_segment_age_s: float = 1.0): + edit_thresh_grow_halflife_s: float = 0.5, + min_segment_age_s: float = 0.5): self.collector = collector self.whisper = whisper # List of candidate segments. Once these all match, we commit the @@ -317,6 +369,7 @@ class FuzzyRepeatCommitter: preview = ''.join(s.transcript for s in segments) if len(segments) == 0: + self.collector.keepLast(1.0) return TranscriptCommit("", preview, None) s = segments[0] @@ -338,11 +391,11 @@ class FuzzyRepeatCommitter: # as the buffer size grows, thus allowing the check to get weaker under # compute pressure. edit_thresh = self.edit_thresh_min - dt = self.collector.now() - (self.collector.wall_ts + s.start_ts) + dt = self.collector.now() - (self.collector.begin() + s.start_ts) if dt > self.edit_thresh_grow_begin_s: dt -= self.edit_thresh_grow_begin_s - edit_thresh = int(math.ceil(2**(dt / - self.edit_thresh_grow_halflife_s))) + edit_thresh = math.ceil(2**(dt / + self.edit_thresh_grow_halflife_s)) drop_candidates = 0 for i in range(1, len(self.candidates)): @@ -374,15 +427,17 @@ class FuzzyRepeatCommitter: def evaluate(cfg, audio_path: str, control_path: str, - last_n_must_match: int = 4, - edit_thresh_min: int = 1, + last_n_must_match: int = 2, + edit_thresh_min: float = 1, edit_thresh_grow_begin_s: float = 1.5, - edit_thresh_grow_halflife_s: float = 1.5, - min_segment_age_s: float = 1.0 + edit_thresh_grow_halflife_s: float = 0.5, + min_segment_age_s: float = 0.5 ): - stream = DiskStream(audio_path, True) + stream = DiskStream(audio_path) collector = AudioCollector(stream) + collector = LengthEnforcingAudioCollector(collector, 5.0) + collector = NormalizingAudioCollector(collector) whisper = Whisper(collector, cfg) com = FuzzyRepeatCommitter(collector, whisper, last_n_must_match=last_n_must_match, @@ -392,6 +447,14 @@ def evaluate(cfg, min_segment_age_s=min_segment_age_s) transcript = "" commits = [] + + print(f"PARAMS") + print(f"last_n_must_match: {last_n_must_match}") + print(f"edit_thresh_min: {edit_thresh_min}") + print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}") + print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}") + print(f"min_segment_age_s: {min_segment_age_s}") + while len(stream.frames) > 0: commit = com.getDelta() @@ -404,10 +467,10 @@ def evaluate(cfg, transcript += commit.delta - #if len(commit.delta): - # print(f"transcript: {transcript}") - # print(f"commit latency: {commit.latency_s}") - # print(f"commit thresh: {commit.thresh_at_commit}") + if False and len(commit.delta): + print(f"transcript: {transcript}") + print(f"commit latency: {commit.latency_s}") + print(f"commit thresh: {commit.thresh_at_commit}") with open(control_path, "r") as f: control = f.read() @@ -422,44 +485,40 @@ def evaluate(cfg, dist = editdistance.eval(control, experiment) - print(f"PARAMS") - print(f"last_n_must_match: {last_n_must_match}") - print(f"edit_thresh_min: {edit_thresh_min}") - print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}") - print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}") - print(f"min_segment_age_s: {min_segment_age_s}") print(f"RESULTS") print(f"edit distance: {dist}") print(f"avg latency: {avg_latency}") print(f"num commits: {len(commits)}") print(f"final transcript: {transcript}") - score = dist * avg_latency + score = (3 + (dist/len(control)) * 100) * avg_latency print(f"score: {score}") return score def optimize(cfg, - audio_path: str, - control_path: str): + experiments: typing.List[typing.Tuple[str, str]]): def wrapper_to_optimize(x): - return evaluate( - cfg, - audio_path, - control_path, - int(x[0]), # last_n_must_match - int(x[1]), # edit_thresh_min - x[2], # edit_thresh_grow_begin_s - x[3], # edit_thresh_grow_halflife_s - x[4] # min_segment_age_s - ) - - initial_guess = [4, 1, 1.5, 1.5, 1.0] + s = 0 + for audio_path, control_path in experiments: + s += evaluate( + cfg, + audio_path, + control_path, + int(x[0]), # last_n_must_match + 2**x[1], # edit_thresh_min + (2**x[2])-1,# edit_thresh_grow_begin_s + x[3], # edit_thresh_grow_halflife_s + x[4] # min_segment_age_s + ) + return s + + initial_guess = [2.3, 1, 1.75, 1.5, 0.5] bounds = [ - (1, 5), # last_n_must_match - (1, 100), # edit_thresh_min - (0, 10), # edit_thresh_grow_begin_s - (0.1, 10), # edit_thresh_grow_halflife_s + (2, 3), # last_n_must_match + (1, 4), # edit_thresh_min + (0, 2.5), # edit_thresh_grow_begin_s + (0.1, 2), # edit_thresh_grow_halflife_s (0, 3) # min_segment_age_s ] @@ -468,19 +527,47 @@ def optimize(cfg, initial_guess, bounds=bounds, method='L-BFGS-B', + options={"maxfun": int((60/.5)*12), + "eps": 0.2}, ) optimized_params = result.x print("Optimized Parameters:") print(f"last_n_must_match: {int(optimized_params[0])}") - print(f"edit_thresh_min: {int(optimized_params[1])}") + print(f"edit_thresh_min: {optimized_params[1]}") print(f"edit_thresh_grow_begin_s: {optimized_params[2]}") print(f"edit_thresh_grow_halflife_s: {optimized_params[3]}") print(f"min_segment_age_s: {optimized_params[4]}") return optimized_params +def run(cfg): + stream = MicStream(cfg["microphone"]) + + collector = AudioCollector(stream) + #collector = LengthEnforcingAudioCollector(collector, 5.0) + #collector = NormalizingAudioCollector(collector) + + whisper = Whisper(collector, cfg) + com = FuzzyRepeatCommitter(collector, whisper) + transcript = "" + commits = [] + + while True: + commit = com.getDelta() + + if len(commit.delta) > 0: + commits.append(commit) + + transcript += commit.delta + + print(f"{transcript}{commit.preview}") + + if True and len(commit.delta): + print(f"commit latency: {commit.latency_s}") + print(f"commit thresh: {commit.thresh_at_commit}") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="Path to app config YAML file.") @@ -488,7 +575,22 @@ if __name__ == "__main__": cfg = app_config.getConfig(args.config) - optimize(cfg, "Evaluate/declaration_short/audio.mp3", - "Evaluate/declaration_short/control.txt") + experiments = [ + ("Evaluate/declaration_short/audio.mp3", + "Evaluate/declaration_short/control.txt"), + ("Evaluate/moist/audio.mp3", + "Evaluate/moist/control.txt"), + ("Evaluate/vei/audio.mp3", + "Evaluate/vei/control.txt"), + ] + + if False: + sum = 0 + for audio, control in experiments: + sum += evaluate(cfg, audio, control) + print(f"Total score: {sum}") + else: + #optimize(cfg, experiments) + run(cfg) |
