summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--Images/unigram_lut_for_visualization.pngbin0 -> 489395 bytes
-rw-r--r--LICENSE7
-rw-r--r--README.md45
-rw-r--r--app_config.py39
-rw-r--r--config.yaml18
-rw-r--r--hi.py77
-rw-r--r--requirements.txt4
-rw-r--r--shared_thread_data.py9
-rw-r--r--stt.py581
-rw-r--r--vad.py313
11 files changed, 1068 insertions, 28 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a102cf0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+.*.sw[po]
+*.meta
+
diff --git a/Images/unigram_lut_for_visualization.png b/Images/unigram_lut_for_visualization.png
new file mode 100644
index 0000000..622419d
--- /dev/null
+++ b/Images/unigram_lut_for_visualization.png
Binary files differ
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..1ebdcb5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,7 @@
+Copyright 2025 yum_food
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/README.md b/README.md
index eaeceea..abb0576 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,40 @@
# Optimized text paging for VRChat
+This repo provides code to help you send English text into VRChat. It includes:
+
+1. Training code to produce an English-language tokenizer of any vocabulary
+ size.
+2. Code to turn your tokenizer into a lookup table for GPU decoding.
+3. Unity code to generate an animator to shuttle data from OSC to material
+ properties.
+4. OSC code to talk to your Unity animator.
+
+To get started, see Quick Start.
+
+## Quick start
+
+1. Clone this repo.
+2. Clone my toon shader, [2ner](https://github.com/yum-food/2ner).
+3. Install Lyuma's av3emulator.
+4. Drag STT.prefab onto your avatar's root.
+5. Enter play mode.
+6. Open PowerShell.
+
+```bash
+$ cd ~
+$ mkdir tmp
+$ cd tmp
+$ python.exe -m venv venv
+$ ./venv/Scripts/Activate.ps1
+$ pushd /path/to/FastTextPaging/
+$ pip3 install -r requirements.txt
+$ python3 ./hi.py
+```
+
+7. Start typing.
+
+## Design overview
+
It is sometimes useful to send text data into VRChat, for example for
speech-to-text (STT). This is typically done naively, with a "block" of
n 8-bit characters\* sent in along with an 8-bit pointer. Since avatars can only
@@ -19,7 +54,7 @@ used. Thus to reach a typical reading speed, you need to use (260/4.7) = 55.5
OSC bits. The goal of this module is to get more out of these bits by
compressing text over the wire.
-## Unigram tokenizer
+### Unigram tokenizer
Byte pair encoding (BPE) is an encoding scheme frequently used in natural
language processing (NLP) contexts. For any language with a fixed character set
@@ -127,7 +162,7 @@ bits naive rate bpe rate speedup factor
I reserve 39 token slots for sequences of whitespace characters of length 2-40. This helps simplify formatting. To end a line or position text, you can just send in the exact right number of spaces, and a fixed-width font renderer will position things as intended.
-## Paging data into shader
+### Paging data into shader
Sending this data to a shader is pretty simple:
@@ -224,7 +259,7 @@ void GetTokens(uint screen_ptr, out uint block_ptr, out uint tokens[BLOCK_WIDTH]
}
```
-## GPU decoding
+### GPU decoding
Now we have to translate the tokens into text. I do this with a texture laid out as follows:
@@ -236,6 +271,10 @@ My tokenizer's vocabulary is 65,536 tokens. If we add up the lengths of every to
So, the entire vocabulary - length+offset head and content - requires a 32-bit RGBA texture with 232,419 slots. We'll just jam this into a 512x512 texture, at an occupancy ratio of 88.66% (11.34% waste). The total VRAM usage of that lookup table (LUT) is 1 MiB.
+![Unigram tokenizer texture](Images/unigram_lut_for_visualization.png)
+
+*A 64K vocabulary tokenizer I trained on Wikipedia and OpenSubtitles.*
+
We want to implement this API:
```c
diff --git a/app_config.py b/app_config.py
new file mode 100644
index 0000000..f911456
--- /dev/null
+++ b/app_config.py
@@ -0,0 +1,39 @@
+import os
+import sys
+import typing
+
+def getConfig(path: str) -> typing.Dict[str, typing.Union[str, float, int, bool]]:
+ # Helper functions to detect and convert the type
+ def is_int(value: str) -> bool:
+ try:
+ int(value)
+ return True
+ except ValueError:
+ return False
+
+ def is_float(value: str) -> bool:
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+ def convert_value(key: str, value: str):
+ if key.startswith(("enable_", "remove_", "use_", "clear_")):
+ return bool(int(value))
+ elif is_int(value):
+ return int(value)
+ elif is_float(value):
+ return float(value)
+ else:
+ return value
+
+ config = {}
+ with open(path, 'r') as file:
+ for line in file:
+ key_value = line.strip().split(": ", maxsplit=1)
+ key = key_value[0]
+ value = key_value[1] if len(key_value) > 1 else ""
+ config[key] = convert_value(key, value.strip())
+ return config
+
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 0000000..164b4e6
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,18 @@
+compute_type: int8
+enable_debug_mode: 0
+enable_previews: 1
+language: english
+gpu_idx: 0
+max_speech_duration_s: 10
+min_silence_duration_ms: 250
+microphone: motu
+model: turbo
+reset_after_silence_s: 15
+transcription_loop_delay_ms: 100
+use_cpu: 0
+
+block_width: 2
+num_blocks: 40
+rows: 10
+cols: 24
+
diff --git a/hi.py b/hi.py
index 7c68071..0129958 100644
--- a/hi.py
+++ b/hi.py
@@ -1,8 +1,11 @@
+import app_config
import argparse
from math import floor, ceil
import msvcrt
from pythonosc import udp_client
import sentencepiece as spm
+from shared_thread_data import SharedThreadData
+import stt
import sys
import threading
import time
@@ -22,10 +25,7 @@ def get_tokenizer():
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block")
- parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator")
- parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox")
- parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox")
+ parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True)
return parser.parse_args()
def assert_equal(a, b):
@@ -244,27 +244,48 @@ def getOscClient(ip = "127.0.0.1", port = 9000):
class InputState:
def __init__(self):
+ self.page = 0
# Initialize the known state of the board to empty array. This will cause
# our paging logic to re-send everything the first time around.
self.blocks = []
self.visual_pointers = []
pass
-def handle_input(state: InputState, line: str, tokenizer, osc_client, args):
- line_wrapped = wrap_line(line, args.cols)
+def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
+ line_wrapped = wrap_line(line, cfg["cols"])
if TESTS_ENABLED:
for line in line_wrapped:
- assert_equal(len(line), args.cols)
+ assert_equal(len(line), cfg["cols"])
if LOG_LEVEL == 2:
print(f"Wrapped lines: {line_wrapped}")
+
+ # Get several blank lines whenever we roll over.
+ # It's better for the reader to have some continuity when the board pages
+ # over. If we simply replaced the entire screen, it would be harder to
+ # understand.
+ line_rollover = cfg["rows"] - 2
+ blank_line = ' ' * cfg["cols"]
+ # We show a full page, then only `line_rollover` additional lines per page.
+ end_ptr = cfg["rows"]
+ which_page = 0
+ while end_ptr < len(line_wrapped):
+ end_ptr += line_rollover
+ which_page += 1
+ if state.page != which_page:
+ state.blocks = []
+ state.visual_pointers = []
+ state.page = which_page
+ line_wrapped = line_wrapped[end_ptr-cfg["rows"]:]
+
+ # Get blocks and visual pointers.
blocks, visual_pointers = get_blocks(line_wrapped, tokenizer,
- args.block_width, args.num_blocks)
- # Wrap visual pointers. This is mostly done just to stay within our
- # limited 8-bit precision budget. The shader also has to wrap in order
- # to properly display things.
- visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ]
+ cfg["block_width"], cfg["num_blocks"])
+
+ # Note that because we only send one page of data at a time, we don't have
+ # to worry about wrapping visual pointers! We will basically never run out
+ # of space.
indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers)
- indices = [idx % args.num_blocks for idx in indices]
+ indices = [idx % cfg["num_blocks"] for idx in indices]
# Send only one block at a time to make things snappier in interactive use
# case.
# TODO use a continuation (yield) instead of returning. Then we can be a
@@ -285,41 +306,46 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, args):
send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]])
-class SharedThreadData:
- def __init__(self):
- self.word = ""
- self.word_lock = threading.Lock()
- self.exit_event = threading.Event()
-
-def osc_thread(shared_data: SharedThreadData, cli_args):
+def osc_thread(shared_data: SharedThreadData):
tokenizer = get_tokenizer()
osc_client = getOscClient()
# Prime the board
print("Priming the board")
input_state = InputState()
- handle_input(input_state, "", tokenizer, osc_client, cli_args)
+ handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
while not shared_data.exit_event.is_set():
word_copy = ""
with shared_data.word_lock:
word_copy = shared_data.word
- handle_input(input_state, word_copy, tokenizer, osc_client, cli_args)
+ handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
time.sleep(0.01)
if __name__ == "__main__":
cli_args = parse_args()
-
- shared_data = SharedThreadData()
+ cfg = app_config.getConfig(cli_args.config)
+ shared_data = SharedThreadData(cfg)
osc_thread = threading.Thread(
target=osc_thread,
- args=(shared_data, cli_args))
+ args=(shared_data,))
osc_thread.start()
+ transcribe_thread = threading.Thread(
+ target=stt.transcriptionThread,
+ args=(shared_data,))
+ transcribe_thread.start()
+
word_is_over = False
local_word = ""
while True:
char_bytes = msvcrt.getch()
+ if char_bytes == b'\x03': # ctrl+C
+ break
+
+ time.sleep(0.1)
+ continue
+
try:
char = char_bytes.decode('utf-8')
@@ -354,4 +380,5 @@ if __name__ == "__main__":
print(local_word + "_")
shared_data.exit_event.set()
osc_thread.join()
+ transcribe_thread.join()
diff --git a/requirements.txt b/requirements.txt
index 104e7b8..1043fae 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,9 @@
datasets
+faster-whisper
+langcodes
pillow
+pyaudio
+pydub
python-osc
unidecode
sentencepiece
diff --git a/shared_thread_data.py b/shared_thread_data.py
new file mode 100644
index 0000000..ba0a419
--- /dev/null
+++ b/shared_thread_data.py
@@ -0,0 +1,9 @@
+import threading
+
+class SharedThreadData:
+ def __init__(self, cfg):
+ self.word = ""
+ self.word_lock = threading.Lock()
+ self.exit_event = threading.Event()
+ self.cfg = cfg
+
diff --git a/stt.py b/stt.py
new file mode 100644
index 0000000..34ef2e9
--- /dev/null
+++ b/stt.py
@@ -0,0 +1,581 @@
+from faster_whisper import WhisperModel
+import langcodes
+import numpy as np
+import os
+import pyaudio
+from pydub import AudioSegment
+from shared_thread_data import SharedThreadData
+import sys
+import time
+import typing
+import vad
+
+class AudioStream():
+ FORMAT = pyaudio.paInt16
+ # Size of each frame (audio sample), in bytes. If you change FORMAT, make
+ # sure this stays up to date!
+ FRAME_SZ = 2
+ # Frames per second.
+ FPS = 16000
+ CHANNELS = 1
+ def __init__(self):
+ pass
+
+ def getSamples(self) -> bytes:
+ raise NotImplementedError("getSamples is not implemented!")
+
+class MicStream(AudioStream):
+ CHUNK_SZ = 1024
+
+ def __init__(self, which_mic: str):
+ self.p = pyaudio.PyAudio()
+ self.stream = None
+ self.sample_rate = None
+ # Each time pyaudio gives us audio data, it's in the form of a chunk of
+ # samples. We keep these in a list to keep the audio callback as light
+ # as possible. Whenever downstream layers want data, we collapse the
+ # list into a single array of data (a bytes object).
+ self.chunks = []
+ # If set, incoming frames are simply discarded.
+ self.paused = False
+
+ print(f"Finding mic {which_mic}", file=sys.stderr)
+ self.dumpMicDevices()
+
+ got_match = False
+ device_index = -1
+ if which_mic == "index":
+ target_str = "Digital Audio Interface"
+ elif which_mic == "focusrite":
+ target_str = "Focusrite"
+ elif which_mic == "motu":
+ target_str = "In 1-2 (MOTU M Series)"
+ elif which_mic == "beyond":
+ target_str = "Microphone (Beyond)"
+ else:
+ print(f"Mic {which_mic} requested, treating it as a numerical " +
+ "device ID", file=sys.stderr)
+ device_index = int(which_mic)
+ got_match = True
+ if not got_match:
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ if target_str in device_name:
+ print(f"Got matching mic: {device_name}",
+ file=sys.stderr)
+ device_index = i
+ got_match = True
+ break
+ if not got_match:
+ raise KeyError(f"Mic {which_mic} not found")
+
+ info = self.p.get_device_info_by_host_api_device_index(0, device_index)
+ print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr)
+ self.sample_rate = int(info['defaultSampleRate'])
+ print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr)
+
+ self.stream = self.p.open(
+ rate=self.sample_rate,
+ channels=AudioStream.CHANNELS,
+ format=AudioStream.FORMAT,
+ input=True,
+ frames_per_buffer=MicStream.CHUNK_SZ,
+ input_device_index=device_index,
+ stream_callback=self.onAudioFramesAvailable)
+
+ self.stream.start_stream()
+
+ AudioStream.__init__(self)
+
+ def pause(self, state: bool = True):
+ self.paused = state
+
+ def dumpMicDevices(self):
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ print("Input Device id ", i, " - ", device_name)
+
+ def onAudioFramesAvailable(self,
+ frames,
+ frame_count,
+ time_info,
+ status_flags):
+ if self.paused:
+ # Don't literally pause, just start returning silence. This allows
+ # the `min_segment_age_s` check to work while paused.
+ n_frames = int(frame_count * AudioStream.FPS /
+ float(self.sample_rate))
+ self.chunks.append(np.zeros(n_frames,
+ dtype=np.int16).tobytes())
+ return (frames, pyaudio.paContinue)
+
+ decimated = b''
+ # In pyaudio, a `frame` is a single sample of audio data.
+ frame_len = AudioStream.FRAME_SZ
+ next_frame = 0.0
+ # The mic probably has a higher sample rate than Whisper wants, so
+ # decrease the sample rate by dropping samples. Note that this
+ # algorithm only works if the mic's rate is higher than whisper's
+ # expected rate.
+ keep_every = float(self.sample_rate) / AudioStream.FPS
+ for i in range(frame_count):
+ if i >= next_frame:
+ decimated += frames[i*frame_len:(i+1)*frame_len]
+ next_frame += keep_every
+ self.chunks.append(decimated)
+
+ return (frames, pyaudio.paContinue)
+
+ # Get audio data and the corresponding timestamp.
+ def getSamples(self) -> bytes:
+ chunks = self.chunks
+ self.chunks = []
+ result = b''.join(chunks)
+ return result
+
+class AudioCollector:
+ def __init__(self, stream: AudioStream):
+ self.stream = stream
+ self.frames = b''
+ # Note: by design, this is the only spot where we anchor our timestamps
+ # against the real world. This is done to make it possible to profile
+ # test cases which read from disk (at much faster than real speed) in
+ # the same way that we profile real-time data.
+ self.wall_ts = time.time()
+
+ def getAudio(self) -> bytes:
+ frames = self.stream.getSamples()
+ if frames:
+ self.frames += frames
+ return self.frames
+
+ def dropAudioPrefix(self, dur_s: float) -> bytes:
+ n_bytes = int(dur_s * AudioStream.FPS) * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ cut_portion = self.frames[:n_bytes]
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS
+ return cut_portion
+
+ def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes:
+ n_bytes = dur_frames * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ cut_portion = self.frames[:n_bytes]
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS
+ return cut_portion
+
+ def keepLast(self, dur_s: float) -> bytes:
+ drop_len = max(0, self.duration() - dur_s)
+ return self.dropAudioPrefix(drop_len)
+
+ def dropAudio(self):
+ self.wall_ts += self.duration()
+ cut_portion = self.frames
+ self.frames = b''
+ return cut_portion
+
+ def duration(self):
+ return len(self.frames) / (AudioStream.FPS * self.stream.FRAME_SZ)
+
+ def begin(self):
+ return self.wall_ts
+
+ def now(self):
+ return self.begin() + self.duration()
+
+class AudioCollectorFilter:
+ def __init__(self, parent: AudioCollector):
+ self.parent = parent
+
+ def getAudio(self) -> bytes:
+ return self.parent.getAudio()
+ def dropAudioPrefix(self, dur_s: float):
+ return self.parent.dropAudioPrefix(dur_s)
+ def dropAudioPrefixByFrames(self, dur_frames: int):
+ return self.parent.dropAudioPrefixByFrames(dur_frames)
+ def keepLast(self, dur_s):
+ return self.parent.keepLast(dur_s)
+ def dropAudio(self):
+ return self.parent.dropAudio()
+ def duration(self):
+ return self.parent.duration()
+ def begin(self):
+ return self.parent.begin()
+ def now(self):
+ return self.parent.now()
+
+# Audio collector that enforces a minimum length on its audio data.
+class LengthEnforcingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector, min_duration_s: float):
+ AudioCollectorFilter.__init__(self, parent)
+ self.min_duration_s = min_duration_s
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+ min_duration_frames = int(self.min_duration_s * AudioStream.FPS)
+ pad_len_frames = max(0, min_duration_frames - int(len(audio) /
+ AudioStream.FRAME_SZ))
+ pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes()
+ return pad + audio
+
+class NormalizingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ audio = audio.normalize()
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
+class CompressingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ # subtle compression has a slight positive effect on my benchmark
+ audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0)
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
+class AudioSegmenter:
+ def __init__(self,
+ min_silence_ms=250,
+ max_speech_s=5):
+ self.vad_options = vad.VadOptions(
+ min_silence_duration_ms=min_silence_ms,
+ max_speech_duration_s=max_speech_s)
+ pass
+
+ def segmentAudio(self, audio: bytes):
+ audio = np.frombuffer(audio,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+ return vad.get_speech_timestamps(audio, vad_options=self.vad_options)
+
+ # Returns the stable cutoff (if any) and whether there are any segments.
+ def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]:
+ min_delta_frames = int((self.vad_options.min_silence_duration_ms *
+ AudioStream.FPS) / 1000.0)
+ cutoff = None
+
+ last_end = None
+ segments = self.segmentAudio(audio)
+
+ for i in range(len(segments)):
+ s = segments[i]
+ #print(f"s: {s}")
+ #print(f"last_end: {last_end}")
+
+ if last_end:
+ delta_frames = s['start'] - last_end
+ #print(f"delta frames: {delta_frames}")
+ if delta_frames > min_delta_frames:
+ cutoff = s['start']
+ else:
+ last_end = s['end']
+
+ if i == len(segments) - 1:
+ now = int(len(audio) / AudioStream.FRAME_SZ)
+ #print(f"now: {now}")
+ #print(f"min d: {min_delta_frames}")
+ delta_frames = now - s['end']
+ if delta_frames > min_delta_frames:
+ cutoff = now - int(min_delta_frames / 2)
+
+ return (cutoff, len(segments) > 0)
+
+# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
+# number of seconds since the beginning of audio data.
+class Segment:
+ def __init__(self,
+ transcript: str,
+ start_ts: float,
+ end_ts: float,
+ wall_ts: float,
+ avg_logprob: float,
+ no_speech_prob: float,
+ compression_ratio: float):
+ self.transcript = transcript
+ # start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
+ self.start_ts = start_ts
+ self.end_ts = end_ts
+ # wall_ts is the time.time() at which the oldest audio sample leading
+ # to this transcript was collected.
+ self.wall_ts = wall_ts
+ self.avg_logprob = avg_logprob
+ self.no_speech_prob = no_speech_prob
+ self.compression_ratio = compression_ratio
+
+ def __str__(self):
+ ts = f"(ts: {self.start_ts}-{self.end_ts}) "
+
+ wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) "
+
+ no_speech = f"(no_speech: {self.no_speech_prob}) "
+ avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+
+class Whisper:
+ def __init__(self,
+ collector: AudioCollector,
+ cfg: typing.Dict):
+ self.collector = collector
+ self.model = None
+ self.cfg = cfg
+
+ abspath = os.path.abspath(__file__)
+ my_dir = os.path.dirname(abspath)
+ parent_dir = os.path.dirname(my_dir)
+
+ model_str = cfg["model"]
+ model_root = os.path.join(parent_dir, "Models",
+ os.path.normpath(model_str))
+ print(f"Model {cfg['model']} will be saved to {model_root}",
+ file=sys.stderr)
+
+ model_device = "cuda"
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+
+ already_downloaded = os.path.exists(model_root)
+
+ self.model = WhisperModel(model_str,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = cfg["compute_type"],
+ download_root = model_root,
+ local_files_only = already_downloaded)
+
+ def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
+ if frames is None:
+ frames = self.collector.getAudio()
+ # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on
+ # [-1, 1].
+ audio = np.frombuffer(frames,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+
+ t0 = time.time()
+ segments, info = self.model.transcribe(
+ audio,
+ language = langcodes.find(self.cfg["language"]).language,
+ vad_filter = True,
+ temperature=0.0,
+ without_timestamps = False)
+ res = []
+ for s in segments:
+ # Manual touchup. I see a decent number of hallucinations sneaking
+ # in with high `no_speech_prob` and modest `avg_logprob`.
+ if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 1) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ continue
+ # Another touchup targeted at the vexatious "thanks for watching!"
+ # hallucination. This triggers a lot when listening to
+ # instrumental/electronic music.
+ if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 2) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ continue
+ if self.cfg["enable_debug_mode"]:
+ print(f"s get: {s}")
+ if s.avg_logprob < -1.0:
+ continue
+ if s.compression_ratio > 2.4:
+ continue
+ res.append(Segment(s.text, s.start, s.end,
+ self.collector.begin(),
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ t1 = time.time()
+ if self.cfg["enable_debug_mode"]:
+ print(f"Transcription latency (s): {t1 - t0}")
+ return res
+
+class TranscriptCommit:
+ def __init__(self,
+ delta: str,
+ preview: str,
+ latency_s: float = None,
+ thresh_at_commit: int = None,
+ audio: bytes = None,
+ duration_s: float = None,
+ start_ts: float = None):
+ self.delta = delta
+ self.preview = preview
+ self.latency_s = latency_s
+ self.thresh_at_commit = thresh_at_commit
+ self.audio = audio
+ # Time at which the commit is generated
+ self.ts = time.time()
+ # Time corresponding to the start of the segment
+ self.start_ts = start_ts
+ # The duration of the audio segment, in seconds.
+ self.duration_s = duration_s
+
+
+class VadCommitter:
+ def __init__(self,
+ cfg: typing.Dict,
+ collector: AudioCollector,
+ whisper: Whisper,
+ segmenter: AudioSegmenter):
+ self.cfg = cfg
+ self.collector = collector
+ self.whisper = whisper
+ self.segmenter = segmenter
+
+ def getDelta(self) -> TranscriptCommit:
+ audio = self.collector.getAudio()
+ stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio)
+
+ delta = ""
+ commit_audio = None
+ latency_s = None
+ duration_s = self.collector.duration()
+ start_ts = self.collector.begin()
+
+ if has_audio and stable_cutoff:
+ #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr)
+ latency_s = self.collector.now() - self.collector.begin()
+ duration_s = stable_cutoff / AudioStream.FPS
+ start_ts = self.collector.begin()
+ commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff)
+
+ segments = self.whisper.transcribe(commit_audio)
+ delta = ''.join(s.transcript for s in segments)
+ audio = self.collector.getAudio()
+ if self.cfg["enable_debug_mode"]:
+ for s in segments:
+ print(f"commit segment: {s}", file=sys.stderr)
+ print(f"delta get: {delta}", file=sys.stderr)
+
+ if False:
+ ts = datetime.fromtimestamp(self.collector.now() - latency_s)
+ filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav"
+ saveAudio(commit_audio, filename)
+
+ preview = ""
+ if self.cfg["enable_previews"] and has_audio:
+ segments = self.whisper.transcribe(audio)
+ preview = "".join(s.transcript for s in segments)
+
+ if not has_audio:
+ #print("VAD detects no audio, skip transcription", file=sys.stderr)
+ self.collector.keepLast(1.0)
+
+ return TranscriptCommit(
+ delta.strip(),
+ preview.strip(),
+ latency_s,
+ audio=audio,
+ duration_s=duration_s,
+ start_ts=start_ts)
+
+def transcriptionThread(shared_data: SharedThreadData):
+ last_stable_commit = None
+
+ stream = MicStream(shared_data.cfg["microphone"])
+ collector = AudioCollector(stream)
+ collector = NormalizingAudioCollector(collector)
+ collector = CompressingAudioCollector(collector)
+ whisper = Whisper(collector, shared_data.cfg)
+ segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
+ max_speech_s=shared_data.cfg["max_speech_duration_s"])
+ committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter)
+
+ transcript = ""
+ preview = ""
+
+ while not shared_data.exit_event.is_set():
+ time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0);
+
+ op = None
+
+ commit = committer.getDelta()
+
+ if len(commit.delta) > 0 or len(commit.preview) > 0:
+ # Avoid re-sending text after long pauses. User controls the length
+ # of the pause in the UI.
+ if shared_data.cfg["reset_after_silence_s"] > 0:
+ silence_duration = 0
+ if last_stable_commit:
+ last_commit_end_ts = \
+ last_stable_commit.start_ts + \
+ last_stable_commit.duration_s
+ silence_duration = commit.start_ts - last_commit_end_ts
+ if silence_duration > shared_data.cfg["reset_after_silence_s"]:
+ print(f"Resetting transcript after {silence_duration}-second "
+ "silence", file=sys.stderr)
+ transcript = ""
+ preview = ""
+ if commit.delta:
+ last_stable_commit = commit
+
+ # Hard-cap displayed transcript length at 4k characters to prevent
+ # runaway memory use in UI. Keep the full transcript to avoid
+ # breaking OSC pager.
+ transcript = transcript[-4096:]
+ def join_segments(a, b):
+ if len(a) > 0 and a[-1] != ' ':
+ return a + ' ' + b
+ else:
+ return a + b
+ transcript = join_segments(transcript, commit.delta)
+ preview = commit.preview
+
+ try:
+ print(f"Transcript: {transcript}")
+ except UnicodeEncodeError:
+ print("Failed to encode transcript - discarding delta",
+ file=sys.stderr)
+ continue
+ try:
+ print(f"Preview: {preview}")
+ except UnicodeEncodeError:
+ print("Failed to encode preview - discarding", file=sys.stderr)
+
+ with shared_data.word_lock:
+ shared_data.word = join_segments(transcript, preview)
+
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"commit latency: {commit.latency_s}", file=sys.stderr)
+ print(f"commit thresh: {commit.thresh_at_commit}",
+ file=sys.stderr)
+
+ if len(transcript) > 0 and \
+ (not transcript.endswith(' ')) and \
+ (not commit.delta.startswith(' ')):
+ commit.delta = ' ' + commit.delta
+ if len(commit.delta) > 0 and \
+ (not commit.delta.endswith(' ')) and \
+ (not commit.preview.startswith(' ')):
+ commit.preview = ' ' + commit.preview
+
diff --git a/vad.py b/vad.py
new file mode 100644
index 0000000..10a72d3
--- /dev/null
+++ b/vad.py
@@ -0,0 +1,313 @@
+# MIT License
+#
+# Copyright (c) 2023 Guillaume Klein
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import bisect
+import functools
+import os
+import warnings
+
+from typing import List, NamedTuple, Optional
+
+import numpy as np
+
+
+# The code below is adapted from https://github.com/snakers4/silero-vad.
+class VadOptions(NamedTuple):
+ """VAD options.
+
+ Attributes:
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
+ than max_speech_duration_s will be split at the timestamp of the last silence that
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
+ split aggressively just before max_speech_duration_s.
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
+ before separating it
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
+ Values other than these may affect model performance!!
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
+ """
+
+ threshold: float = 0.5
+ min_speech_duration_ms: int = 250
+ max_speech_duration_s: float = float("inf")
+ min_silence_duration_ms: int = 2000
+ window_size_samples: int = 1024
+ speech_pad_ms: int = 400
+
+
+def get_speech_timestamps(
+ audio: np.ndarray,
+ vad_options: Optional[VadOptions] = None,
+ **kwargs,
+) -> List[dict]:
+ """This method is used for splitting long audios into speech chunks using silero VAD.
+
+ Args:
+ audio: One dimensional float array.
+ vad_options: Options for VAD processing.
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
+
+ Returns:
+ List of dicts containing begin and end samples of each speech chunk.
+ """
+ if vad_options is None:
+ vad_options = VadOptions(**kwargs)
+
+ threshold = vad_options.threshold
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
+ max_speech_duration_s = vad_options.max_speech_duration_s
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
+ window_size_samples = vad_options.window_size_samples
+ speech_pad_ms = vad_options.speech_pad_ms
+
+ if window_size_samples not in [512, 1024, 1536]:
+ warnings.warn(
+ "Unusual window_size_samples! Supported window_size_samples:\n"
+ " - [512, 1024, 1536] for 16000 sampling_rate"
+ )
+
+ sampling_rate = 16000
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
+ max_speech_samples = (
+ sampling_rate * max_speech_duration_s
+ - window_size_samples
+ - 2 * speech_pad_samples
+ )
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
+
+ audio_length_samples = len(audio)
+
+ model = get_vad_model()
+ state = model.get_initial_state(batch_size=1)
+
+ speech_probs = []
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
+ if len(chunk) < window_size_samples:
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
+ speech_prob, state = model(chunk, state, sampling_rate)
+ speech_probs.append(speech_prob)
+
+ triggered = False
+ speeches = []
+ current_speech = {}
+ neg_threshold = threshold - 0.15
+
+ # to save potential segment end (and tolerate some silence)
+ temp_end = 0
+ # to save potential segment limits in case of maximum segment size reached
+ prev_end = next_start = 0
+
+ for i, speech_prob in enumerate(speech_probs):
+ if (speech_prob >= threshold) and temp_end:
+ temp_end = 0
+ if next_start < prev_end:
+ next_start = window_size_samples * i
+
+ if (speech_prob >= threshold) and not triggered:
+ triggered = True
+ current_speech["start"] = window_size_samples * i
+ continue
+
+ if (
+ triggered
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
+ ):
+ if prev_end:
+ current_speech["end"] = prev_end
+ speeches.append(current_speech)
+ current_speech = {}
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
+ if next_start < prev_end:
+ triggered = False
+ else:
+ current_speech["start"] = next_start
+ prev_end = next_start = temp_end = 0
+ else:
+ current_speech["end"] = window_size_samples * i
+ speeches.append(current_speech)
+ current_speech = {}
+ prev_end = next_start = temp_end = 0
+ triggered = False
+ continue
+
+ if (speech_prob < neg_threshold) and triggered:
+ if not temp_end:
+ temp_end = window_size_samples * i
+ # condition to avoid cutting in very short silence
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
+ prev_end = temp_end
+ if (window_size_samples * i) - temp_end < min_silence_samples:
+ continue
+ else:
+ current_speech["end"] = temp_end
+ if (
+ current_speech["end"] - current_speech["start"]
+ ) > min_speech_samples:
+ speeches.append(current_speech)
+ current_speech = {}
+ prev_end = next_start = temp_end = 0
+ triggered = False
+ continue
+
+ if (
+ current_speech
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
+ ):
+ current_speech["end"] = audio_length_samples
+ speeches.append(current_speech)
+
+ for i, speech in enumerate(speeches):
+ if i == 0:
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
+ if i != len(speeches) - 1:
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
+ if silence_duration < 2 * speech_pad_samples:
+ speech["end"] += int(silence_duration // 2)
+ speeches[i + 1]["start"] = int(
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
+ )
+ else:
+ speech["end"] = int(
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
+ )
+ speeches[i + 1]["start"] = int(
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
+ )
+ else:
+ speech["end"] = int(
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
+ )
+
+ return speeches
+
+
+def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
+ """Collects and concatenates audio chunks."""
+ if not chunks:
+ return np.array([], dtype=np.float32)
+
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
+
+
+class SpeechTimestampsMap:
+ """Helper class to restore original speech timestamps."""
+
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
+ self.sampling_rate = sampling_rate
+ self.time_precision = time_precision
+ self.chunk_end_sample = []
+ self.total_silence_before = []
+
+ previous_end = 0
+ silent_samples = 0
+
+ for chunk in chunks:
+ silent_samples += chunk["start"] - previous_end
+ previous_end = chunk["end"]
+
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
+ self.total_silence_before.append(silent_samples / sampling_rate)
+
+ def get_original_time(
+ self,
+ time: float,
+ chunk_index: Optional[int] = None,
+ ) -> float:
+ if chunk_index is None:
+ chunk_index = self.get_chunk_index(time)
+
+ total_silence_before = self.total_silence_before[chunk_index]
+ return round(total_silence_before + time, self.time_precision)
+
+ def get_chunk_index(self, time: float) -> int:
+ sample = int(time * self.sampling_rate)
+ return min(
+ bisect.bisect(self.chunk_end_sample, sample),
+ len(self.chunk_end_sample) - 1,
+ )
+
+
+@functools.lru_cache
+def get_vad_model():
+ """Returns the VAD model instance."""
+ abspath = os.path.abspath(__file__)
+ my_dir = os.path.dirname(abspath)
+ path = os.path.join(my_dir, "Models/silero_vad.onnx")
+ return SileroVADModel(path)
+
+
+class SileroVADModel:
+ def __init__(self, path):
+ try:
+ import onnxruntime
+ except ImportError as e:
+ raise RuntimeError(
+ "Applying the VAD filter requires the onnxruntime package"
+ ) from e
+
+ opts = onnxruntime.SessionOptions()
+ opts.inter_op_num_threads = 1
+ opts.intra_op_num_threads = 1
+ opts.log_severity_level = 4
+
+ self.session = onnxruntime.InferenceSession(
+ path,
+ providers=["CPUExecutionProvider"],
+ sess_options=opts,
+ )
+
+ def get_initial_state(self, batch_size: int):
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
+ return h, c
+
+ def __call__(self, x, state, sr: int):
+ if len(x.shape) == 1:
+ x = np.expand_dims(x, 0)
+ if len(x.shape) > 2:
+ raise ValueError(
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
+ )
+ if sr / x.shape[1] > 31.25:
+ raise ValueError("Input audio chunk is too short")
+
+ h, c = state
+
+ ort_inputs = {
+ "input": x,
+ "h": h,
+ "c": c,
+ "sr": np.array(sr, dtype="int64"),
+ }
+
+ out, h, c = self.session.run(None, ort_inputs)
+ state = (h, c)
+
+ return out, state