summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-08 16:46:34 -0700
committeryum <yum.food.vr@gmail.com>2023-09-08 16:48:38 -0700
commitcc857e7c73334ef47a415c0abdb51a49171b9afd (patch)
tree4cb276ea0b94029ca53f5ec298141246ed564432 /Scripts
parent31d84a0cfdd3c30c2fdf1dbfe4e855ae78555707 (diff)
Only transcribe if VAD detects something
Also: * DiskStream starts returning silence when out of data instead of just stopping. * Filter out Whisper segments with high `no_speech_prob` and low `avg_logprob`. * Add `saveAudio` function, useful for debugging. * Tune vad silence cutoff to 250 ms. This is pretty accurate in benchmarks.
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/transcribe_v2.py65
1 files changed, 47 insertions, 18 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index f13a5aa..a7d1ad2 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -22,6 +22,7 @@ import threading
import time
import typing
import vad
+import wave
class ThreadControl:
def __init__(self, cfg):
@@ -71,6 +72,10 @@ class DiskStream(AudioStream):
nframes = int(give_s * AudioStream.FPS)
frames = self.frames[0:nframes * AudioStream.FRAME_SZ];
self.frames = self.frames[nframes * AudioStream.FRAME_SZ:]
+
+ if len(frames) < nframes:
+ frames += np.zeros(nframes - len(frames), dtype=np.int16).tobytes()
+
return frames
class MicStream(AudioStream):
@@ -319,7 +324,8 @@ class AudioSegmenter:
dtype=np.int16).flatten().astype(np.float32) / 32768.0
return vad.get_speech_timestamps(audio, vad_options=self.vad_options)
- def getStableCutoff(self, audio: bytes) -> int:
+ # Returns the stable cutoff (if any) and whether there are any segments.
+ def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]:
min_delta_frames = int((self.vad_options.min_silence_duration_ms *
AudioStream.FPS) / 1000)
cutoff = None
@@ -346,7 +352,7 @@ class AudioSegmenter:
if now - s['end'] > min_delta_frames:
cutoff = now - int(min_delta_frames / 2)
- return cutoff
+ return (cutoff, len(segments) > 0)
# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
# number of seconds since the beginning of audio data.
@@ -420,14 +426,15 @@ class Whisper:
segments, info = self.model.transcribe(
audio,
- beam_size = 5,
language = langcodes.find(self.cfg["language"]).language,
- temperature = 0.0,
- log_prob_threshold = -1.0,
vad_filter = True,
without_timestamps = False)
res = []
for s in segments:
+ # Manual touchup. I see a decent number of hallucinations sneaking
+ # in with high `no_speech_prob` and modest `avg_logprob`.
+ if s.no_speech_prob > 0.8 and s.avg_logprob < -0.5:
+ continue
res.append(Segment(s.text, s.start, s.end,
self.collector.begin(),
s.avg_logprob, s.no_speech_prob))
@@ -446,6 +453,14 @@ class TranscriptCommit:
self.thresh_at_commit = thresh_at_commit
self.audio = audio
+def saveAudio(audio: bytes, path: str):
+ with wave.open(path, 'wb') as wf:
+ print(f"Saving audio to {path}")
+ wf.setnchannels(AudioStream.CHANNELS)
+ wf.setsampwidth(AudioStream.FRAME_SZ)
+ wf.setframerate(AudioStream.FPS)
+ wf.writeframes(audio)
+
class VadCommitter:
def __init__(self,
collector: AudioCollector,
@@ -457,23 +472,34 @@ class VadCommitter:
def getDelta(self) -> TranscriptCommit:
audio = self.collector.getAudio()
- stable_cutoff = self.segmenter.getStableCutoff(audio)
+ stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio)
delta = ""
commit_audio = None
latency_s = None
if stable_cutoff:
#print(f"stable cutoff get: {stable_cutoff}")
- segments = self.whisper.transcribe(audio)
+ latency_s = self.collector.now() - self.collector.begin()
+ commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff)
+
+ segments = self.whisper.transcribe(commit_audio)
+ for s in segments:
+ print(f"commit segment: {s}")
delta = ''.join(s.transcript for s in segments)
#print(f"delta get: {delta}")
- commit_begin = self.collector.begin()
- commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff)
- latency_s = self.collector.now() - commit_begin
audio = self.collector.getAudio()
- segments = self.whisper.transcribe(audio)
- preview = ''.join(s.transcript for s in segments)
+ #ts = datetime.fromtimestamp(self.collector.now() - latency_s)
+ #filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav"
+ #saveAudio(commit_audio, filename)
+
+ preview = ""
+ if has_audio:
+ segments = self.whisper.transcribe(audio)
+ preview = "".join(s.transcript for s in segments)
+ else:
+ #print("VAD detects no audio, skip transcription")
+ self.collector.keepLast(1.0)
return TranscriptCommit(
delta,
@@ -525,16 +551,18 @@ def evaluate(cfg,
committer = VadCommitter(collector, whisper, segmenter)
transcript = ""
commits = []
+ last_commit_ts = None
- while len(stream.frames) > 0:
+ while True:
commit = committer.getDelta()
- if len(stream.frames) == 0:
- commit.delta = commit.preview
- commit.latency_s = 0
+ if last_commit_ts != None and collector.now() - last_commit_ts > 30:
+ break
if len(commit.delta) > 0:
+ print(f"Commit latency: {commit.latency_s}")
commits.append(commit)
+ last_commit_ts = collector.now()
transcript += commit.delta
preview = commit.preview
@@ -834,7 +862,7 @@ def run(cfg):
#collector = NormalizingAudioCollector(collector)
collector = CompressingAudioCollector(collector)
whisper = Whisper(collector, cfg)
- segmenter = AudioSegmenter(min_silence_ms=500)
+ segmenter = AudioSegmenter(min_silence_ms=250)
committer = VadCommitter(collector, whisper, segmenter)
pager = OscPager(cfg)
@@ -889,9 +917,10 @@ if __name__ == "__main__":
"Evaluate/vei/control.txt"),
]
- if True:
+ if False:
sum = 0
for audio, control in experiments:
+ print(f"Run experiment {audio} :: {control}")
sum += evaluate(cfg, audio, control)
print(f"Total score: {sum}")
else: