summaryrefslogtreecommitdiffstats
path: root/Scripts/transcribe_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'Scripts/transcribe_v2.py')
-rw-r--r--Scripts/transcribe_v2.py108
1 files changed, 68 insertions, 40 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 87b88af..6377ff4 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -532,7 +532,23 @@ class VadCommitter:
latency_s,
audio=audio)
-class Plugin:
+def install_in_venv(pkgs: typing.List[str]) -> bool:
+ pkgs_str = " ".join(pkgs)
+ print(f"Installing {pkgs_str}")
+ pip_proc = subprocess.Popen(
+ f"Resources/Python/python.exe -m pip install {pkgs_str}".split(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ pip_stdout, pip_stderr = pip_proc.communicate()
+ pip_stdout = pip_stdout.decode("utf-8")
+ pip_stderr = pip_stderr.decode("utf-8")
+ print(pip_stdout, file=sys.stderr)
+ print(pip_stderr, file=sys.stderr)
+ if pip_proc.returncode != 0:
+ print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}",
+ file=sys.stderr)
+
+class StreamingPlugin:
def __init__(self):
pass
@@ -542,7 +558,7 @@ class Plugin:
def stop(self):
pass
-class TranslationPlugin(Plugin):
+class TranslationPlugin(StreamingPlugin):
def __init__(self, cfg):
lang_bits = cfg["language_target"].split(" | ")
self.cfg = cfg
@@ -554,23 +570,8 @@ class TranslationPlugin(Plugin):
self.language_target = lang_bits[1]
print("Translation requested", file=sys.stderr)
- print("Installing torch and sentencepiece in virtual environment. "
- "Nothing will print "
- "for several minutes while these download (~2.4 GB).",
- file=sys.stderr)
-
- pip_proc = subprocess.Popen(
- "Resources/Python/python.exe -m pip install sentencepiece torch".split(),
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- pip_stdout, pip_stderr = pip_proc.communicate()
- pip_stdout = pip_stdout.decode("utf-8")
- pip_stderr = pip_stderr.decode("utf-8")
- print(pip_stdout, file=sys.stderr)
- print(pip_stderr, file=sys.stderr)
- if pip_proc.returncode != 0:
- print(f"Failed to set up for translation: `pip install torch` "
- "exited with {pip_proc.returncode}", file=sys.stderr)
+ if not install_in_venv(["torch", "sentencepiece"]):
+ return
output_dir = "Resources/" + cfg["model_translation"]
# Provided by ctranslate2 Python package
@@ -631,7 +632,7 @@ class TranslationPlugin(Plugin):
commit.preview = _translate_text(commit.preview)
return commit
-class LowercasePlugin(Plugin):
+class LowercasePlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -641,7 +642,7 @@ class LowercasePlugin(Plugin):
commit.preview = commit.preview.lower()
return commit
-class UppercasePlugin(Plugin):
+class UppercasePlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -651,21 +652,7 @@ class UppercasePlugin(Plugin):
commit.preview = commit.preview.upper()
return commit
-class TrailingPeriodPlugin(Plugin):
- def __init__(self, cfg):
- self.cfg = cfg
-
- def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
- if self.cfg["remove_trailing_period"]:
- def _remove_trailing_period(s: str) -> str:
- if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
- s = s[0:len(s)-1]
- return s
- commit.delta = _remove_trailing_period(commit.delta)
- commit.preview = _remove_trailing_period(commit.preview)
- return commit
-
-class UwuPlugin(Plugin):
+class UwuPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
@@ -684,7 +671,7 @@ class UwuPlugin(Plugin):
commit.preview = _to_uwu(commit.preview)
return commit
-class ProfanityPlugin(Plugin):
+class ProfanityPlugin(StreamingPlugin):
def __init__(self, cfg):
self.cfg = cfg
en_profanity_path = os.path.abspath("Resources/Profanity/en")
@@ -698,6 +685,34 @@ class ProfanityPlugin(Plugin):
commit.preview = self.filter.filter(commit.preview)
return commit
+class PresentationFilter:
+ def __init__(self):
+ pass
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ return transcript, preview
+
+ def stop(self):
+ pass
+
+class TrailingPeriodFilter(PresentationFilter):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ if self.cfg["remove_trailing_period"]:
+ def _remove_trailing_period(s: str) -> str:
+ if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
+ s = s[0:len(s)-1]
+ return s
+ if len(preview) == 0:
+ print("here")
+ transcript = _remove_trailing_period(transcript)
+ else:
+ print("there")
+ preview = _remove_trailing_period(preview)
+ return transcript, preview
+
class OscPager:
def __init__(self, cfg):
self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"],
@@ -859,15 +874,22 @@ def transcriptionThread(ctrl: ThreadControl):
# Hard-cap displayed transcript length at 4k characters to prevent
# runaway memory use in UI. Keep the full transcript to avoid
# breaking OSC pager.
+ transcript = ctrl.transcript[-4096:] + commit.delta
+ preview = commit.preview
+
+ for filt in ctrl.filters:
+ transcript, preview = filt.transform(transcript, preview)
+
try:
- print(f"Transcript: {ctrl.transcript[-4096:]}{commit.delta}")
+ print(f"Transcript: {transcript}")
except UnicodeEncodeError:
print("Failed to encode transcript - discarding delta")
continue
try:
- print(f"Preview: {commit.preview}")
+ print(f"Preview: {preview}")
except UnicodeEncodeError:
print("Failed to encode preview - discarding")
+
if cfg["enable_debug_mode"]:
print(f"commit latency: {commit.latency_s}", file=sys.stderr)
print(f"commit thresh: {commit.thresh_at_commit}",
@@ -881,6 +903,8 @@ def transcriptionThread(ctrl: ThreadControl):
ctrl.preview = ctrl.transcript + commit.preview
for plugin in ctrl.plugins:
plugin.stop()
+ for filt in ctrl.filters:
+ filt.stop()
def vrInputThread(ctrl: ThreadControl):
RECORD_STATE = 0
@@ -1127,13 +1151,17 @@ def run(cfg):
ctrl.collector = collector
ctrl.whisper = whisper
ctrl.committer = committer
+
ctrl.plugins = []
ctrl.plugins.append(TranslationPlugin(cfg))
ctrl.plugins.append(UppercasePlugin(cfg))
ctrl.plugins.append(LowercasePlugin(cfg))
- ctrl.plugins.append(TrailingPeriodPlugin(cfg))
ctrl.plugins.append(ProfanityPlugin(cfg))
ctrl.plugins.append(UwuPlugin(cfg))
+
+ ctrl.filters = []
+ ctrl.filters.append(TrailingPeriodFilter(cfg))
+
ctrl.pager = pager
ctrl.transcript = ""
ctrl.preview = ""