diff options
Diffstat (limited to 'Scripts')
| -rw-r--r-- | Scripts/transcribe_v2.py | 108 |
1 files changed, 68 insertions, 40 deletions
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 87b88af..6377ff4 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -532,7 +532,23 @@ class VadCommitter: latency_s, audio=audio) -class Plugin: +def install_in_venv(pkgs: typing.List[str]) -> bool: + pkgs_str = " ".join(pkgs) + print(f"Installing {pkgs_str}") + pip_proc = subprocess.Popen( + f"Resources/Python/python.exe -m pip install {pkgs_str}".split(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + pip_stdout, pip_stderr = pip_proc.communicate() + pip_stdout = pip_stdout.decode("utf-8") + pip_stderr = pip_stderr.decode("utf-8") + print(pip_stdout, file=sys.stderr) + print(pip_stderr, file=sys.stderr) + if pip_proc.returncode != 0: + print(f"`pip install {pkgs_str}` exited with {pip_proc.returncode}", + file=sys.stderr) + +class StreamingPlugin: def __init__(self): pass @@ -542,7 +558,7 @@ class Plugin: def stop(self): pass -class TranslationPlugin(Plugin): +class TranslationPlugin(StreamingPlugin): def __init__(self, cfg): lang_bits = cfg["language_target"].split(" | ") self.cfg = cfg @@ -554,23 +570,8 @@ class TranslationPlugin(Plugin): self.language_target = lang_bits[1] print("Translation requested", file=sys.stderr) - print("Installing torch and sentencepiece in virtual environment. " - "Nothing will print " - "for several minutes while these download (~2.4 GB).", - file=sys.stderr) - - pip_proc = subprocess.Popen( - "Resources/Python/python.exe -m pip install sentencepiece torch".split(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - pip_stdout, pip_stderr = pip_proc.communicate() - pip_stdout = pip_stdout.decode("utf-8") - pip_stderr = pip_stderr.decode("utf-8") - print(pip_stdout, file=sys.stderr) - print(pip_stderr, file=sys.stderr) - if pip_proc.returncode != 0: - print(f"Failed to set up for translation: `pip install torch` " - "exited with {pip_proc.returncode}", file=sys.stderr) + if not install_in_venv(["torch", "sentencepiece"]): + return output_dir = "Resources/" + cfg["model_translation"] # Provided by ctranslate2 Python package @@ -631,7 +632,7 @@ class TranslationPlugin(Plugin): commit.preview = _translate_text(commit.preview) return commit -class LowercasePlugin(Plugin): +class LowercasePlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -641,7 +642,7 @@ class LowercasePlugin(Plugin): commit.preview = commit.preview.lower() return commit -class UppercasePlugin(Plugin): +class UppercasePlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -651,21 +652,7 @@ class UppercasePlugin(Plugin): commit.preview = commit.preview.upper() return commit -class TrailingPeriodPlugin(Plugin): - def __init__(self, cfg): - self.cfg = cfg - - def transform(self, commit: TranscriptCommit) -> TranscriptCommit: - if self.cfg["remove_trailing_period"]: - def _remove_trailing_period(s: str) -> str: - if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): - s = s[0:len(s)-1] - return s - commit.delta = _remove_trailing_period(commit.delta) - commit.preview = _remove_trailing_period(commit.preview) - return commit - -class UwuPlugin(Plugin): +class UwuPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg @@ -684,7 +671,7 @@ class UwuPlugin(Plugin): commit.preview = _to_uwu(commit.preview) return commit -class ProfanityPlugin(Plugin): +class ProfanityPlugin(StreamingPlugin): def __init__(self, cfg): self.cfg = cfg en_profanity_path = os.path.abspath("Resources/Profanity/en") @@ -698,6 +685,34 @@ class ProfanityPlugin(Plugin): commit.preview = self.filter.filter(commit.preview) return commit +class PresentationFilter: + def __init__(self): + pass + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + return transcript, preview + + def stop(self): + pass + +class TrailingPeriodFilter(PresentationFilter): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + if len(preview) == 0: + print("here") + transcript = _remove_trailing_period(transcript) + else: + print("there") + preview = _remove_trailing_period(preview) + return transcript, preview + class OscPager: def __init__(self, cfg): self.osc_state = osc_ctrl.OscState(cfg["chars_per_sync"], @@ -859,15 +874,22 @@ def transcriptionThread(ctrl: ThreadControl): # Hard-cap displayed transcript length at 4k characters to prevent # runaway memory use in UI. Keep the full transcript to avoid # breaking OSC pager. + transcript = ctrl.transcript[-4096:] + commit.delta + preview = commit.preview + + for filt in ctrl.filters: + transcript, preview = filt.transform(transcript, preview) + try: - print(f"Transcript: {ctrl.transcript[-4096:]}{commit.delta}") + print(f"Transcript: {transcript}") except UnicodeEncodeError: print("Failed to encode transcript - discarding delta") continue try: - print(f"Preview: {commit.preview}") + print(f"Preview: {preview}") except UnicodeEncodeError: print("Failed to encode preview - discarding") + if cfg["enable_debug_mode"]: print(f"commit latency: {commit.latency_s}", file=sys.stderr) print(f"commit thresh: {commit.thresh_at_commit}", @@ -881,6 +903,8 @@ def transcriptionThread(ctrl: ThreadControl): ctrl.preview = ctrl.transcript + commit.preview for plugin in ctrl.plugins: plugin.stop() + for filt in ctrl.filters: + filt.stop() def vrInputThread(ctrl: ThreadControl): RECORD_STATE = 0 @@ -1127,13 +1151,17 @@ def run(cfg): ctrl.collector = collector ctrl.whisper = whisper ctrl.committer = committer + ctrl.plugins = [] ctrl.plugins.append(TranslationPlugin(cfg)) ctrl.plugins.append(UppercasePlugin(cfg)) ctrl.plugins.append(LowercasePlugin(cfg)) - ctrl.plugins.append(TrailingPeriodPlugin(cfg)) ctrl.plugins.append(ProfanityPlugin(cfg)) ctrl.plugins.append(UwuPlugin(cfg)) + + ctrl.filters = [] + ctrl.filters.append(TrailingPeriodFilter(cfg)) + ctrl.pager = pager ctrl.transcript = "" ctrl.preview = "" |
