From c2bc70c18d2fd1c3601b32f2a93b3b4a704786a5 Mon Sep 17 00:00:00 2001 From: yum Date: Mon, 18 Sep 2023 21:00:56 -0700 Subject: Reimplement BrowserSource as a StreamingPlugin BrowserSource now fades text out continuously over time. TODO * Delete C++ webserver, browsersource, transcript code * Add UI for text age fading --- BrowserSource/index.html | 81 +++++++++++++++----------- GUI/GUI/GUI/Frame.cpp | 2 + Scripts/browser_src.py | 128 +++++++++++++++++++++++++++++++++++++++++ Scripts/transcribe_pipeline.py | 28 +++++++++ Scripts/transcribe_v2.py | 26 +-------- 5 files changed, 207 insertions(+), 58 deletions(-) create mode 100644 Scripts/browser_src.py create mode 100644 Scripts/transcribe_pipeline.py diff --git a/BrowserSource/index.html b/BrowserSource/index.html index 28053e3..216e013 100644 --- a/BrowserSource/index.html +++ b/BrowserSource/index.html @@ -60,42 +60,53 @@
- - $('#content').html(contentHtml); - $('#content').css("background-color", "#22222280"); - }, - error: function(jqXHR, textStatus, errorThrown) { - console.error('Error getting transcript: ', textStatus, errorThrown); - } - }); - scrollToBottom(); - } - setInterval(getTranscript, /*interval_ms=*/100); - diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp index 30ccbde..23ac38c 100644 --- a/GUI/GUI/GUI/Frame.cpp +++ b/GUI/GUI/GUI/Frame.cpp @@ -2492,6 +2492,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { EnsureVirtualEnv(/*block=*/true); }; +#if 0 obs_app_ = std::async(std::launch::async, [this, enable_browser_src, browser_src_port]() -> bool { if (enable_browser_src) { @@ -2500,6 +2501,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { } return true; }); +#endif const std::string config_path(AppConfig::kConfigPath); py_app_ = std::move(PythonWrapper::StartApp(*app_c_, config_path, transcribe_out_, diff --git a/Scripts/browser_src.py b/Scripts/browser_src.py new file mode 100644 index 0000000..befb2db --- /dev/null +++ b/Scripts/browser_src.py @@ -0,0 +1,128 @@ +from transcribe_pipeline import StreamingPlugin, TranscriptCommit +from urllib.parse import urlparse + +import copy +import json +import http.server +import os +import socketserver +import threading +import time +import transcribe_pipeline +import typing + +class HTTPServer: + def __init__(self, port: int): + self.port = port + self.route_map = {} + self.httpd = None + + def register_file_handler(self, http_method: str, path: str, file_path: str): + print(f"File handler registered at {os.getcwd()}") + def handler(): + if os.path.exists(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + return 200, f.read().replace('%PORT%', str(self.port)), 'text/html' + else: + return 404, {'error': 'file not found'}, 'application/json' + self.route_map[(http_method, path)] = handler + + def register_json_handler(self, http_method: str, path: str, handler): + self.route_map[(http_method, path)] = handler + + def run(self): + def handler(*args, **kwargs): + MyHandler(http_server_instance=self, *args, **kwargs) + + with socketserver.TCPServer(("", self.port), handler) as httpd: + self.httpd = httpd + print(f"Webserver running at port {self.port}") + httpd.serve_forever() + print(f"Webserver exiting") + self.httpd = None + + def stop(self): + if self.httpd: + self.httpd.shutdown() + + +class MyHandler(http.server.BaseHTTPRequestHandler): + def __init__(self, *args, http_server_instance=None, **kwargs): + self.http_server_instance = http_server_instance + super().__init__(*args, **kwargs) + + def do_GET(self): + self.handle_request('GET') + + def handle_request(self, method: str): + parsed_path = urlparse(self.path) + if (method, parsed_path.path) in self.http_server_instance.route_map: + status_code, response_content, content_type = \ + self.http_server_instance.route_map[(method, parsed_path.path)]() + self.send_response(status_code) + self.send_header('Content-Type', content_type) + self.end_headers() + if content_type == 'application/json': + self.wfile.write(json.dumps(response_content).encode('utf-8')) + else: + self.wfile.write(response_content.encode('utf-8')) + else: + self.send_response(404) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'error': 'not found'}).encode('utf-8')) + + +class BrowserSource(StreamingPlugin): + def __init__(self, cfg: typing.Dict): + port = cfg["browser_src_port"] + print(f"Browser source running on port {port}") + self.commits = [] + self.preview_commit = None + self.http_server = HTTPServer(port) + self.http_server.register_json_handler('GET', '/api/v0/transcript', self.get_transcript_json) + + index_html_path = os.path.join("Resources", "BrowserSource", "index.html") + self.http_server.register_file_handler('GET', '/', index_html_path) + self.http_server.register_file_handler('GET', '/index.html', index_html_path) + + # Start the HTTP server in a new thread + self.server_thread = threading.Thread(target=self.run) + self.server_thread.start() + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + original_commit = commit + commit = copy.deepcopy(original_commit) + del commit.audio + if commit.delta: + self.commits.append(commit) + self.preview_commit = commit + return original_commit + + # return (http_code, body, content_type) + def get_transcript_json(self) -> typing.Tuple[int, str, str]: + processed_commits = [vars(commit) for commit in self.commits] + transcript_data = { + 'commits': processed_commits, + 'preview': vars(self.preview_commit) if self.preview_commit else None, + 'ts': time.time() + } + return 200, json.dumps(transcript_data), 'text/json' + + def run(self): + self.http_server.run() + + def stop(self): + self.http_server.stop() + self.server_thread.join() + + +# Example usage +def my_callback() -> typing.Tuple[int, typing.Dict[str, str]]: + return 200, {'message': 'Hello, world!'}, 'text/json' + +if __name__ == '__main__': + server = HTTPServer(port=8080) + server.register_json_handler('GET', '/api/v0/transcript', my_callback) + server.run() + diff --git a/Scripts/transcribe_pipeline.py b/Scripts/transcribe_pipeline.py new file mode 100644 index 0000000..3f48b08 --- /dev/null +++ b/Scripts/transcribe_pipeline.py @@ -0,0 +1,28 @@ +import time + + +class TranscriptCommit: + def __init__(self, + delta: str, + preview: str, + latency_s: int = None, + thresh_at_commit: int = None, + audio: bytes = None): + self.delta = delta + self.preview = preview + self.latency_s = latency_s + self.thresh_at_commit = thresh_at_commit + self.audio = audio + self.ts = time.time() + + +class StreamingPlugin: + def __init__(self): + pass + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + return commit + + def stop(self): + pass + diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index e8d7ef6..2bf605d 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -1,3 +1,4 @@ +from browser_src import BrowserSource from datetime import datetime from emotes_v2 import EmotesState from faster_whisper import WhisperModel @@ -5,6 +6,7 @@ from functools import partial from profanity_filter import ProfanityFilter from pydub import AudioSegment from sentence_splitter import split_text_into_sentences +from transcribe_pipeline import StreamingPlugin, TranscriptCommit import app_config import argparse @@ -458,19 +460,6 @@ class Whisper: s.avg_logprob, s.no_speech_prob, s.compression_ratio)) return res -class TranscriptCommit: - def __init__(self, - delta: str, - preview: str, - latency_s: int = None, - thresh_at_commit: int = None, - audio: bytes = None): - self.delta = delta - self.preview = preview - self.latency_s = latency_s - self.thresh_at_commit = thresh_at_commit - self.audio = audio - def saveAudio(audio: bytes, path: str): with wave.open(path, 'wb') as wf: print(f"Saving audio to {path}", file=sys.stderr) @@ -545,16 +534,6 @@ def install_in_venv(pkgs: typing.List[str]) -> bool: print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}", file=sys.stderr) -class StreamingPlugin: - def __init__(self): - pass - - def transform(self, commit: TranscriptCommit) -> TranscriptCommit: - return commit - - def stop(self): - pass - class TranslationPlugin(StreamingPlugin): def __init__(self, cfg): lang_bits = cfg["language_target"].split(" | ") @@ -1166,6 +1145,7 @@ def run(cfg): ctrl.plugins.append(LowercasePlugin(cfg)) ctrl.plugins.append(ProfanityPlugin(cfg)) ctrl.plugins.append(UwuPlugin(cfg)) + ctrl.plugins.append(BrowserSource(cfg)) ctrl.filters = [] ctrl.filters.append(TrailingPeriodFilter(cfg)) -- cgit v1.2.3