summaryrefslogtreecommitdiffstats
path: root/app/stt.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/stt.py')
-rw-r--r--app/stt.py55
1 files changed, 47 insertions, 8 deletions
diff --git a/app/stt.py b/app/stt.py
index 34ef2e9..c157f6d 100644
--- a/app/stt.py
+++ b/app/stt.py
@@ -1,3 +1,4 @@
+from datetime import datetime
from faster_whisper import WhisperModel
import langcodes
import numpy as np
@@ -9,6 +10,11 @@ import sys
import time
import typing
import vad
+import wave
+
+
+APP_ROOT = os.path.dirname(os.path.abspath(__file__))
+PROJECT_ROOT = os.path.dirname(APP_ROOT)
class AudioStream():
FORMAT = pyaudio.paInt16
@@ -242,6 +248,26 @@ class NormalizingAudioCollector(AudioCollectorFilter):
return frames
+class BoostingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector, target_dBFS: float, cfg: typing.Dict):
+ AudioCollectorFilter.__init__(self, parent)
+ self.target_dBFS = target_dBFS
+ self.cfg = cfg
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ if self.cfg["enable_debug_mode"]:
+ print(f"Boosting audio from {audio.dBFS}dB to {self.target_dBFS}dB", file=sys.stderr)
+ audio = audio.apply_gain(self.target_dBFS - audio.dBFS)
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
class CompressingAudioCollector(AudioCollectorFilter):
def __init__(self, parent: AudioCollector):
AudioCollectorFilter.__init__(self, parent)
@@ -441,6 +467,16 @@ class TranscriptCommit:
self.duration_s = duration_s
+def saveAudio(audio: bytes, path: str, cfg: typing.Dict):
+ with wave.open(path, 'wb') as wf:
+ if cfg["enable_debug_mode"]:
+ print(f"Saving audio to {path}", file=sys.stderr)
+ wf.setnchannels(AudioStream.CHANNELS)
+ wf.setsampwidth(AudioStream.FRAME_SZ)
+ wf.setframerate(AudioStream.FPS)
+ wf.writeframes(audio)
+
+
class VadCommitter:
def __init__(self,
cfg: typing.Dict,
@@ -463,7 +499,6 @@ class VadCommitter:
start_ts = self.collector.begin()
if has_audio and stable_cutoff:
- #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr)
latency_s = self.collector.now() - self.collector.begin()
duration_s = stable_cutoff / AudioStream.FPS
start_ts = self.collector.begin()
@@ -475,12 +510,16 @@ class VadCommitter:
if self.cfg["enable_debug_mode"]:
for s in segments:
print(f"commit segment: {s}", file=sys.stderr)
- print(f"delta get: {delta}", file=sys.stderr)
+ if len(delta) > 0:
+ print(f"delta get: {delta}", file=sys.stderr)
- if False:
+ if self.cfg["save_audio"] and len(delta) > 0:
ts = datetime.fromtimestamp(self.collector.now() - latency_s)
filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav"
- saveAudio(commit_audio, filename)
+ audio_dir = os.path.join(PROJECT_ROOT, "audio")
+ if not os.path.exists(audio_dir):
+ os.makedirs(audio_dir)
+ saveAudio(commit_audio, os.path.join(audio_dir, filename), self.cfg)
preview = ""
if self.cfg["enable_previews"] and has_audio:
@@ -488,7 +527,6 @@ class VadCommitter:
preview = "".join(s.transcript for s in segments)
if not has_audio:
- #print("VAD detects no audio, skip transcription", file=sys.stderr)
self.collector.keepLast(1.0)
return TranscriptCommit(
@@ -504,8 +542,9 @@ def transcriptionThread(shared_data: SharedThreadData):
stream = MicStream(shared_data.cfg["microphone"])
collector = AudioCollector(stream)
- collector = NormalizingAudioCollector(collector)
collector = CompressingAudioCollector(collector)
+ collector = NormalizingAudioCollector(collector)
+ collector = BoostingAudioCollector(collector, 0.0, shared_data.cfg)
whisper = Whisper(collector, shared_data.cfg)
segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
max_speech_s=shared_data.cfg["max_speech_duration_s"])
@@ -552,13 +591,13 @@ def transcriptionThread(shared_data: SharedThreadData):
preview = commit.preview
try:
- print(f"Transcript: {transcript}")
+ print(f"Transcript: {transcript}", flush=True)
except UnicodeEncodeError:
print("Failed to encode transcript - discarding delta",
file=sys.stderr)
continue
try:
- print(f"Preview: {preview}")
+ print(f"Preview: {preview}", flush=True)
except UnicodeEncodeError:
print("Failed to encode preview - discarding", file=sys.stderr)