From a4c1870f724f18e98c33468b4d038dd1c742e4bd Mon Sep 17 00:00:00 2001 From: yum Date: Tue, 26 Dec 2023 02:18:58 -0800 Subject: Add optional transcription & curation components Add a transcribe button, which transcribes each .wav file using openai/whisper-large-v2, producing a corresponding .txt file. Also add a TUI tool for WSL. This tool lets you view transcripts and delete them with vi-like commands. Useful for cleaning data. --- app.py | 170 ++++++++++++++++++++++++++++++++++++++++++++++- curate/Makefile | 25 +++++++ curate/ui.cc | 156 +++++++++++++++++++++++++++++++++++++++++++ whisper_requirements.txt | 3 + 4 files changed, 351 insertions(+), 3 deletions(-) create mode 100644 curate/Makefile create mode 100644 curate/ui.cc create mode 100644 whisper_requirements.txt diff --git a/app.py b/app.py index 064aef3..21fbbf7 100644 --- a/app.py +++ b/app.py @@ -6,6 +6,7 @@ import math import numpy as np import os import pyaudio +import subprocess import sys import time import typing @@ -157,6 +158,36 @@ class MicStream(AudioStream): result = b''.join(chunks) return result +class DiskStream(AudioStream): + def __init__(self, path: str): + fmt = None + if path.endswith(".mp3"): + fmt = "mp3" + elif path.endswith(".wav"): + fmt = "wav" + else: + raise NotImplementedError(f"Requested file type {path} " + \ + "is not supported") + print(f"Loading audio data", file=sys.stderr) + audio = AudioSegment.from_file(path, format=fmt) + audio = audio.set_channels(1) + audio = audio.set_frame_rate(16000) + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + self.frames = frames + self.fps = 16000 + + def getSamples(self) -> bytes: + frames = self.frames + self.frames = b'' + return frames + + if len(frames) < nframes: + frames += np.zeros(nframes - len(frames), dtype=np.int16).tobytes() + + return frames + class AudioCollector: def __init__(self, stream: AudioStream): self.stream = stream @@ -365,7 +396,7 @@ class AppControl: run = True app_ctrl = AppControl() -def recordMeDaddy( +def recordAudio( mic_device: str, min_volume: float = -1.3, max_volume: float = -0.8 @@ -383,7 +414,7 @@ def recordMeDaddy( #collector_hd = NormalizingAudioCollector(collector_hd) collector_hd = CompressingAudioCollector(collector_hd) - min_silence_ms = 1000 + min_silence_ms = 250 max_speech_s = 30 segmenter = AudioSegmenter( min_silence_ms=min_silence_ms, @@ -438,6 +469,104 @@ def recordMeDaddy( collector_hd.keepLast(1.0) print("Stopped recording") +class Segment: + def __init__(self, + transcript: str, + start_ts: float, + end_ts: float, + wall_ts: float, + avg_logprob: float, + no_speech_prob: float, + compression_ratio: float): + self.transcript = transcript + # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. + self.start_ts = start_ts + self.end_ts = end_ts + # wall_ts is the time.time() at which the oldest audio sample leading + # to this transcript was collected. + self.wall_ts = wall_ts + self.avg_logprob = avg_logprob + self.no_speech_prob = no_speech_prob + self.compression_ratio = compression_ratio + + def __str__(self): + ts = f"(ts: {self.start_ts}-{self.end_ts}) " + + wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " + + no_speech = f"(no_speech: {self.no_speech_prob}) " + avg_logprob = f"(avg_logprob: {self.avg_logprob}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + +def pipInstall(pkgs: typing.List[str]) -> bool: + pkgs_str = " ".join(pkgs) + print(f"Installing {pkgs_str}") + env = os.environ.copy() + # cwd is set at top of __main__. We set PATH to ensure that installed + # Python packages have access to any binaries that come with them. + env["PATH"] = os.getcwd() + "/Python/Scripts;" + env['PATH'] + pip_proc = subprocess.Popen( + f"./Python/python.exe -m pip install {pkgs_str} --no-warn-script-location".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout, file=sys.stderr) + print(pip_stderr, file=sys.stderr) + if pip_proc.returncode != 0: + print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}", + file=sys.stderr) + return False + return True + +class Whisper: + def __init__(self, + collector: AudioCollector): + self.collector = collector + + import torch + from transformers import pipeline + + whisper_model = "openai/whisper-large-v2" + print(f"Loading pipeline for {whisper_model}...") + self.pipe = pipeline( + "automatic-speech-recognition", + model="distil-whisper/distil-large-v2", + torch_dtype=torch.float16, + device="cuda", + ) + print(f"Done.") + + def transcribe(self, frames: bytes = None) -> typing.List[Segment]: + if frames is None: + frames = self.collector.getAudio() + # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on + # [-1, 1]. + audio = np.frombuffer(frames, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + + t0 = time.time() + res = self.pipe( + audio, + chunk_length_s=30, + batch_size=1) + + result = [Segment(res["text"], + 0, + 0, + self.collector.begin(), + 0, + 0, + 0)] + + t1 = time.time() + print(f"Transcription latency (s): {t1 - t0}: {result[0].transcript}") + return result + def getOutput() -> str: sys.stdout.flush() with open("output.log", "r") as f: @@ -447,6 +576,37 @@ def stopApp(): print("Requesting app stop") app_ctrl.run = False +def transcribeAudio(concatenated_path: str): + # Step 1: Install Whisper requirements + with open("whisper_requirements.txt", "r") as file: + requirements = file.read().splitlines() + if not pipInstall(requirements): + return + + # Step 2: Iterate over .wav files in the current working directory + whisper = Whisper(None) + for wav_file in os.listdir('.'): + if wav_file.endswith('.wav'): + if wav_file.endswith(os.path.basename(concatenated_path)): + print("Skipping concatenated file") + continue + # Step 3: Transcription pipeline + # TODO parameterize high fidelity framerate + print(f"Transcribing {wav_file}") + disk_stream = DiskStream(wav_file) + collector = CompressingAudioCollector(AudioCollector(disk_stream)) + whisper.collector = collector + + # Transcribe the audio + segments = whisper.transcribe() + + # Step 4: Save transcriptions + transcript_filename = wav_file.replace('.wav', '.txt') + with open(transcript_filename, 'w') as txt_file: + for segment in segments: + txt_file.write(segment.transcript + '\n') + print(f"Transcript generated at {transcript_filename}") + if __name__ == "__main__": abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) @@ -463,15 +623,19 @@ if __name__ == "__main__": max_volume = gr.Number(label="Maximum volume", value=-0.8) record_audio = gr.Button("Record audio") stop_recording = gr.Button("Stop recording") + transcribe_audio = gr.Button("Transcribe audio") concatenated_path = gr.Text(label="Combined audio filename", value="combined.wav") min_length = gr.Number(label="Minimum length (seconds)", value=3.0) concatenate_audio = gr.Button("Combine audio files") dbg_output = gr.Text(label="Output") - record_audio.click(recordMeDaddy, [mic_device, min_volume, max_volume], + record_audio.click(recordAudio, [mic_device, min_volume, max_volume], dbg_output) stop_recording.click(stopApp, [], dbg_output) + + transcribe_audio.click(transcribeAudio, [concatenated_path], dbg_output) + concatenate_audio.click(concatenateWavFiles, [concatenated_path], dbg_output) diff --git a/curate/Makefile b/curate/Makefile new file mode 100644 index 0000000..d4e63dc --- /dev/null +++ b/curate/Makefile @@ -0,0 +1,25 @@ +CC=g++ +CFLAGS=-c -O2 -std=c++20 +LDFLAGS=-lcurses + +EXE=ui +SRCS=ui.cc +OBJS=$(SRCS:.cc=.o) +HDRS= + +.PHONY: all +all: $(EXE) + +$(EXE): $(OBJS) + $(CC) $^ $(LDFLAGS) -o $@ + +%.o: %.cc %.h + $(CC) $(CFLAGS) $< -o $@ + +%.o: %.cc + $(CC) $(CFLAGS) $< -o $@ + +.PHONY: clean +clean: + @rm -f $(OBJS) $(EXE) + diff --git a/curate/ui.cc b/curate/ui.cc new file mode 100644 index 0000000..222bfe4 --- /dev/null +++ b/curate/ui.cc @@ -0,0 +1,156 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +typedef std::pair datapoint_t; + +const int PAGE_LINES = 40; +const int TRANSCRIPT_CHARS = 120; + +void getData( + const std::filesystem::path& data_path, + std::vector &datapoints, + std::map &transcripts) { + datapoints.clear(); + transcripts.clear(); + printw("Scanning for files at %s\n", data_path.string().c_str()); + for (const auto& entry : std::filesystem::directory_iterator(data_path)) { + //printw(" Checking file %s\n", entry.path().string().c_str()); + if (entry.is_regular_file()) { + std::filesystem::path filepath = entry.path(); + std::string filename = filepath.stem().string(); + + if (filepath.extension() == ".wav") { + std::filesystem::path txt_file = filepath.replace_extension(".txt"); + if (std::filesystem::exists(txt_file)) { + datapoints.emplace_back(filepath.string(), txt_file.string()); + std::ifstream fileStream(txt_file); + std::stringstream buffer; + buffer << fileStream.rdbuf(); + std::string contents = buffer.str(); + contents.erase(std::remove(contents.begin(), contents.end(), '\n'), contents.cend()); + contents.erase(std::remove(contents.begin(), contents.end(), '\r'), contents.cend()); + contents = contents.substr(0, TRANSCRIPT_CHARS); + transcripts[txt_file.string()] = contents; + } + } + } + } +} + + +int main(int argc, char* argv[]) { + const std::filesystem::path cwd = std::filesystem::current_path(); + std::filesystem::path data_path = std::filesystem::current_path(); + if (argc == 2) { + data_path = std::filesystem::path(argv[1]); + } + + // Initialize ncurses + initscr(); + cbreak(); + noecho(); + keypad(stdscr, TRUE); + + // Clear the screen and wait for 'q' or 'x' + bool run = true; + bool redraw = true; + int idx = 0; + int page_offset = 0; + + std::vector datapoints; + std::map transcripts; + + std::string digits; + while (run) { + clear(); + { + int cur_idx = 0; + getData(data_path, datapoints, transcripts); + for (const auto& [txt_path, transcript] : transcripts) { + if (cur_idx < page_offset * PAGE_LINES) { + ++cur_idx; + continue; + } + + char selector = ((cur_idx % PAGE_LINES) == idx) ? '>' : ' '; + printw("%02d %c %s: %s\n", (cur_idx % PAGE_LINES), selector, txt_path.c_str(), transcript.c_str()); + ++cur_idx; + + if (cur_idx >= (page_offset + 1) * PAGE_LINES) { + break; + } + } + } + refresh(); + + int ch = getch(); + if (ch == 'q') { + run = false; + continue; + } else if (ch == 'j') { + int step_sz = 1; + if (digits.size() > 0) { + step_sz = std::atoi(digits.c_str()); + digits.clear(); + } + + idx += step_sz; + idx = std::min(PAGE_LINES - 1, idx); + } else if (ch == 'k') { + int step_sz = 1; + if (digits.size() > 0) { + step_sz = std::atoi(digits.c_str()); + digits.clear(); + } + + idx -= step_sz; + idx = std::max(0, idx); + } else if (ch == KEY_NPAGE) { + ++page_offset; + } else if (ch == KEY_PPAGE) { + --page_offset; + page_offset = std::max(0, page_offset); + } else if (ch == 'x') { + int cur_idx = 0; + for (const auto& [txt_path, transcript] : transcripts) { + if (cur_idx != page_offset * PAGE_LINES + idx) { + ++cur_idx; + continue; + } + std::filesystem::path wav_file = std::filesystem::path(txt_path).replace_extension(".wav"); + std::filesystem::remove(txt_path); + std::filesystem::remove(wav_file); + break; + } + } else if (ch >= '0' && ch <= '9') { + digits += ch; + } else if (ch == 'g') { + int target = idx; + if (digits.size() > 0) { + target = std::atoi(digits.c_str()); + digits.clear(); + } + + idx = target; + idx = std::min(PAGE_LINES - 1, idx); + idx = std::max(0, idx); + } else if (ch == 27) { // ASCII value of esc key + digits.clear(); + } + } + + // End ncurses mode + endwin(); + + return 0; +} + diff --git a/whisper_requirements.txt b/whisper_requirements.txt new file mode 100644 index 0000000..e7547f3 --- /dev/null +++ b/whisper_requirements.txt @@ -0,0 +1,3 @@ +transformers==4.35.2 +--extra-index-url https://download.pytorch.org/whl/cu121 +torch -- cgit v1.2.3