summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-03 16:56:29 -0700
committeryum <yum.food.vr@gmail.com>2023-09-03 16:56:29 -0700
commit2a4c6051acd8140bde6c1abad62bd613673de4b4 (patch)
treee0c21e333c715f35f434e836f5476b0b57358c3f
parent606d223f8ba9174a2984d7cb15e6e94ef6e48228 (diff)
Apply subtle compression to audio before transcribing
This has a slight positive effect on my benchmark.
-rw-r--r--Scripts/transcribe_v2.py48
1 files changed, 34 insertions, 14 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 2c8c57d..1904526 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -87,7 +87,7 @@ class MicStream(AudioStream):
target_str = focusrite_str
else:
print(f"Mic {which_mic} requested, treating it as a numerical " +
- "device ID")
+ "device ID", file=sys.stderr)
device_index = int(which_mic)
got_match = True
if not got_match:
@@ -97,7 +97,8 @@ class MicStream(AudioStream):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
if target_str in device_name:
- print(f"Got matching mic: {device_name}")
+ print(f"Got matching mic: {device_name}",
+ file=sys.stderr)
device_index = i
got_match = True
break
@@ -105,9 +106,9 @@ class MicStream(AudioStream):
raise KeyError(f"Mic {which_mic} not found")
info = self.p.get_device_info_by_host_api_device_index(0, device_index)
- print(f"Found mic {which_mic}: {info['name']}")
+ print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr)
self.sample_rate = int(info['defaultSampleRate'])
- print(f"Mic sample rate: {self.sample_rate}")
+ print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr)
self.stream = self.p.open(
rate=self.sample_rate,
@@ -129,7 +130,8 @@ class MicStream(AudioStream):
for i in range(0, numdevices):
if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
- print("Input Device id ", i, " - ", device_name)
+ print("Input Device id ", i, " - ", device_name,
+ file=sys.stderr)
def onAudioFramesAvailable(self,
frames,
@@ -248,6 +250,23 @@ class NormalizingAudioCollector(AudioCollectorFilter):
return frames
+class CompressingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ # subtle compression has a slight positive effect on my benchmark
+ audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0)
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
# number of seconds since the beginning of audio data.
class Segment:
@@ -291,7 +310,8 @@ class Whisper:
dname = os.path.dirname(abspath)
model_root = os.path.join(dname, "Models", cfg["model"])
- print("Model {} will be saved to {}".format(cfg["model"], model_root))
+ print(f"Model {cfg['model']} will be saved to {model_root}",
+ file=sys.stderr)
model_device = "cuda"
if cfg["use_cpu"]:
@@ -320,9 +340,8 @@ class Whisper:
beam_size = 5,
language = langcodes.find(self.cfg["language"]).language,
temperature = 0.0,
- log_prob_threshold = -0.8,
+ log_prob_threshold = -1.0,
vad_filter = True,
- condition_on_previous_text = True,
without_timestamps = False)
res = []
for s in segments:
@@ -436,8 +455,9 @@ def evaluate(cfg,
stream = DiskStream(audio_path)
collector = AudioCollector(stream)
- collector = LengthEnforcingAudioCollector(collector, 5.0)
- collector = NormalizingAudioCollector(collector)
+ #collector = LengthEnforcingAudioCollector(collector, 5.0)
+ #collector = NormalizingAudioCollector(collector)
+ collector = CompressingAudioCollector(collector)
whisper = Whisper(collector, cfg)
com = FuzzyRepeatCommitter(collector, whisper,
last_n_must_match=last_n_must_match,
@@ -548,6 +568,7 @@ def run(cfg):
collector = AudioCollector(stream)
#collector = LengthEnforcingAudioCollector(collector, 5.0)
#collector = NormalizingAudioCollector(collector)
+ collector = CompressingAudioCollector(collector)
whisper = Whisper(collector, cfg)
com = FuzzyRepeatCommitter(collector, whisper)
@@ -562,11 +583,10 @@ def run(cfg):
transcript += commit.delta
- print(f"{transcript}{commit.preview}")
-
if True and len(commit.delta):
- print(f"commit latency: {commit.latency_s}")
- print(f"commit thresh: {commit.thresh_at_commit}")
+ print(f"{transcript}")
+ print(f"commit latency: {commit.latency_s}", file=sys.stderr)
+ print(f"commit thresh: {commit.thresh_at_commit}", file=sys.stderr)
if __name__ == "__main__":
parser = argparse.ArgumentParser()