diff options
| author | yum <yum.food.vr@gmail.com> | 2023-09-10 22:51:16 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-09-10 22:52:52 -0700 |
| commit | d3c325c4c4dd954a75267b013f33f5f3c5d041bc (patch) | |
| tree | b9ca3b87e14a1ccd5b9d6c525937121e6d9c34ab | |
| parent | 920d6dfeeac132488c85311512fe9e5da505c4a8 (diff) | |
Introduce notion of PresentationFilter
... and restructure RemoveTrailingPeriod as a filter instead of as a
plugin.
Plugins have the power to change transcription data as it comes along,
but don't have access to the entire transcript. Filters have access to
the entire transcript but can't durably change it.
TODO
* This does not work with data passed through OSC
| -rw-r--r-- | Scripts/transcribe_v2.py | 108 |
1 files changed, 68 insertions, 40 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 87b88af..6377ff4 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -532,7 +532,23 @@ class VadCommitter: latency_s, audio=audio) -class Plugin: +def install_in_venv(pkgs: typing.List[str]) -> bool: + pkgs_str = " ".join(pkgs) + print(f"Installing {pkgs_str}") + pip_proc = subprocess.Popen( + f"Resources/Python/python.exe -m pip install {pkgs_str}".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout, file=sys.stderr) + print(pip_stderr, file=sys.stderr) + if pip_proc.returncode != 0: + print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}", + file=sys.stderr) + +class StreamingPlugin: def __init__(self): pass @@ -542,7 +558,7 @@ class Plugin: def stop(self): pass -class TranslationPlugin(Plugin): +class TranslationPlugin(StreamingPlugin): def __init__(self, cfg): lang_bits = cfg["language_target"].split(" | ") self.cfg = cfg @@ -554,23 +570,8 @@ class TranslationPlugin(Plugin): self.language_target = lang_bits[1] print("Translation requested", file=sys.stderr) - print("Installing torch and sentencepiece in virtual environment. " - "Nothing will print " - "for several minutes while these download (~2.4 GB).", - file=sys.stderr) - - pip_proc = subprocess.Popen( - "Resources/Python/python.exe -m pip install sentencepiece torch".split(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - pip_stdout, pip_stderr = pip_proc.communicate() - pip_stdout = pip_stdout.decode("utf-8") - pip_stderr = pip_stderr.decode("utf-8") - print(pip_stdout, file=sys.stderr) - print(pip_stderr, file=sys.stderr) - if pip_proc.returncode != 0: - print(f"Failed to set up for translation: `pip install torch` " - "exited with {pip_proc.returncode}", file=sys.stderr) + if not install_in_venv(["torch", "sentencepiece"]): + return output_dir = "Resources/" + cfg["model_translation"] # Provided by ctranslate2 Python package @@ -631,7 +632,7 @@ class TranslationPlugin(Plugin): commit.preview = _translate_text(commit.preview) return commit -class LowercasePlugin(Plugin): +class LowercasePlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -641,7 +642,7 @@ class LowercasePlugin(Plugin): commit.preview = commit.preview.lower() return commit -class UppercasePlugin(Plugin): +class UppercasePlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -651,21 +652,7 @@ class UppercasePlugin(Plugin): commit.preview = commit.preview.upper() return commit -class TrailingPeriodPlugin(Plugin): - def __init__(self, cfg): - self.cfg = cfg - - def transform(self, commit: TranscriptCommit) -> TranscriptCommit: - if self.cfg["remove_trailing_period"]: - def _remove_trailing_period(s: str) -> str: - if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): - s = s[0:len(s)-1] - return s - commit.delta = _remove_trailing_period(commit.delta) - commit.preview = _remove_trailing_period(commit.preview) - return commit - -class UwuPlugin(Plugin): +class UwuPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -684,7 +671,7 @@ class UwuPlugin(Plugin): commit.preview = _to_uwu(commit.preview) return commit -class ProfanityPlugin(Plugin): +class ProfanityPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg en_profanity_path = os.path.abspath("Resources/Profanity/en") @@ -698,6 +685,34 @@ class ProfanityPlugin(Plugin): commit.preview = self.filter.filter(commit.preview) return commit +class PresentationFilter: + def __init__(self): + pass + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + return transcript, preview + + def stop(self): + pass + +class TrailingPeriodFilter(PresentationFilter): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + if len(preview) == 0: + print("here") + transcript = _remove_trailing_period(transcript) + else: + print("there") + preview = _remove_trailing_period(preview) + return transcript, preview + class OscPager: def __init__(self, cfg): self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"], @@ -859,15 +874,22 @@ def transcriptionThread(ctrl: ThreadControl): # Hard-cap displayed transcript length at 4k characters to prevent # runaway memory use in UI. Keep the full transcript to avoid # breaking OSC pager. + transcript = ctrl.transcript[-4096:] + commit.delta + preview = commit.preview + + for filt in ctrl.filters: + transcript, preview = filt.transform(transcript, preview) + try: - print(f"Transcript: {ctrl.transcript[-4096:]}{commit.delta}") + print(f"Transcript: {transcript}") except UnicodeEncodeError: print("Failed to encode transcript - discarding delta") continue try: - print(f"Preview: {commit.preview}") + print(f"Preview: {preview}") except UnicodeEncodeError: print("Failed to encode preview - discarding") + if cfg["enable_debug_mode"]: print(f"commit latency: {commit.latency_s}", file=sys.stderr) print(f"commit thresh: {commit.thresh_at_commit}", @@ -881,6 +903,8 @@ def transcriptionThread(ctrl: ThreadControl): ctrl.preview = ctrl.transcript + commit.preview for plugin in ctrl.plugins: plugin.stop() + for filt in ctrl.filters: + filt.stop() def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -1127,13 +1151,17 @@ def run(cfg): ctrl.collector = collector ctrl.whisper = whisper ctrl.committer = committer + ctrl.plugins = [] ctrl.plugins.append(TranslationPlugin(cfg)) ctrl.plugins.append(UppercasePlugin(cfg)) ctrl.plugins.append(LowercasePlugin(cfg)) - ctrl.plugins.append(TrailingPeriodPlugin(cfg)) ctrl.plugins.append(ProfanityPlugin(cfg)) ctrl.plugins.append(UwuPlugin(cfg)) + + ctrl.filters = [] + ctrl.filters.append(TrailingPeriodFilter(cfg)) + ctrl.pager = pager ctrl.transcript = "" ctrl.preview = "" |
