summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-09-10 22:51:16 -0700
committeryum <yum.food.vr@gmail.com>2023-09-10 22:52:52 -0700
commitd3c325c4c4dd954a75267b013f33f5f3c5d041bc (patch)
treeb9ca3b87e14a1ccd5b9d6c525937121e6d9c34ab
parent920d6dfeeac132488c85311512fe9e5da505c4a8 (diff)
Introduce notion of PresentationFilter
... and restructure RemoveTrailingPeriod as a filter instead of as a plugin. Plugins have the power to change transcription data as it comes along, but don't have access to the entire transcript. Filters have access to the entire transcript but can't durably change it. TODO * This does not work with data passed through OSC
-rw-r--r--Scripts/transcribe_v2.py108
1 files changed, 68 insertions, 40 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 87b88af..6377ff4 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -532,7 +532,23 @@ class VadCommitter:
latency_s,
audio=audio)
-class Plugin:
+def install_in_venv(pkgs: typing.List[str]) -> bool:
+ pkgs_str = " ".join(pkgs)
+ print(f"Installing {pkgs_str}")
+ pip_proc = subprocess.Popen(
+ f"Resources/Python/python.exe -m pip install {pkgs_str}".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout, file=sys.stderr)
+ print(pip_stderr, file=sys.stderr)
+ if pip_proc.returncode != 0:
+ print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}",
+ file=sys.stderr)
+
+class StreamingPlugin:
def __init__(self):
pass
@@ -542,7 +558,7 @@ class Plugin:
def stop(self):
pass
-class TranslationPlugin(Plugin):
+class TranslationPlugin(StreamingPlugin):
def __init__(self, cfg):
lang_bits = cfg["language_target"].split(" | ")
self.cfg = cfg
@@ -554,23 +570,8 @@ class TranslationPlugin(Plugin):
self.language_target = lang_bits[1]
print("Translation requested", file=sys.stderr)
- print("Installing torch and sentencepiece in virtual environment. "
- "Nothing will print "
- "for several minutes while these download (~2.4 GB).",
- file=sys.stderr)
-
- pip_proc = subprocess.Popen(
- "Resources/Python/python.exe -m pip install sentencepiece torch".split(),
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- pip_stdout, pip_stderr = pip_proc.communicate()
- pip_stdout = pip_stdout.decode("utf-8")
- pip_stderr = pip_stderr.decode("utf-8")
- print(pip_stdout, file=sys.stderr)
- print(pip_stderr, file=sys.stderr)
- if pip_proc.returncode != 0:
- print(f"Failed to set up for translation: `pip install torch` "
- "exited with {pip_proc.returncode}", file=sys.stderr)
+ if not install_in_venv(["torch", "sentencepiece"]):
+ return
output_dir = "Resources/" + cfg["model_translation"]
# Provided by ctranslate2 Python package
@@ -631,7 +632,7 @@ class TranslationPlugin(Plugin):
commit.preview = _translate_text(commit.preview)
return commit
-class LowercasePlugin(Plugin):
+class LowercasePlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -641,7 +642,7 @@ class LowercasePlugin(Plugin):
commit.preview = commit.preview.lower()
return commit
-class UppercasePlugin(Plugin):
+class UppercasePlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -651,21 +652,7 @@ class UppercasePlugin(Plugin):
commit.preview = commit.preview.upper()
return commit
-class TrailingPeriodPlugin(Plugin):
- def __init__(self, cfg):
- self.cfg = cfg
-
- def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
- if self.cfg["remove_trailing_period"]:
- def _remove_trailing_period(s: str) -> str:
- if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
- s = s[0:len(s)-1]
- return s
- commit.delta = _remove_trailing_period(commit.delta)
- commit.preview = _remove_trailing_period(commit.preview)
- return commit
-
-class UwuPlugin(Plugin):
+class UwuPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -684,7 +671,7 @@ class UwuPlugin(Plugin):
commit.preview = _to_uwu(commit.preview)
return commit
-class ProfanityPlugin(Plugin):
+class ProfanityPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
en_profanity_path = os.path.abspath("Resources/Profanity/en")
@@ -698,6 +685,34 @@ class ProfanityPlugin(Plugin):
commit.preview = self.filter.filter(commit.preview)
return commit
+class PresentationFilter:
+ def __init__(self):
+ pass
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ return transcript, preview
+
+ def stop(self):
+ pass
+
+class TrailingPeriodFilter(PresentationFilter):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ if self.cfg["remove_trailing_period"]:
+ def _remove_trailing_period(s: str) -> str:
+ if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
+ s = s[0:len(s)-1]
+ return s
+ if len(preview) == 0:
+ print("here")
+ transcript = _remove_trailing_period(transcript)
+ else:
+ print("there")
+ preview = _remove_trailing_period(preview)
+ return transcript, preview
+
class OscPager:
def __init__(self, cfg):
self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"],
@@ -859,15 +874,22 @@ def transcriptionThread(ctrl: ThreadControl):
# Hard-cap displayed transcript length at 4k characters to prevent
# runaway memory use in UI. Keep the full transcript to avoid
# breaking OSC pager.
+ transcript = ctrl.transcript[-4096:] + commit.delta
+ preview = commit.preview
+
+ for filt in ctrl.filters:
+ transcript, preview = filt.transform(transcript, preview)
+
try:
- print(f"Transcript: {ctrl.transcript[-4096:]}{commit.delta}")
+ print(f"Transcript: {transcript}")
except UnicodeEncodeError:
print("Failed to encode transcript - discarding delta")
continue
try:
- print(f"Preview: {commit.preview}")
+ print(f"Preview: {preview}")
except UnicodeEncodeError:
print("Failed to encode preview - discarding")
+
if cfg["enable_debug_mode"]:
print(f"commit latency: {commit.latency_s}", file=sys.stderr)
print(f"commit thresh: {commit.thresh_at_commit}",
@@ -881,6 +903,8 @@ def transcriptionThread(ctrl: ThreadControl):
ctrl.preview = ctrl.transcript + commit.preview
for plugin in ctrl.plugins:
plugin.stop()
+ for filt in ctrl.filters:
+ filt.stop()
def vrInputThread(ctrl: ThreadControl):
RECORD_STATE = 0
@@ -1127,13 +1151,17 @@ def run(cfg):
ctrl.collector = collector
ctrl.whisper = whisper
ctrl.committer = committer
+
ctrl.plugins = []
ctrl.plugins.append(TranslationPlugin(cfg))
ctrl.plugins.append(UppercasePlugin(cfg))
ctrl.plugins.append(LowercasePlugin(cfg))
- ctrl.plugins.append(TrailingPeriodPlugin(cfg))
ctrl.plugins.append(ProfanityPlugin(cfg))
ctrl.plugins.append(UwuPlugin(cfg))
+
+ ctrl.filters = []
+ ctrl.filters.append(TrailingPeriodFilter(cfg))
+
ctrl.pager = pager
ctrl.transcript = ""
ctrl.preview = ""