summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-03 13:23:50 -0700
committeryum <yum.food.vr@gmail.com>2023-09-03 13:23:50 -0700
commit606d223f8ba9174a2984d7cb15e6e94ef6e48228 (patch)
treeafd1b19fe801d9aac54b4e5bbe4a671e5df2217c
parente9b5b4f1da2a8ff07b2d13e5e63dae491325251d (diff)
Experiment with Collector filters
Try adding two filters on top of the usual AudioCollector: * Minimum length preservation: never report fewer than N seconds worth of audio data. Pad with silence as needed. * Volume normalizing: normalize audio volume. Using my benchmark of 30-second audio clips from 3 speakers (lower is better): length enf + norm = 87.118 nothing = 90.917 norm = 94.538 length = 111.402 Both together are a slight improvement, but independently degrade the result by a lot. I also observed more hallucinations in a conversational pattern when using them vs. not. So I'll phase them out. I'm still curious about *compression* as opposed to normalization.
-rw-r--r--Scripts/transcribe_v2.py222
1 files changed, 162 insertions, 60 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 3cf8fe7..2c8c57d 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -31,7 +31,7 @@ class AudioStream():
raise NotImplementedError("getSamples is not implemented!")
class DiskStream(AudioStream):
- def __init__(self, path: str, pace_at_real_time: bool):
+ def __init__(self, path: str):
fmt = None
if path.endswith(".mp3"):
fmt = "mp3"
@@ -42,29 +42,23 @@ class DiskStream(AudioStream):
"is not supported")
print(f"Loading audio data")
audio = AudioSegment.from_file(path, format=fmt)
- audio.set_channels(1)
+ audio = audio.set_channels(1)
# TODO(yum) replace manual decimation code with this!
audio = audio.set_frame_rate(16000)
frames = np.array(audio.get_array_of_samples())
frames = np.int16(frames).tobytes()
self.frames = frames
- self.wall_ts = time.time()
- self.pace_at_real_time = pace_at_real_time
print(f"Loaded data")
def getSamples(self) -> bytes:
- if self.pace_at_real_time:
- now = time.time()
- nframes = int((now - self.wall_ts) * AudioStream.FPS)
- self.wall_ts = now
- frames = self.frames[0:nframes * AudioStream.FRAME_SZ];
- self.frames = self.frames[nframes * AudioStream.FRAME_SZ:]
- return frames
-
- frames = self.frames
- self.frames = b''
+ # Give out samples at a fixed rate to minimize
+ # noise.
+ give_s = 0.2
+ nframes = int(give_s * AudioStream.FPS)
+ frames = self.frames[0:nframes * AudioStream.FRAME_SZ];
+ self.frames = self.frames[nframes * AudioStream.FRAME_SZ:]
return frames
class MicStream(AudioStream):
@@ -188,6 +182,10 @@ class AudioCollector:
self.frames = self.frames[n_bytes:]
self.wall_ts = self.wall_ts + self.duration()
+ def keepLast(self, dur_s: float):
+ drop_len = max(0, self.duration() - dur_s)
+ self.dropAudioPrefix(drop_len)
+
def dropAudio(self):
self.wall_ts += self.duration()
self.frames = b''
@@ -195,8 +193,60 @@ class AudioCollector:
def duration(self):
return len(self.frames) / (self.stream.FPS * self.stream.FRAME_SZ)
+ def begin(self):
+ return self.wall_ts
+
+ def now(self):
+ return self.begin() + self.duration()
+
+class AudioCollectorFilter:
+ def __init__(self, parent: AudioCollector):
+ self.parent = parent
+
+ def getAudio(self) -> bytes:
+ return self.parent.getAudio()
+ def dropAudioPrefix(self, dur_s: float):
+ return self.parent.dropAudioPrefix(dur_s)
+ def keepLast(self, dur_s):
+ return self.parent.keepLast(dur_s)
+ def dropAudio(self):
+ return self.parent.dropAudio()
+ def duration(self):
+ return self.parent.duration()
+ def begin(self):
+ return self.parent.begin()
def now(self):
- return self.wall_ts + self.duration()
+ return self.parent.now()
+
+# Audio collector that enforces a minimum length on its audio data.
+class LengthEnforcingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector, min_duration_s: float):
+ AudioCollectorFilter.__init__(self, parent)
+ self.min_duration_s = min_duration_s
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+ min_duration_frames = int(self.min_duration_s * AudioStream.FPS)
+ pad_len_frames = max(0, min_duration_frames - int(len(audio) /
+ AudioStream.FRAME_SZ))
+ pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes()
+ return pad + audio
+
+class NormalizingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ audio = audio.normalize()
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
# number of seconds since the beginning of audio data.
@@ -277,7 +327,7 @@ class Whisper:
res = []
for s in segments:
res.append(Segment(s.text, s.start, s.end,
- self.collector.wall_ts,
+ self.collector.begin(),
s.avg_logprob, s.no_speech_prob))
return res
@@ -292,15 +342,17 @@ class TranscriptCommit:
self.latency_s = latency_s
self.thresh_at_commit = thresh_at_commit
+# Commits audio when the transcription layer repeats the same transcript,
+# within some fuzzy match distance.
class FuzzyRepeatCommitter:
def __init__(self,
collector: AudioCollector,
whisper: Whisper,
- last_n_must_match: int = 4,
- edit_thresh_min: int = 1,
+ last_n_must_match: int = 2,
+ edit_thresh_min: float = 1,
edit_thresh_grow_begin_s: float = 1.5,
- edit_thresh_grow_halflife_s: float = 1.5,
- min_segment_age_s: float = 1.0):
+ edit_thresh_grow_halflife_s: float = 0.5,
+ min_segment_age_s: float = 0.5):
self.collector = collector
self.whisper = whisper
# List of candidate segments. Once these all match, we commit the
@@ -317,6 +369,7 @@ class FuzzyRepeatCommitter:
preview = ''.join(s.transcript for s in segments)
if len(segments) == 0:
+ self.collector.keepLast(1.0)
return TranscriptCommit("", preview, None)
s = segments[0]
@@ -338,11 +391,11 @@ class FuzzyRepeatCommitter:
# as the buffer size grows, thus allowing the check to get weaker under
# compute pressure.
edit_thresh = self.edit_thresh_min
- dt = self.collector.now() - (self.collector.wall_ts + s.start_ts)
+ dt = self.collector.now() - (self.collector.begin() + s.start_ts)
if dt > self.edit_thresh_grow_begin_s:
dt -= self.edit_thresh_grow_begin_s
- edit_thresh = int(math.ceil(2**(dt /
- self.edit_thresh_grow_halflife_s)))
+ edit_thresh = math.ceil(2**(dt /
+ self.edit_thresh_grow_halflife_s))
drop_candidates = 0
for i in range(1, len(self.candidates)):
@@ -374,15 +427,17 @@ class FuzzyRepeatCommitter:
def evaluate(cfg,
audio_path: str,
control_path: str,
- last_n_must_match: int = 4,
- edit_thresh_min: int = 1,
+ last_n_must_match: int = 2,
+ edit_thresh_min: float = 1,
edit_thresh_grow_begin_s: float = 1.5,
- edit_thresh_grow_halflife_s: float = 1.5,
- min_segment_age_s: float = 1.0
+ edit_thresh_grow_halflife_s: float = 0.5,
+ min_segment_age_s: float = 0.5
):
- stream = DiskStream(audio_path, True)
+ stream = DiskStream(audio_path)
collector = AudioCollector(stream)
+ collector = LengthEnforcingAudioCollector(collector, 5.0)
+ collector = NormalizingAudioCollector(collector)
whisper = Whisper(collector, cfg)
com = FuzzyRepeatCommitter(collector, whisper,
last_n_must_match=last_n_must_match,
@@ -392,6 +447,14 @@ def evaluate(cfg,
min_segment_age_s=min_segment_age_s)
transcript = ""
commits = []
+
+ print(f"PARAMS")
+ print(f"last_n_must_match: {last_n_must_match}")
+ print(f"edit_thresh_min: {edit_thresh_min}")
+ print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}")
+ print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}")
+ print(f"min_segment_age_s: {min_segment_age_s}")
+
while len(stream.frames) > 0:
commit = com.getDelta()
@@ -404,10 +467,10 @@ def evaluate(cfg,
transcript += commit.delta
- #if len(commit.delta):
- # print(f"transcript: {transcript}")
- # print(f"commit latency: {commit.latency_s}")
- # print(f"commit thresh: {commit.thresh_at_commit}")
+ if False and len(commit.delta):
+ print(f"transcript: {transcript}")
+ print(f"commit latency: {commit.latency_s}")
+ print(f"commit thresh: {commit.thresh_at_commit}")
with open(control_path, "r") as f:
control = f.read()
@@ -422,44 +485,40 @@ def evaluate(cfg,
dist = editdistance.eval(control, experiment)
- print(f"PARAMS")
- print(f"last_n_must_match: {last_n_must_match}")
- print(f"edit_thresh_min: {edit_thresh_min}")
- print(f"edit_thresh_grow_begin_s: {edit_thresh_grow_begin_s}")
- print(f"edit_thresh_grow_halflife_s: {edit_thresh_grow_halflife_s}")
- print(f"min_segment_age_s: {min_segment_age_s}")
print(f"RESULTS")
print(f"edit distance: {dist}")
print(f"avg latency: {avg_latency}")
print(f"num commits: {len(commits)}")
print(f"final transcript: {transcript}")
- score = dist * avg_latency
+ score = (3 + (dist/len(control)) * 100) * avg_latency
print(f"score: {score}")
return score
def optimize(cfg,
- audio_path: str,
- control_path: str):
+ experiments: typing.List[typing.Tuple[str, str]]):
def wrapper_to_optimize(x):
- return evaluate(
- cfg,
- audio_path,
- control_path,
- int(x[0]), # last_n_must_match
- int(x[1]), # edit_thresh_min
- x[2], # edit_thresh_grow_begin_s
- x[3], # edit_thresh_grow_halflife_s
- x[4] # min_segment_age_s
- )
-
- initial_guess = [4, 1, 1.5, 1.5, 1.0]
+ s = 0
+ for audio_path, control_path in experiments:
+ s += evaluate(
+ cfg,
+ audio_path,
+ control_path,
+ int(x[0]), # last_n_must_match
+ 2**x[1], # edit_thresh_min
+ (2**x[2])-1,# edit_thresh_grow_begin_s
+ x[3], # edit_thresh_grow_halflife_s
+ x[4] # min_segment_age_s
+ )
+ return s
+
+ initial_guess = [2.3, 1, 1.75, 1.5, 0.5]
bounds = [
- (1, 5), # last_n_must_match
- (1, 100), # edit_thresh_min
- (0, 10), # edit_thresh_grow_begin_s
- (0.1, 10), # edit_thresh_grow_halflife_s
+ (2, 3), # last_n_must_match
+ (1, 4), # edit_thresh_min
+ (0, 2.5), # edit_thresh_grow_begin_s
+ (0.1, 2), # edit_thresh_grow_halflife_s
(0, 3) # min_segment_age_s
]
@@ -468,19 +527,47 @@ def optimize(cfg,
initial_guess,
bounds=bounds,
method='L-BFGS-B',
+ options={"maxfun": int((60/.5)*12),
+ "eps": 0.2},
)
optimized_params = result.x
print("Optimized Parameters:")
print(f"last_n_must_match: {int(optimized_params[0])}")
- print(f"edit_thresh_min: {int(optimized_params[1])}")
+ print(f"edit_thresh_min: {optimized_params[1]}")
print(f"edit_thresh_grow_begin_s: {optimized_params[2]}")
print(f"edit_thresh_grow_halflife_s: {optimized_params[3]}")
print(f"min_segment_age_s: {optimized_params[4]}")
return optimized_params
+def run(cfg):
+ stream = MicStream(cfg["microphone"])
+
+ collector = AudioCollector(stream)
+ #collector = LengthEnforcingAudioCollector(collector, 5.0)
+ #collector = NormalizingAudioCollector(collector)
+
+ whisper = Whisper(collector, cfg)
+ com = FuzzyRepeatCommitter(collector, whisper)
+ transcript = ""
+ commits = []
+
+ while True:
+ commit = com.getDelta()
+
+ if len(commit.delta) > 0:
+ commits.append(commit)
+
+ transcript += commit.delta
+
+ print(f"{transcript}{commit.preview}")
+
+ if True and len(commit.delta):
+ print(f"commit latency: {commit.latency_s}")
+ print(f"commit thresh: {commit.thresh_at_commit}")
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to app config YAML file.")
@@ -488,7 +575,22 @@ if __name__ == "__main__":
cfg = app_config.getConfig(args.config)
- optimize(cfg, "Evaluate/declaration_short/audio.mp3",
- "Evaluate/declaration_short/control.txt")
+ experiments = [
+ ("Evaluate/declaration_short/audio.mp3",
+ "Evaluate/declaration_short/control.txt"),
+ ("Evaluate/moist/audio.mp3",
+ "Evaluate/moist/control.txt"),
+ ("Evaluate/vei/audio.mp3",
+ "Evaluate/vei/control.txt"),
+ ]
+
+ if False:
+ sum = 0
+ for audio, control in experiments:
+ sum += evaluate(cfg, audio, control)
+ print(f"Total score: {sum}")
+ else:
+ #optimize(cfg, experiments)
+ run(cfg)