summaryrefslogtreecommitdiffstats
path: root/Scripts
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts')
-rw-r--r--Scripts/transcribe_v2.py190
1 files changed, 188 insertions, 2 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index c73c13a..6eba3e8 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -1,25 +1,31 @@
from datetime import datetime
+from emotes_v2 import EmotesState
from faster_whisper import WhisperModel
from functools import partial
+from profanity_filter import ProfanityFilter
from pydub import AudioSegment
-from whisper.normalizers import EnglishTextNormalizer
from scipy.optimize import minimize
-from emotes_v2 import EmotesState
+from sentence_splitter import split_text_into_sentences
+from whisper.normalizers import EnglishTextNormalizer
import app_config
import argparse
+import ctranslate2
import editdistance
import keybind_event_machine
import langcodes
+import lang_compat
import math
import numpy as np
import os
import osc_ctrl
import pyaudio
import steamvr
+import subprocess
import sys
import threading
import time
+import transformers
import typing
import vad
import wave
@@ -524,6 +530,172 @@ class VadCommitter:
latency_s,
audio=audio)
+class Plugin:
+ def __init__(self):
+ pass
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ return commit
+
+ def stop():
+ pass
+
+class TranslationPlugin(Plugin):
+ def __init__(self, cfg):
+ lang_bits = cfg["language_target"].split(" | ")
+ self.cfg = cfg
+ self.language_target = None
+ self.translator = None
+ self.tokenizer = None
+ if len(lang_bits) != 2:
+ return
+ self.language_target = lang_bits[1]
+
+ print("Translation requested", file=sys.stderr)
+ print("Installing torch and sentencepiece in virtual environment. "
+ "Nothing will print "
+ "for several minutes while these download (~2.4 GB).",
+ file=sys.stderr)
+
+ pip_proc = subprocess.Popen(
+ "Resources/Python/python.exe -m pip install sentencepiece torch".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout, file=sys.stderr)
+ print(pip_stderr, file=sys.stderr)
+ if pip_proc.returncode != 0:
+ print(f"Failed to set up for translation: `pip install torch` "
+ "exited with {pip_proc.returncode}", file=sys.stderr)
+
+ output_dir = "Resources/" + cfg["model_translation"]
+ # Provided by ctranslate2 Python package
+ cmd = "ct2-transformers-converter.exe --model facebook/" + \
+ cfg["model_translation"] + " --output_dir " + output_dir
+
+ print(f"Fetching translation algorithm ({cfg['model_translation']})",
+ file=sys.stderr)
+ if not os.path.exists(output_dir):
+ ct2_proc = subprocess.Popen(
+ cmd.split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ ct2_stdout, ct2_stderr = ct2_proc.communicate()
+ ct2_stdout = ct2_stdout.decode("utf-8")
+ ct2_stderr = ct2_stderr.decode("utf-8")
+ print(ct2_stdout, file=sys.stderr)
+ print(ct2_stderr, file=sys.stderr)
+ if ct2_proc.returncode != 0:
+ print(f"Failed to get NLLB model: ct2 process exited with "
+ "{ct2_proc.returncode}", file=sys.stderr)
+ print(f"Using model at {output_dir}", file=sys.stderr)
+
+ self.translator = ctranslate2.Translator(output_dir)
+
+ whisper_lang = cfg["language"]
+ nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]
+
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ "facebook/" + cfg["model_translation"],
+ src_lang=nllb_lang)
+
+ print(f"Translation ready to go", file=sys.stderr)
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if not self.language_target:
+ return commit
+
+ def _translate_text(text: str) -> str:
+
+ whisper_lang = self.cfg["language"]
+ nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]
+ ss_lang = lang_compat.nllb_to_ss[nllb_lang]
+ sentences = split_text_into_sentences(text, language=ss_lang)
+
+ translated_sentences = []
+ for sentence in sentences:
+ source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(sentence))
+ target_prefix = [self.language_target]
+ results = self.translator.translate_batch([source], target_prefix=[target_prefix])
+ target = results[0].hypotheses[0][1:]
+ translated_sentence = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(target))
+ translated_sentences.append(translated_sentence)
+ translated = " ".join(translated_sentences)
+ return translated
+
+ commit.delta = _translate_text(commit.delta)
+ commit.preview = _translate_text(commit.preview)
+ return commit
+
+class LowercasePlugin(Plugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_lowercase_filter"]:
+ commit.delta = commit.delta.lower()
+ commit.preview = commit.preview.lower()
+ return commit
+
+class UppercasePlugin(Plugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_uppercase_filter"]:
+ commit.delta = commit.delta.upper()
+ commit.preview = commit.preview.upper()
+ return commit
+
+class TrailingPeriodPlugin(Plugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["remove_trailing_period"]:
+ def _remove_trailing_period(s: str) -> str:
+ if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
+ s = s[0:len(s)-1]
+ return s
+ commit.delta = _remove_trailing_period(commit.delta)
+ commit.preview = _remove_trailing_period(commit.preview)
+ return commit
+
+class UwuPlugin(Plugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_uwu_filter"]:
+ def _to_uwu(s: str) -> str:
+ uwu_proc = subprocess.Popen(["Resources/Uwu/Uwwwu.exe", s],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ uwu_stdout, uwu_stderr = uwu_proc.communicate()
+ uwu_text = uwu_stdout.decode("utf-8")
+ uwu_text = uwu_text.replace("\n", "")
+ uwu_text = uwu_text.replace("\r", "")
+ return uwu_text
+ commit.delta = _to_uwu(commit.delta)
+ commit.preview = _to_uwu(commit.preview)
+ return commit
+
+class ProfanityPlugin(Plugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+ en_profanity_path = os.path.abspath("Resources/Profanity/en")
+ self.filter = ProfanityFilter(en_profanity_path)
+ if cfg["enable_profanity_filter"]:
+ self.filter.load()
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_profanity_filter"]:
+ commit.delta = self.filter.filter(commit.delta)
+ commit.preview = self.filter.filter(commit.preview)
+ return commit
+
class OscPager:
def __init__(self, cfg):
self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"],
@@ -678,6 +850,9 @@ def transcriptionThread(ctrl: ThreadControl):
commit = ctrl.committer.getDelta()
+ for plugin in ctrl.plugins:
+ commit = plugin.transform(commit)
+
if len(commit.delta) > 0 or len(commit.preview) > 0:
# Hard-cap displayed transcript length at 4k characters to prevent
# runaway memory use in UI. Keep the full transcript to avoid
@@ -702,6 +877,8 @@ def transcriptionThread(ctrl: ThreadControl):
ctrl.transcript += commit.delta
ctrl.preview = ctrl.transcript + commit.preview
+ for plugin in ctrl.plugins:
+ plugin.stop()
def vrInputThread(ctrl: ThreadControl):
RECORD_STATE = 0
@@ -936,6 +1113,13 @@ def run(cfg):
ctrl.collector = collector
ctrl.whisper = whisper
ctrl.committer = committer
+ ctrl.plugins = []
+ ctrl.plugins.append(TranslationPlugin(cfg))
+ ctrl.plugins.append(UppercasePlugin(cfg))
+ ctrl.plugins.append(LowercasePlugin(cfg))
+ ctrl.plugins.append(TrailingPeriodPlugin(cfg))
+ ctrl.plugins.append(ProfanityPlugin(cfg))
+ ctrl.plugins.append(UwuPlugin(cfg))
ctrl.pager = pager
ctrl.transcript = ""
ctrl.preview = ""
@@ -973,6 +1157,8 @@ def run(cfg):
print("Done", file=sys.stderr)
if __name__ == "__main__":
+ sys.stdout.reconfigure(encoding="utf-8")
+
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to app config YAML file.")
args = parser.parse_args()