summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-12-26 02:18:58 -0800
committeryum <yum.food.vr@gmail.com>2023-12-26 02:18:58 -0800
commita4c1870f724f18e98c33468b4d038dd1c742e4bd (patch)
tree54721b70fd73997076199da0ea39e0ce5d2ef367
parente773bf75a562a8ed5afe72642ed39ba196ffab75 (diff)
Add optional transcription & curation components
Add a transcribe button, which transcribes each .wav file using openai/whisper-large-v2, producing a corresponding .txt file. Also add a TUI tool for WSL. This tool lets you view transcripts and delete them with vi-like commands. Useful for cleaning data.
-rw-r--r--app.py170
-rw-r--r--curate/Makefile25
-rw-r--r--curate/ui.cc156
-rw-r--r--whisper_requirements.txt3
4 files changed, 351 insertions, 3 deletions
diff --git a/app.py b/app.py
index 064aef3..21fbbf7 100644
--- a/app.py
+++ b/app.py
@@ -6,6 +6,7 @@ import math
import numpy as np
import os
import pyaudio
+import subprocess
import sys
import time
import typing
@@ -157,6 +158,36 @@ class MicStream(AudioStream):
result = b''.join(chunks)
return result
+class DiskStream(AudioStream):
+ def __init__(self, path: str):
+ fmt = None
+ if path.endswith(".mp3"):
+ fmt = "mp3"
+ elif path.endswith(".wav"):
+ fmt = "wav"
+ else:
+ raise NotImplementedError(f"Requested file type {path} " + \
+ "is not supported")
+ print(f"Loading audio data", file=sys.stderr)
+ audio = AudioSegment.from_file(path, format=fmt)
+ audio = audio.set_channels(1)
+ audio = audio.set_frame_rate(16000)
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ self.frames = frames
+ self.fps = 16000
+
+ def getSamples(self) -> bytes:
+ frames = self.frames
+ self.frames = b''
+ return frames
+
+ if len(frames) < nframes:
+ frames += np.zeros(nframes - len(frames), dtype=np.int16).tobytes()
+
+ return frames
+
class AudioCollector:
def __init__(self, stream: AudioStream):
self.stream = stream
@@ -365,7 +396,7 @@ class AppControl:
run = True
app_ctrl = AppControl()
-def recordMeDaddy(
+def recordAudio(
mic_device: str,
min_volume: float = -1.3,
max_volume: float = -0.8
@@ -383,7 +414,7 @@ def recordMeDaddy(
#collector_hd = NormalizingAudioCollector(collector_hd)
collector_hd = CompressingAudioCollector(collector_hd)
- min_silence_ms = 1000
+ min_silence_ms = 250
max_speech_s = 30
segmenter = AudioSegmenter(
min_silence_ms=min_silence_ms,
@@ -438,6 +469,104 @@ def recordMeDaddy(
collector_hd.keepLast(1.0)
print("Stopped recording")
+class Segment:
+ def __init__(self,
+ transcript: str,
+ start_ts: float,
+ end_ts: float,
+ wall_ts: float,
+ avg_logprob: float,
+ no_speech_prob: float,
+ compression_ratio: float):
+ self.transcript = transcript
+ # start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
+ self.start_ts = start_ts
+ self.end_ts = end_ts
+ # wall_ts is the time.time() at which the oldest audio sample leading
+ # to this transcript was collected.
+ self.wall_ts = wall_ts
+ self.avg_logprob = avg_logprob
+ self.no_speech_prob = no_speech_prob
+ self.compression_ratio = compression_ratio
+
+ def __str__(self):
+ ts = f"(ts: {self.start_ts}-{self.end_ts}) "
+
+ wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) "
+
+ no_speech = f"(no_speech: {self.no_speech_prob}) "
+ avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+
+def pipInstall(pkgs: typing.List[str]) -> bool:
+ pkgs_str = " ".join(pkgs)
+ print(f"Installing {pkgs_str}")
+ env = os.environ.copy()
+ # cwd is set at top of __main__. We set PATH to ensure that installed
+ # Python packages have access to any binaries that come with them.
+ env["PATH"] = os.getcwd() + "/Python/Scripts;" + env['PATH']
+ pip_proc = subprocess.Popen(
+ f"./Python/python.exe -m pip install {pkgs_str} --no-warn-script-location".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env=env)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout, file=sys.stderr)
+ print(pip_stderr, file=sys.stderr)
+ if pip_proc.returncode != 0:
+ print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}",
+ file=sys.stderr)
+ return False
+ return True
+
+class Whisper:
+ def __init__(self,
+ collector: AudioCollector):
+ self.collector = collector
+
+ import torch
+ from transformers import pipeline
+
+ whisper_model = "openai/whisper-large-v2"
+ print(f"Loading pipeline for {whisper_model}...")
+ self.pipe = pipeline(
+ "automatic-speech-recognition",
+ model="distil-whisper/distil-large-v2",
+ torch_dtype=torch.float16,
+ device="cuda",
+ )
+ print(f"Done.")
+
+ def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
+ if frames is None:
+ frames = self.collector.getAudio()
+ # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on
+ # [-1, 1].
+ audio = np.frombuffer(frames,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+
+ t0 = time.time()
+ res = self.pipe(
+ audio,
+ chunk_length_s=30,
+ batch_size=1)
+
+ result = [Segment(res["text"],
+ 0,
+ 0,
+ self.collector.begin(),
+ 0,
+ 0,
+ 0)]
+
+ t1 = time.time()
+ print(f"Transcription latency (s): {t1 - t0}: {result[0].transcript}")
+ return result
+
def getOutput() -> str:
sys.stdout.flush()
with open("output.log", "r") as f:
@@ -447,6 +576,37 @@ def stopApp():
print("Requesting app stop")
app_ctrl.run = False
+def transcribeAudio(concatenated_path: str):
+ # Step 1: Install Whisper requirements
+ with open("whisper_requirements.txt", "r") as file:
+ requirements = file.read().splitlines()
+ if not pipInstall(requirements):
+ return
+
+ # Step 2: Iterate over .wav files in the current working directory
+ whisper = Whisper(None)
+ for wav_file in os.listdir('.'):
+ if wav_file.endswith('.wav'):
+ if wav_file.endswith(os.path.basename(concatenated_path)):
+ print("Skipping concatenated file")
+ continue
+ # Step 3: Transcription pipeline
+ # TODO parameterize high fidelity framerate
+ print(f"Transcribing {wav_file}")
+ disk_stream = DiskStream(wav_file)
+ collector = CompressingAudioCollector(AudioCollector(disk_stream))
+ whisper.collector = collector
+
+ # Transcribe the audio
+ segments = whisper.transcribe()
+
+ # Step 4: Save transcriptions
+ transcript_filename = wav_file.replace('.wav', '.txt')
+ with open(transcript_filename, 'w') as txt_file:
+ for segment in segments:
+ txt_file.write(segment.transcript + '\n')
+ print(f"Transcript generated at {transcript_filename}")
+
if __name__ == "__main__":
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
@@ -463,15 +623,19 @@ if __name__ == "__main__":
max_volume = gr.Number(label="Maximum volume", value=-0.8)
record_audio = gr.Button("Record audio")
stop_recording = gr.Button("Stop recording")
+ transcribe_audio = gr.Button("Transcribe audio")
concatenated_path = gr.Text(label="Combined audio filename", value="combined.wav")
min_length = gr.Number(label="Minimum length (seconds)", value=3.0)
concatenate_audio = gr.Button("Combine audio files")
dbg_output = gr.Text(label="Output")
- record_audio.click(recordMeDaddy, [mic_device, min_volume, max_volume],
+ record_audio.click(recordAudio, [mic_device, min_volume, max_volume],
dbg_output)
stop_recording.click(stopApp, [], dbg_output)
+
+ transcribe_audio.click(transcribeAudio, [concatenated_path], dbg_output)
+
concatenate_audio.click(concatenateWavFiles, [concatenated_path],
dbg_output)
diff --git a/curate/Makefile b/curate/Makefile
new file mode 100644
index 0000000..d4e63dc
--- /dev/null
+++ b/curate/Makefile
@@ -0,0 +1,25 @@
+CC=g++
+CFLAGS=-c -O2 -std=c++20
+LDFLAGS=-lcurses
+
+EXE=ui
+SRCS=ui.cc
+OBJS=$(SRCS:.cc=.o)
+HDRS=
+
+.PHONY: all
+all: $(EXE)
+
+$(EXE): $(OBJS)
+ $(CC) $^ $(LDFLAGS) -o $@
+
+%.o: %.cc %.h
+ $(CC) $(CFLAGS) $< -o $@
+
+%.o: %.cc
+ $(CC) $(CFLAGS) $< -o $@
+
+.PHONY: clean
+clean:
+ @rm -f $(OBJS) $(EXE)
+
diff --git a/curate/ui.cc b/curate/ui.cc
new file mode 100644
index 0000000..222bfe4
--- /dev/null
+++ b/curate/ui.cc
@@ -0,0 +1,156 @@
+#include <curses.h>
+#include <ncurses.h>
+#include <stdio.h>
+
+#include <filesystem>
+#include <fstream>
+#include <iostream>
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+typedef std::pair<std::string, std::string> datapoint_t;
+
+const int PAGE_LINES = 40;
+const int TRANSCRIPT_CHARS = 120;
+
+void getData(
+ const std::filesystem::path& data_path,
+ std::vector<datapoint_t> &datapoints,
+ std::map<std::string, std::string> &transcripts) {
+ datapoints.clear();
+ transcripts.clear();
+ printw("Scanning for files at %s\n", data_path.string().c_str());
+ for (const auto& entry : std::filesystem::directory_iterator(data_path)) {
+ //printw(" Checking file %s\n", entry.path().string().c_str());
+ if (entry.is_regular_file()) {
+ std::filesystem::path filepath = entry.path();
+ std::string filename = filepath.stem().string();
+
+ if (filepath.extension() == ".wav") {
+ std::filesystem::path txt_file = filepath.replace_extension(".txt");
+ if (std::filesystem::exists(txt_file)) {
+ datapoints.emplace_back(filepath.string(), txt_file.string());
+ std::ifstream fileStream(txt_file);
+ std::stringstream buffer;
+ buffer << fileStream.rdbuf();
+ std::string contents = buffer.str();
+ contents.erase(std::remove(contents.begin(), contents.end(), '\n'), contents.cend());
+ contents.erase(std::remove(contents.begin(), contents.end(), '\r'), contents.cend());
+ contents = contents.substr(0, TRANSCRIPT_CHARS);
+ transcripts[txt_file.string()] = contents;
+ }
+ }
+ }
+ }
+}
+
+
+int main(int argc, char* argv[]) {
+ const std::filesystem::path cwd = std::filesystem::current_path();
+ std::filesystem::path data_path = std::filesystem::current_path();
+ if (argc == 2) {
+ data_path = std::filesystem::path(argv[1]);
+ }
+
+ // Initialize ncurses
+ initscr();
+ cbreak();
+ noecho();
+ keypad(stdscr, TRUE);
+
+ // Clear the screen and wait for 'q' or 'x'
+ bool run = true;
+ bool redraw = true;
+ int idx = 0;
+ int page_offset = 0;
+
+ std::vector<datapoint_t> datapoints;
+ std::map<std::string, std::string> transcripts;
+
+ std::string digits;
+ while (run) {
+ clear();
+ {
+ int cur_idx = 0;
+ getData(data_path, datapoints, transcripts);
+ for (const auto& [txt_path, transcript] : transcripts) {
+ if (cur_idx < page_offset * PAGE_LINES) {
+ ++cur_idx;
+ continue;
+ }
+
+ char selector = ((cur_idx % PAGE_LINES) == idx) ? '>' : ' ';
+ printw("%02d %c %s: %s\n", (cur_idx % PAGE_LINES), selector, txt_path.c_str(), transcript.c_str());
+ ++cur_idx;
+
+ if (cur_idx >= (page_offset + 1) * PAGE_LINES) {
+ break;
+ }
+ }
+ }
+ refresh();
+
+ int ch = getch();
+ if (ch == 'q') {
+ run = false;
+ continue;
+ } else if (ch == 'j') {
+ int step_sz = 1;
+ if (digits.size() > 0) {
+ step_sz = std::atoi(digits.c_str());
+ digits.clear();
+ }
+
+ idx += step_sz;
+ idx = std::min(PAGE_LINES - 1, idx);
+ } else if (ch == 'k') {
+ int step_sz = 1;
+ if (digits.size() > 0) {
+ step_sz = std::atoi(digits.c_str());
+ digits.clear();
+ }
+
+ idx -= step_sz;
+ idx = std::max(0, idx);
+ } else if (ch == KEY_NPAGE) {
+ ++page_offset;
+ } else if (ch == KEY_PPAGE) {
+ --page_offset;
+ page_offset = std::max(0, page_offset);
+ } else if (ch == 'x') {
+ int cur_idx = 0;
+ for (const auto& [txt_path, transcript] : transcripts) {
+ if (cur_idx != page_offset * PAGE_LINES + idx) {
+ ++cur_idx;
+ continue;
+ }
+ std::filesystem::path wav_file = std::filesystem::path(txt_path).replace_extension(".wav");
+ std::filesystem::remove(txt_path);
+ std::filesystem::remove(wav_file);
+ break;
+ }
+ } else if (ch >= '0' && ch <= '9') {
+ digits += ch;
+ } else if (ch == 'g') {
+ int target = idx;
+ if (digits.size() > 0) {
+ target = std::atoi(digits.c_str());
+ digits.clear();
+ }
+
+ idx = target;
+ idx = std::min(PAGE_LINES - 1, idx);
+ idx = std::max(0, idx);
+ } else if (ch == 27) { // ASCII value of esc key
+ digits.clear();
+ }
+ }
+
+ // End ncurses mode
+ endwin();
+
+ return 0;
+}
+
diff --git a/whisper_requirements.txt b/whisper_requirements.txt
new file mode 100644
index 0000000..e7547f3
--- /dev/null
+++ b/whisper_requirements.txt
@@ -0,0 +1,3 @@
+transformers==4.35.2
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch