diff options
| author | yum <yum.food.vr@gmail.com> | 2023-09-03 16:56:29 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-09-03 16:56:29 -0700 |
| commit | 2a4c6051acd8140bde6c1abad62bd613673de4b4 (patch) | |
| tree | e0c21e333c715f35f434e836f5476b0b57358c3f /Scripts/transcribe_v2.py | |
| parent | 606d223f8ba9174a2984d7cb15e6e94ef6e48228 (diff) | |
Apply subtle compression to audio before transcribing
This has a slight positive effect on my benchmark.
Diffstat (limited to 'Scripts/transcribe_v2.py')
| -rw-r--r-- | Scripts/transcribe_v2.py | 48 |
1 files changed, 34 insertions, 14 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 2c8c57d..1904526 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -87,7 +87,7 @@ class MicStream(AudioStream): target_str = focusrite_str else: print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID") + "device ID", file=sys.stderr) device_index = int(which_mic) got_match = True if not got_match: @@ -97,7 +97,8 @@ class MicStream(AudioStream): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') if target_str in device_name: - print(f"Got matching mic: {device_name}") + print(f"Got matching mic: {device_name}", + file=sys.stderr) device_index = i got_match = True break @@ -105,9 +106,9 @@ class MicStream(AudioStream): raise KeyError(f"Mic {which_mic} not found") info = self.p.get_device_info_by_host_api_device_index(0, device_index) - print(f"Found mic {which_mic}: {info['name']}") + print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) self.sample_rate = int(info['defaultSampleRate']) - print(f"Mic sample rate: {self.sample_rate}") + print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) self.stream = self.p.open( rate=self.sample_rate, @@ -129,7 +130,8 @@ class MicStream(AudioStream): for i in range(0, numdevices): if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) + print("Input Device id ", i, " - ", device_name, + file=sys.stderr) def onAudioFramesAvailable(self, frames, @@ -248,6 +250,23 @@ class NormalizingAudioCollector(AudioCollectorFilter): return frames +class CompressingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + # subtle compression has a slight positive effect on my benchmark + audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0) + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + # A segment of transcribed audio. `start_ts` and `end_ts` are floating point # number of seconds since the beginning of audio data. class Segment: @@ -291,7 +310,8 @@ class Whisper: dname = os.path.dirname(abspath) model_root = os.path.join(dname, "Models", cfg["model"]) - print("Model {} will be saved to {}".format(cfg["model"], model_root)) + print(f"Model {cfg['model']} will be saved to {model_root}", + file=sys.stderr) model_device = "cuda" if cfg["use_cpu"]: @@ -320,9 +340,8 @@ class Whisper: beam_size = 5, language = langcodes.find(self.cfg["language"]).language, temperature = 0.0, - log_prob_threshold = -0.8, + log_prob_threshold = -1.0, vad_filter = True, - condition_on_previous_text = True, without_timestamps = False) res = [] for s in segments: @@ -436,8 +455,9 @@ def evaluate(cfg, stream = DiskStream(audio_path) collector = AudioCollector(stream) - collector = LengthEnforcingAudioCollector(collector, 5.0) - collector = NormalizingAudioCollector(collector) + #collector = LengthEnforcingAudioCollector(collector, 5.0) + #collector = NormalizingAudioCollector(collector) + collector = CompressingAudioCollector(collector) whisper = Whisper(collector, cfg) com = FuzzyRepeatCommitter(collector, whisper, last_n_must_match=last_n_must_match, @@ -548,6 +568,7 @@ def run(cfg): collector = AudioCollector(stream) #collector = LengthEnforcingAudioCollector(collector, 5.0) #collector = NormalizingAudioCollector(collector) + collector = CompressingAudioCollector(collector) whisper = Whisper(collector, cfg) com = FuzzyRepeatCommitter(collector, whisper) @@ -562,11 +583,10 @@ def run(cfg): transcript += commit.delta - print(f"{transcript}{commit.preview}") - if True and len(commit.delta): - print(f"commit latency: {commit.latency_s}") - print(f"commit thresh: {commit.thresh_at_commit}") + print(f"{transcript}") + print(f"commit latency: {commit.latency_s}", file=sys.stderr) + print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr) if __name__ == "__main__": parser = argparse.ArgumentParser() |
