From 2dc2f63686fc0137931f675f579d3e528861433d Mon Sep 17 00:00:00 2001 From: yum Date: Sun, 10 Sep 2023 14:45:45 -0700 Subject: Add plugin interface ... and use it to implement translation and text filters. Also fix display of non-English characters in browser src. --- Scripts/transcribe_v2.py | 190 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 188 insertions(+), 2 deletions(-) diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index c73c13a..6eba3e8 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -1,25 +1,31 @@ from datetime import datetime +from emotes_v2 import EmotesState from faster_whisper import WhisperModel from functools import partial +from profanity_filter import ProfanityFilter from pydub import AudioSegment -from whisper.normalizers import EnglishTextNormalizer from scipy.optimize import minimize -from emotes_v2 import EmotesState +from sentence_splitter import split_text_into_sentences +from whisper.normalizers import EnglishTextNormalizer import app_config import argparse +import ctranslate2 import editdistance import keybind_event_machine import langcodes +import lang_compat import math import numpy as np import os import osc_ctrl import pyaudio import steamvr +import subprocess import sys import threading import time +import transformers import typing import vad import wave @@ -524,6 +530,172 @@ class VadCommitter: latency_s, audio=audio) +class Plugin: + def __init__(self): + pass + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + return commit + + def stop(): + pass + +class TranslationPlugin(Plugin): + def __init__(self, cfg): + lang_bits = cfg["language_target"].split(" | ") + self.cfg = cfg + self.language_target = None + self.translator = None + self.tokenizer = None + if len(lang_bits) != 2: + return + self.language_target = lang_bits[1] + + print("Translation requested", file=sys.stderr) + print("Installing torch and sentencepiece in virtual environment. " + "Nothing will print " + "for several minutes while these download (~2.4 GB).", + file=sys.stderr) + + pip_proc = subprocess.Popen( + "Resources/Python/python.exe -m pip install sentencepiece torch".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout, file=sys.stderr) + print(pip_stderr, file=sys.stderr) + if pip_proc.returncode != 0: + print(f"Failed to set up for translation: `pip install torch` " + "exited with {pip_proc.returncode}", file=sys.stderr) + + output_dir = "Resources/" + cfg["model_translation"] + # Provided by ctranslate2 Python package + cmd = "ct2-transformers-converter.exe --model facebook/" + \ + cfg["model_translation"] + " --output_dir " + output_dir + + print(f"Fetching translation algorithm ({cfg['model_translation']})", + file=sys.stderr) + if not os.path.exists(output_dir): + ct2_proc = subprocess.Popen( + cmd.split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + ct2_stdout, ct2_stderr = ct2_proc.communicate() + ct2_stdout = ct2_stdout.decode("utf-8") + ct2_stderr = ct2_stderr.decode("utf-8") + print(ct2_stdout, file=sys.stderr) + print(ct2_stderr, file=sys.stderr) + if ct2_proc.returncode != 0: + print(f"Failed to get NLLB model: ct2 process exited with " + "{ct2_proc.returncode}", file=sys.stderr) + print(f"Using model at {output_dir}", file=sys.stderr) + + self.translator = ctranslate2.Translator(output_dir) + + whisper_lang = cfg["language"] + nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] + + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/" + cfg["model_translation"], + src_lang=nllb_lang) + + print(f"Translation ready to go", file=sys.stderr) + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if not self.language_target: + return commit + + def _translate_text(text: str) -> str: + + whisper_lang = self.cfg["language"] + nllb_lang = lang_compat.whisper_to_nllb[whisper_lang] + ss_lang = lang_compat.nllb_to_ss[nllb_lang] + sentences = split_text_into_sentences(text, language=ss_lang) + + translated_sentences = [] + for sentence in sentences: + source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(sentence)) + target_prefix = [self.language_target] + results = self.translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + translated_sentence = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(target)) + translated_sentences.append(translated_sentence) + translated = " ".join(translated_sentences) + return translated + + commit.delta = _translate_text(commit.delta) + commit.preview = _translate_text(commit.preview) + return commit + +class LowercasePlugin(Plugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_lowercase_filter"]: + commit.delta = commit.delta.lower() + commit.preview = commit.preview.lower() + return commit + +class UppercasePlugin(Plugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_uppercase_filter"]: + commit.delta = commit.delta.upper() + commit.preview = commit.preview.upper() + return commit + +class TrailingPeriodPlugin(Plugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + commit.delta = _remove_trailing_period(commit.delta) + commit.preview = _remove_trailing_period(commit.preview) + return commit + +class UwuPlugin(Plugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_uwu_filter"]: + def _to_uwu(s: str) -> str: + uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", s], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + uwu_stdout, uwu_stderr = uwu_proc.communicate() + uwu_text = uwu_stdout.decode("utf-8") + uwu_text = uwu_text.replace("\n", "") + uwu_text = uwu_text.replace("\r", "") + return uwu_text + commit.delta = _to_uwu(commit.delta) + commit.preview = _to_uwu(commit.preview) + return commit + +class ProfanityPlugin(Plugin): + def __init__(self, cfg): + self.cfg = cfg + en_profanity_path = os.path.abspath("Resources/Profanity/en") + self.filter = ProfanityFilter(en_profanity_path) + if cfg["enable_profanity_filter"]: + self.filter.load() + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_profanity_filter"]: + commit.delta = self.filter.filter(commit.delta) + commit.preview = self.filter.filter(commit.preview) + return commit + class OscPager: def __init__(self, cfg): self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"], @@ -678,6 +850,9 @@ def transcriptionThread(ctrl: ThreadControl): commit = ctrl.committer.getDelta() + for plugin in ctrl.plugins: + commit = plugin.transform(commit) + if len(commit.delta) > 0 or len(commit.preview) > 0: # Hard-cap displayed transcript length at 4k characters to prevent # runaway memory use in UI. Keep the full transcript to avoid @@ -702,6 +877,8 @@ def transcriptionThread(ctrl: ThreadControl): ctrl.transcript += commit.delta ctrl.preview = ctrl.transcript + commit.preview + for plugin in ctrl.plugins: + plugin.stop() def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -936,6 +1113,13 @@ def run(cfg): ctrl.collector = collector ctrl.whisper = whisper ctrl.committer = committer + ctrl.plugins = [] + ctrl.plugins.append(TranslationPlugin(cfg)) + ctrl.plugins.append(UppercasePlugin(cfg)) + ctrl.plugins.append(LowercasePlugin(cfg)) + ctrl.plugins.append(TrailingPeriodPlugin(cfg)) + ctrl.plugins.append(ProfanityPlugin(cfg)) + ctrl.plugins.append(UwuPlugin(cfg)) ctrl.pager = pager ctrl.transcript = "" ctrl.preview = "" @@ -973,6 +1157,8 @@ def run(cfg): print("Done", file=sys.stderr) if __name__ == "__main__": + sys.stdout.reconfigure(encoding="utf-8") + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="Path to app config YAML file.") args = parser.parse_args() -- cgit v1.2.3