From 0c54e1fc74fe7677a0d4fef1c147c6e886d182db Mon Sep 17 00:00:00 2001 From: yum Date: Sun, 11 May 2025 22:22:48 -0700 Subject: code bomb --- GenerateTextAnimator.cs | 200 ++++++++++++++++++ README.md | 282 ++++++++++++++++++++++++++ bpe_dump.py | 61 ++++++ generate_bpe_lut.py | 119 +++++++++++ generate_tokenizer.py | 525 ++++++++++++++++++++++++++++++++++++++++++++++++ hi.py | 357 ++++++++++++++++++++++++++++++++ requirements.txt | 6 + tokenize_me.py | 28 +++ 8 files changed, 1578 insertions(+) create mode 100644 GenerateTextAnimator.cs create mode 100644 README.md create mode 100644 bpe_dump.py create mode 100644 generate_bpe_lut.py create mode 100644 generate_tokenizer.py create mode 100644 hi.py create mode 100644 requirements.txt create mode 100644 tokenize_me.py diff --git a/GenerateTextAnimator.cs b/GenerateTextAnimator.cs new file mode 100644 index 0000000..fb4a1ee --- /dev/null +++ b/GenerateTextAnimator.cs @@ -0,0 +1,200 @@ +#if UNITY_EDITOR +using AnimatorAsCode.V1; +using AnimatorAsCode.V1.ModularAvatar; +using YumTools; +using nadena.dev.ndmf; +using System.Collections.Generic; +using UnityEditor; +using UnityEditor.Animations; +using UnityEngine; +using VRC.SDK3.Avatars.Components; +using VRC.SDKBase; + +// This example uses NDMF. See https://github.com/bdunderscore/ndmf?tab=readme-ov-file#getting-started +[assembly: ExportsPlugin(typeof(GenerateTextAnimatorPlugin))] +namespace YumTools +{ + public class GenerateTextAnimator : MonoBehaviour, IEditorOnly + { + // The number of blocks addressable. + public int numBlocks = 10; + // The number of datums sent per block. + public int blockWidth = 5; + // The number of bytes per datum. + public int bytesPerDatum = 2; + + public string oscPointerParam = "_Unigram_Letter_Grid_OSC_Pointer"; + + // Data sent through OSC uses this pattern. + public string[] oscParamPerBlockDatumByte = {"_Unigram_Letter_Grid_OSC_Datum{0:00}_Byte{1:00}"}; + // The pattern of the parameters which this script will animate. + public string[] matPropPerBlockDatumByte = {"_Unigram_Letter_Grid_Data_Block{0:00}_Datum{1:00}_Byte{2:00}_Animated"}; + + public string[] oscParamPerBlockDatum = {}; + public string[] matPropPerBlockDatum = {}; + + public string[] oscParamPerBlock = {"_Unigram_Letter_Grid_OSC_Visual_Pointer"}; + public string[] matPropPerBlock = {"_Unigram_Letter_Grid_Block_{0:00}_Visual_Pointer_Animated"}; + } + + public class GenerateTextAnimatorPlugin : Plugin + { + public override string QualifiedName => "yum.generate_chat_animator"; + public override string DisplayName => "Chat Animator"; + + private const string SystemName = "GenerateTextAnimator"; + // Direct blendtrees have special semantics with write defaults. We want + // them on. They will not fuck up the rest of the animator, whether it uses + // write defaults or not. + private const bool UseWriteDefaults = true; + + protected override void Configure() + { + InPhase(BuildPhase.Generating).Run($"Generate {DisplayName}", Generate); + } + + private AacFlBlendTreeDirect Add8BitBlendTree( + AacFlBase aac, + AacFlLayer layer, + GenerateTextAnimator cfg, + AacFlBlendTreeDirect tree, + string matProp, string oscParam, string oneParam) { + var offAnim = aac.NewClip().Animating(clip => + { + clip.Animates(cfg.GetComponent(), "material." + matProp).WithFrameCountUnit(keyframes => + keyframes.Constant(/*when=*/0, /*value=*/0)); + }); + var onAnim = aac.NewClip().Animating(clip => + { + clip.Animates(cfg.GetComponent(), "material." + matProp).WithFrameCountUnit(keyframes => + keyframes.Constant(/*when=*/0, /*value=*/255)); + }); + var subtree = aac.NewBlendTree().Simple1D(layer.FloatParameter(oscParam)) + .WithAnimation(offAnim, /*threshold=*/-1) + .WithAnimation(onAnim, /*threshold=*/ 1); + return tree.WithAnimation(subtree, layer.FloatParameter("AlwaysOne")); + } + + private void Generate(BuildContext ctx) + { + var components = ctx.AvatarRootTransform.GetComponentsInChildren(true); + if (components.Length == 0) return; + + var aac = AacV1.Create(new AacConfiguration + { + SystemName = SystemName, + AnimatorRoot = ctx.AvatarRootTransform, + DefaultValueRoot = ctx.AvatarRootTransform, + AssetKey = GUID.Generate().ToString(), + AssetContainer = ctx.AssetContainer, + ContainerMode = AacConfiguration.Container.OnlyWhenPersistenceRequired, + AssetContainerProvider = new NDMFContainerProvider(ctx), + DefaultsProvider = new AacDefaultsProvider(UseWriteDefaults) + }); + + // Create a new object in the scene. We will add Modular Avatar components inside it. + var modularAvatar = MaAc.Create(new GameObject(SystemName) + { + transform = { parent = ctx.AvatarRootTransform } + }); + + var ctrl = aac.NewAnimatorController(); + for (int component_i = 0; component_i < components.Length; component_i++) { + GenerateTextAnimator cfg = components[component_i]; + var layer = ctrl.NewLayer($"Chatbox Plumbing (component #{component_i})"); + + var baseState = layer.NewState("Entry (noop)").WithWriteDefaultsSetTo(false); + + // Create "always one" value to use in DBT. + // See https://vrc.school/docs/Other/DBT-Combining/ for details. + layer.OverrideValue(layer.FloatParameter("AlwaysOne"), 1.0f); + + // Create blendtrees. One for each block of data. + var block_trees = new List(); + AacFlState last_block_state = null; + // For each block. + for (int i = 0; i < cfg.numBlocks; i++) { + var block_tree = aac.NewBlendTree().Direct(); + + // Create block-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) { + string matPropB = string.Format(cfg.matPropPerBlock[ii], i); + string oscParamB = cfg.oscParamPerBlock[ii]; + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropB, oscParamB, "AlwaysOne"); + } + + // For each datum per block. + for (int j = 0; j < cfg.blockWidth; j++) { + // Create (block, datum)-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) { + string matPropBD = string.Format(cfg.matPropPerBlockDatum[ii], i, j); + string oscParamBD = string.Format(cfg.oscParamPerBlockDatum[ii], j); + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBD, oscParamBD, "AlwaysOne"); + } + + // For each byte per datum. + for (int k = 0; k < cfg.bytesPerDatum; k++) { + // Create (block, datum, byte)-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) { + string matPropBDB = string.Format(cfg.matPropPerBlockDatumByte[ii], i, j, k); + //Debug.Log($"animating property: {matPropBDB}"); + string oscParamBDB = string.Format(cfg.oscParamPerBlockDatumByte[ii], j, k); + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBDB, oscParamBDB, "AlwaysOne"); + } + } + } + + var cur_block_state = layer.NewState($"Block {i}").WithAnimation(block_tree); + if (last_block_state != null) { + cur_block_state = cur_block_state.Under(last_block_state); + } + last_block_state = cur_block_state; + block_trees.Add(cur_block_state); + } + + // Create transitions to each block's blendtree. + for (int i = 0; i < cfg.numBlocks; i++) { + block_trees[i].TransitionsFromAny() + //.WithInterruption(TransitionInterruptionSource.Source) + .When(layer.IntParameter(cfg.oscPointerParam).IsEqualTo(i)); + } + + // Create sync params (VRCSDK) + for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) { + string bParam = cfg.oscParamPerBlock[ii]; + modularAvatar.NewParameter(layer.FloatParameter(bParam)); + } + for (int i = 0; i < cfg.blockWidth; i++) { + for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) { + string bdParam = string.Format(cfg.oscParamPerBlockDatum[ii], i); + modularAvatar.NewParameter(layer.FloatParameter(bdParam)); + } + for (int j = 0; j < cfg.bytesPerDatum; j++) { + for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) { + string bdbParam = string.Format(cfg.oscParamPerBlockDatumByte[ii], i, j); + modularAvatar.NewParameter(layer.FloatParameter(bdbParam)).WithDefaultValue(1.0f); + } + } + } + + modularAvatar.NewParameter(layer.IntParameter(cfg.oscPointerParam)); + } + + // By creating a Modular Avatar Merge Animator component, + // our animator controller will be added to the avatar's FX layer. + modularAvatar.NewMergeAnimator(ctrl.AnimatorController, VRCAvatarDescriptor.AnimLayerType.FX); + } + } + + // (For AAC 1.2.0 and above) This is recommended starting from NDMF 1.6.0. You only need to define this class once. + internal class NDMFContainerProvider : IAacAssetContainerProvider + { + private readonly BuildContext _ctx; + public NDMFContainerProvider(BuildContext ctx) => _ctx = ctx; + public void SaveAsPersistenceRequired(Object objectToAdd) => _ctx.AssetSaver.SaveAsset(objectToAdd); + public void SaveAsRegular(Object objectToAdd) { } // Let NDMF crawl our assets when it finishes + public void ClearPreviousAssets() { } // ClearPreviousAssets is never used in non-destructive contexts + } +} +#endif + diff --git a/README.md b/README.md new file mode 100644 index 0000000..eaeceea --- /dev/null +++ b/README.md @@ -0,0 +1,282 @@ +# Optimized text paging for VRChat + +It is sometimes useful to send text data into VRChat, for example for +speech-to-text (STT). This is typically done naively, with a "block" of +n 8-bit characters\* sent in along with an 8-bit pointer. Since avatars can only +send 256 bits at 3 Hz\*\* with OSC, this means you can only send (256 - 8) / 8 = +31 characters per sync. The average English word is 4.79 characters long, so +if we naively send in 1 character per byte, then we get a speed limit of 6.47 +words per sync. The works out to ~20 words per second or ~1200 words per minute +(wpm). Adults typically read at about 238-260 wpm. + +\* Typically ASCII encoding is used. + +\*\* Experimentally, 3 Hz is the fastest you can reliably page data with OSC in +busy instances. + +Sending in one character per byte gives you (1200/256) ~= 4.7 wpm per OSC bit +used. Thus to reach a typical reading speed, you need to use (260/4.7) = 55.5 +OSC bits. The goal of this module is to get more out of these bits by +compressing text over the wire. + +## Unigram tokenizer + +Byte pair encoding (BPE) is an encoding scheme frequently used in natural +language processing (NLP) contexts. For any language with a fixed character set +(e.g. an alphabet), you can count up how often each character is used in a +large corpus of text. Then you can repeatedly join together characters and +assign a unique token to joined characters in the order of the most frequently +occurring sequences. You wind up with a lookup table like this: + +``` +0: [UNK] +1: +2: +3: [CLS] +4: [SEP] +... +100: from +101: but +102: he +103: e +104: now +... +10000: eo +10001: is currently +10002: dish +10003: Mi +10004: 6) +``` + +The above example is from a unigram\* sentencepiece organizer that I trained. + +\* A unigram tokenizer is a variant of the byte-pair tokenizer. Where a byte +pair tokenizer views inputs as arbitrary sequences of bytes, a unigram +tokenizer views it as a sequence of letters. + +The tokenizer has a vocabulary size of 65,536 tokens. It was trained on +opensubtitles and 5% of wikipedia, with `unidecode` normalization applied to +limit training data to ASCII. Subword lengths are distributed as follows: + +Subword length histogram: +1: 95 +2: 1032 +3: 3350 +4: 5439 +5: 5445 +6: 5082 +7: 5334 +8: 5866 +9: 6172 +10: 5934 +11: 5329 +12: 4698 +13: 4016 +14: 3191 +15: 2434 +16: 2095 + +In the test set - 25% wikipedia 75% conversational English - +this tokenizer yields 4.896 characters per token. (Recall that +the average English word is 4.8 characters long. Not bad!) + +If we naively send these tokens into the game with a 16-bit number, we can send +floor((256-8)/16) = 15 tokens per sync. This gives us an average of 73.44 +characters per sync - more than 2x higher than the naive approach's 31. + +Here is how the unigram-tokenized scheme fairs against the naive scheme in +every possible configuration of bits used: + +``` +bits naive rate bpe rate speedup factor +8 n/a n/a n/a +16 1 n/a 0.000 +24 2 4.896 2.448 +32 3 4.896 1.632 +40 4 9.792 2.448 +48 5 9.792 1.958 +56 6 14.688 2.448 +64 7 14.688 2.098 +72 8 19.584 2.448 +80 9 19.584 2.176 +88 10 24.480 2.448 +96 11 24.480 2.225 +104 12 29.376 2.448 +112 13 29.376 2.260 +120 14 34.272 2.448 +128 15 34.272 2.285 +136 16 39.168 2.448 +144 17 39.168 2.304 +152 18 44.064 2.448 +160 19 44.064 2.319 +168 20 48.960 2.448 +176 21 48.960 2.331 +184 22 53.856 2.448 +192 23 53.856 2.342 +200 24 58.752 2.448 +208 25 58.752 2.350 +216 26 63.648 2.448 +224 27 63.648 2.357 +232 28 68.544 2.448 +240 29 68.544 2.364 +248 30 73.440 2.448 +256 31 73.440 2.369 +``` + +([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing)) + +I reserve 39 token slots for sequences of whitespace characters of length 2-40. This helps simplify formatting. To end a line or position text, you can just send in the exact right number of spaces, and a fixed-width font renderer will position things as intended. + +## Paging data into shader + +Sending this data to a shader is pretty simple: + +- An OSC app encodes a string into tokens and pages it into the game with OSC. + - The app sends a pointer of *where* the tokens should be rendered along with them. Since the tokens can encode a variable length string, the pointer must be able to point to any spot in the rendering window. Thus we are limited to a 256-bit display with an 8-bit pointer, or a 64K display with a 16-bit pointer. We call this the visual pointer. + - A second pointer tells the animator which shader parameters the tokens and visual pointer should be written to. This can be 8-bit. We call this the logical pointer. +- Animator uses the logical pointer to decide which shader parameters to send the visual pointer and tokens to. + +Here is the expected speedup in every possible configuration with a 1-byte +overhead (1BO) or a 2-byte overhead (2BO): + +``` +bits naive rate bpe rate (1BO) speedup bpe rate (2BO) speedup +8 n/a n/a n/a n/a n/a +16 1 0.000 0.000 0.000 0.000 +24 2 0.000 0.000 0.000 0.000 +32 3 4.896 1.632 0.000 0.000 +40 4 4.896 1.224 4.896 1.224 +48 5 9.792 1.958 4.896 0.979 +56 6 9.792 1.632 9.792 1.632 +64 7 14.688 2.098 9.792 1.399 +72 8 14.688 1.836 14.688 1.836 +80 9 19.584 2.176 14.688 1.632 +88 10 19.584 1.958 19.584 1.958 +96 11 24.480 2.225 19.584 1.780 +104 12 24.480 2.040 24.480 2.040 +112 13 29.376 2.260 24.480 1.883 +120 14 29.376 2.098 29.376 2.098 +128 15 34.272 2.285 29.376 1.958 +136 16 34.272 2.142 34.272 2.142 +144 17 39.168 2.304 34.272 2.016 +152 18 39.168 2.176 39.168 2.176 +160 19 44.064 2.319 39.168 2.061 +168 20 44.064 2.203 44.064 2.203 +176 21 48.960 2.331 44.064 2.098 +184 22 48.960 2.225 48.960 2.225 +192 23 53.856 2.342 48.960 2.129 +200 24 53.856 2.244 53.856 2.244 +208 25 58.752 2.350 53.856 2.154 +216 26 58.752 2.260 58.752 2.260 +224 27 63.648 2.357 58.752 2.176 +232 28 63.648 2.273 63.648 2.273 +240 29 68.544 2.364 63.648 2.195 +248 30 68.544 2.285 68.544 2.285 +256 31 73.440 2.369 68.544 2.211 +``` + +([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing)) + +As you can see, a 2-byte visual pointer is very damaging to the speedup at low bit budgets. So in bit-constrained setups we should definitely use a smaller display. + +Notably, *there is only one crossover point*. If all configurations except the 2-byte overhead 48-bit configuration, BPE-based paging is always\* faster. + +\* Always, going off of the *expected* rate. If you get unlucky and your tokens all decode to 1 character, then BPE-based paging is about 50% *slower* than naive encoding. + +Because the Unity animator sucks shit, we're going to decode tokens on the GPU. As a refresher, the GPU sees data like this: + +``` +_Text_Block00_Visual_Ptr: 0 +_Text_Block00_Token00: 13,766 +_Text_Block00_Token01: 84 +_Text_Block01_Visual_Ptr: 13 +_Text_Block01_Token00: 599 +_Text_Block01_Token01: 8,301 +... +``` + +I.e. it sees "blocks" of data with tokens and visual pointers. The visual pointer just says where on a grid it should draw the subwords represented by the tokens. + +The pixel shader can trivially work out what grid location the current pixel belongs to. By scanning through the visual pointers, it can work out which block it has to draw. + +We can generate a function like this: + +```c +#define BLOCK_WIDTH 2 +void GetBlock(uint which_block, out float data[BLOCK_WIDTH]) { + [loop] + for (uint i = 0; i < BLOCK_WIDTH; i++) { + data[i] = _Text_Blocks[which_block][i]; + } +} + +// Get the tokens that cover `screen_ptr`. Also returns `block_ptr`, the +// location where this block of tokens begins. +void GetTokens(uint screen_ptr, out uint block_ptr, out uint tokens[BLOCK_WIDTH]) { + uint which_block; + [loop] + for (uint i = 0; i < BLOCK_WIDTH; i++) { + if (screen_ptr >= _Text_Block_Visual_Ptrs[i]) { + which_block = i; + } + } + GetBlock(which_block, tokens); +} +``` + +## GPU decoding + +Now we have to translate the tokens into text. I do this with a texture laid out as follows: + +1. A fixed-length array of (offset, length) pairs. Offset is 24 bits, giving us an address space of about 16 million slots. Length is 8 bits, but as established above, the longest token is only 16 characters long. So we're wasting about 4 bits. This tells us we should use an RGBA texture. + +2. A variable length array of ASCII-encoded strings. Each slot is RGBA, so it can hold 4 characters. + +My tokenizer's vocabulary is 65,536 tokens. If we add up the lengths of every token, rounding them up to the nearest multiple of 4, we get the length 667,532 This means that we need 166,883 slots to fit the actual content of the vocabulary. + +So, the entire vocabulary - length+offset head and content - requires a 32-bit RGBA texture with 232,419 slots. We'll just jam this into a 512x512 texture, at an occupancy ratio of 88.66% (11.34% waste). The total VRAM usage of that lookup table (LUT) is 1 MiB. + +We want to implement this API: + +```c +uint GetChar(uint screen_ptr); +``` + +Internally, it must: + +1. Get tokens. (GetTokens - already done) +2. Figure out which token covers the screen\_ptr. +3. Figure out which character in the token covers the screen\_ptr. + +Let's break down [2]. We can get the length of a token with a single texture tap. So, naively, we can just scan through the tokens in the current block, add up their lengths, and stop once we find a token covering the current slot. The scan incurs a worst-case cost of `BLOCK_WIDTH` texture taps. The character lookup is then a single tap. + +```c +// Gets the length of the subword encoded by the token. Performs one texture +// tap. +void TokenLengthOffset(uint token, out uint length, out uint offset); +// Gets the nth character of the token stored at `token_offset`. +uint GetTokenChar(uint token_offset, uint nth); + +uint GetChar(uint screen_ptr) { + uint block_ptr; + uint tokens[BLOCK_WIDTH]; + GetTokens(screen_ptr, tokens, block_ptr); + uint start = block_ptr; + uint covering_token = 0; + uint token_ptr = block_ptr; + uint token_offset; + for (uint ii = 0; ii < BLOCK_WIDTH; ii++) { + uint token_length; + TokenLengthOffset(tokens[ii], token_length, token_offset); + if (token_ptr + token_length >= screen_ptr) { + covering_token = tokens[ii]; + } + token_ptr += token_length; + } + // covering_token covers screen_ptr. It starts at token_ptr. + return GetTokenChar(token_offset, screen_ptr - token_ptr); +} +``` + +That's actually it for the GPU decoding. Once you have the character, you can use standard fixed-width font rendering techniques to display it (e.g. [disinfo](https://github.com/yum-food/disinfo) and [msdf](https://github.com/Chlumsky/msdfgen)). + diff --git a/bpe_dump.py b/bpe_dump.py new file mode 100644 index 0000000..74ee08f --- /dev/null +++ b/bpe_dump.py @@ -0,0 +1,61 @@ +import math +import sentencepiece as spm + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + sp = spm.SentencePieceProcessor() + sp.load(model_path) + return sp + +tokenizer = get_tokenizer() + +print(f"vocabulary size: {tokenizer.get_piece_size()}") +# Sentencepiece uses U+2581 (lower one eighth block) to indicate a space before +# a subword. +sp_space = chr(9601) +tokens_with_non_ascii = set() +subword_len_histo = dict() +# The sum of the lengths of each subword in the vocabulary. These are rounded +# up to 4 characters. +vocab_len_4c_quantized = 0 + +for i in range(tokenizer.get_piece_size()): + k = tokenizer.id_to_piece(i) + v = i + print(f" Original token ({v}): {repr(k)} ({' '.join(str(ord(k_c)) for k_c in k)})") + for k_c in k: + if ord(k_c) > 127 and ord(k_c) != 9601: + tokens_with_non_ascii.add(k) + break + k_processed = k.replace(sp_space, ' ') + if not k.startswith(sp_space) and k not in ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]: + k_processed = k + else: + k_processed = k_processed + + current_len = len(k_processed) + if current_len in subword_len_histo: + subword_len_histo[current_len] += 1 + else: + subword_len_histo[current_len] = 1 + + vocab_len_4c_quantized += math.ceil(current_len / 4.0) * 4.0 + print(f" {v}: {k_processed}") + +print(f"Num tokens with non-ascii: {len(tokens_with_non_ascii)} ({100 * len(tokens_with_non_ascii) / tokenizer.get_piece_size():.2f})%") + +print(f"Subword length histogram:") +avg_subword_len = 0 +total_pieces_for_avg = 0 +for k_len, v_count in sorted(subword_len_histo.items(), key=lambda x: x[0]): + avg_subword_len += k_len * v_count + total_pieces_for_avg += v_count + print(f" {k_len}: {v_count}") + +if total_pieces_for_avg > 0: + avg_subword_len /= total_pieces_for_avg + print(f"Average subword length: {avg_subword_len:.4f}") +else: + print("Average subword length: N/A (no pieces analyzed)") + +print(f"Sum of all subword lengths, quantized to 4 character chunks: {vocab_len_4c_quantized}") diff --git a/generate_bpe_lut.py b/generate_bpe_lut.py new file mode 100644 index 0000000..72b1201 --- /dev/null +++ b/generate_bpe_lut.py @@ -0,0 +1,119 @@ +from math import ceil, floor +from PIL import Image +from unidecode import unidecode +import sentencepiece as spm + +IMG_RES = 512 # square image + +def get_tokenizer(): + use_sentencepiece = True + + if not use_sentencepiece: + from tokenizers import Tokenizer + tokenizer_json = "./custom_wordpiece_tokenizer_65k/tokenizer.json" + print(f"Loading Tokenizers library tokenizer from: {tokenizer_json}") + return Tokenizer.from_file(tokenizer_json) + else: + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def get_words(): + tokenizer = get_tokenizer() + + print(f"vocabulary size: {tokenizer.get_piece_size()}") + # sp_space = sentencepiece space. + # A special character sentencepiece uses to represent spaces before words. + sp_space = chr(9601) + words = [] + + # Accumulate words into a list, indexed by the token number. Sanitize them as + # you go. + for i in range(tokenizer.get_piece_size()): + word = tokenizer.id_to_piece(i) + tok = i + #print(f" Original token ({tok}): {repr(word)} ({' '.join(str(ord(c)) for c in word)})") + word_sanitized = "" + # Dirty hack: convert non-ASCII characters to nearest ASCII equivalent + for c in word: + if ord(c) > 127 and c != sp_space: + c_plain = unidecode(c) + print(f" Resolved {c} to {c_plain}") + word_sanitized += c_plain + else: + word_sanitized += c + # Replace sp_space with ' ' + word_sanitized = word_sanitized.replace(sp_space, ' ') + #print(f" {tok}: {word_sanitized}") + words.append(word_sanitized) + + # Special word: empty string. SentencePiece doesn't support this natively. + words.append('') + + return words + +# Fold a flat index into a IMG_RESxIMG_RES box. Return the (x,y) coordinate of +# the folded index. +def fold_idx(flat_idx): + return (flat_idx % IMG_RES, int(floor(flat_idx / IMG_RES))) + +def unfold_idx(coord): + return coord[0] + coord[1] * IMG_RES + +assert unfold_idx(fold_idx(1533125)) == 1533125 +assert unfold_idx(fold_idx(8538235)) == 8538235 +assert fold_idx(unfold_idx((192,235))) == (192,235) +assert fold_idx(unfold_idx((83,388))) == (83,388) + +def generate_lut(words, filename): + # Write the texture header. + black = (0, 0, 0, 255) + img = Image.new('RGBA', (IMG_RES, IMG_RES), black) + + # The header is `len(words)` slots long. Thus the actual LUT content starts at + # the index `len(words)`. + pixel_data = img.load() + lut_ptr = len(words) + for i in range(0, len(words)): + # Get pointer to the actual word data. + tok_ptr = lut_ptr + tok_len = len(words[i]) + rgba = ((tok_ptr >> 0) & 0xFF, + (tok_ptr >> 8) & 0xFF, + (tok_ptr >> 16) & 0xFF, + tok_len) + print(f"Writing {rgba} to {i} / {fold_idx(i)}") + idx_x, idx_y = fold_idx(i) + pixel_data[idx_x, idx_y] = rgba + + for j in range(0, ceil(tok_len/4.0)): + quad_ptr = tok_ptr + j + tok_0 = ord(words[i][j*4]) + tok_1 = ord(words[i][j*4+1] if tok_len > j*4+1 else ' ') + tok_2 = ord(words[i][j*4+2] if tok_len > j*4+2 else ' ') + tok_3 = ord(words[i][j*4+3] if tok_len > j*4+3 else ' ') + rgba = (tok_0, tok_1, tok_2, tok_3) + idx_x, idx_y = fold_idx(quad_ptr) + print(f" Writing {rgba} to {quad_ptr} / {fold_idx(quad_ptr)}") + pixel_data[idx_x, idx_y] = rgba + + # Advance the LUT ptr. Since we store 4 chars per pixel (RGBA), we advance + # it by ceil(tok_len/4). + lut_ptr += int(ceil(tok_len/4.0)) + + pretty = False + if pretty: + for y in range(0, IMG_RES): + for x in range(0, IMG_RES): + rgba = pixel_data[x, y] + pixel_data[x, y] = (rgba[0], rgba[1], rgba[2], 255) + + print(f"Saving to {filename}") + img.save(filename) + +if __name__ == "__main__": + words = get_words() + generate_lut(words, "bpe_lut.png") diff --git a/generate_tokenizer.py b/generate_tokenizer.py new file mode 100644 index 0000000..88903a8 --- /dev/null +++ b/generate_tokenizer.py @@ -0,0 +1,525 @@ +# !!! AI ARTIFACT !!! +# This file was primarily written with AI. + +import os +import argparse +from datasets import load_dataset, interleave_datasets +import sentencepiece as spm +import itertools +import random +from unidecode import unidecode + +# --- Dataset 1: Wikipedia --- +DATASET_NAME_WIKI = "wikipedia" +DATASET_CONFIG_WIKI = "20220301.en" + +DATASET_SPLIT_WIKI = "train" +TEXT_COLUMN_WIKI = "text" + +# --- Dataset 2: DailyDialog --- +DATASET_NAME_DD = "daily_dialog" +DATASET_SPLIT_DD = "train" +UTTERANCE_COLUMN_DD = "dialog" + +# --- Dataset 3: BlendedSkillTalk (includes Persona-Chat) --- +DATASET_NAME_BST = "blended_skill_talk" +DATASET_SPLIT_BST = "train" + +# --- Dataset 4: OpenSubtitles --- +DATASET_NAME_OS = "Helsinki-NLP/open_subtitles" +DATASET_LANG_PAIR_OS = ("en", "fr") # Load en-fr and use the 'en' part +DATASET_SPLIT_OS = "train" +TEXT_COLUMN_OS = "en" # Access via item['translation']['en'] + +# Reserve one space for special "empty" token. +VOCAB_SIZE = 65535 +OUTPUT_DIR = "./custom_unigram_tokenizer_65k" +MODEL_PREFIX = os.path.join(OUTPUT_DIR, "unigram") +MODEL_FILE = MODEL_PREFIX + ".model" +VOCAB_FILE = MODEL_PREFIX + ".vocab" + +UNK_TOKEN = "[UNK]" +PAD_TOKEN = "[PAD]" +CONTROL_SYMBOLS = ["[CLS]", "[SEP]", "[MASK]"] + +BATCH_SIZE = 1000 + +def wiki_iterator(dataset_wiki, batch_size=BATCH_SIZE): + if dataset_wiki: + for i in range(0, len(dataset_wiki), batch_size): + yield [unidecode(text) for text in dataset_wiki[i : i + batch_size][TEXT_COLUMN_WIKI] if text] + +def dd_iterator(dataset_dd, batch_size=BATCH_SIZE): + if dataset_dd: + current_batch = [] + for dialogue in dataset_dd: + utterances = dialogue[UTTERANCE_COLUMN_DD] + for utterance in utterances: + if utterance: + normalized_utterance = unidecode(utterance) + current_batch.append(normalized_utterance) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def bst_iterator(dataset_bst, batch_size=BATCH_SIZE): + if dataset_bst: + current_batch = [] + for session in dataset_bst: + texts_to_add = [] + if session.get("previous_utterance"): + texts_to_add.append(session["previous_utterance"]) + if session.get("free_messages"): + texts_to_add.extend(session["free_messages"]) + if session.get("guided_messages"): + texts_to_add.extend(session["guided_messages"]) + + for text in texts_to_add: + if text and isinstance(text, str): + normalized_text = unidecode(text) + current_batch.append(normalized_text) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def os_iterator(dataset_os, batch_size=BATCH_SIZE, lang_code=TEXT_COLUMN_OS): + if dataset_os: + current_batch = [] + for item in dataset_os: + text = item['translation'][lang_code] + if text: + normalized_text = unidecode(text) + current_batch.append(normalized_text) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def count_wiki_items_chars(dataset_wiki): + item_count = 0 + char_count = 0 + if dataset_wiki: + item_count = len(dataset_wiki) + try: + char_count = sum(len(unidecode(text)) for text in dataset_wiki[TEXT_COLUMN_WIKI] if text and isinstance(text, str)) + except Exception: + print(f"Warning: Direct column access for char count failed for {DATASET_NAME_WIKI}. Iterating row by row (slower).") + char_count = sum(len(unidecode(row[TEXT_COLUMN_WIKI])) for row in dataset_wiki if row.get(TEXT_COLUMN_WIKI) and isinstance(row[TEXT_COLUMN_WIKI], str)) + + return item_count, char_count + +def count_dd_items_chars(dataset_dd): + item_count = 0 + char_count = 0 + if dataset_dd: + for dialogue in dataset_dd: + utterances = dialogue[UTTERANCE_COLUMN_DD] + for utterance in utterances: + if utterance and isinstance(utterance, str): + item_count += 1 + char_count += len(unidecode(utterance)) + return item_count, char_count + +def count_bst_items_chars(dataset_bst): + item_count = 0 + char_count = 0 + if dataset_bst: + for session in dataset_bst: + texts_to_process = [] + # Gather all potential text strings first + prev_utt = session.get("previous_utterance") + if prev_utt and isinstance(prev_utt, list): + if len(prev_utt) > 0 and isinstance(prev_utt[0], str): + texts_to_process.append(prev_utt[0]) + elif prev_utt and isinstance(prev_utt, str): + texts_to_process.append(prev_utt) + + free_msgs = session.get("free_messages") + if free_msgs and isinstance(free_msgs, list): + for item in free_msgs: + if isinstance(item, list): + texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str)) + elif isinstance(item, str): + texts_to_process.append(item) + + guided_msgs = session.get("guided_messages") + if guided_msgs and isinstance(guided_msgs, list): + for item in guided_msgs: + if isinstance(item, list): + texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str)) + elif isinstance(item, str): + texts_to_process.append(item) + + # Count items and chars from the gathered list + for text in texts_to_process: + if text and isinstance(text, str): + normalized_text = unidecode(text) + if normalized_text: + item_count += 1 + char_count += len(normalized_text) + + return item_count, char_count + +def count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS): + item_count = 0 + char_count = 0 + if dataset_os: + for item in dataset_os: + text = item['translation'][lang_code] + if text and isinstance(text, str): + normalized_text = unidecode(text) + if normalized_text: # Ensure not empty after unidecode + item_count += 1 + char_count += len(normalized_text) + return item_count, char_count + +def load_and_count_datasets(wiki_fraction, subtitles_fraction): + """Loads, potentially shrinks, and counts items/chars in datasets.""" + datasets = {} + counts = {} + total_items = 0 + total_chars = 0 + + # --- Wikipedia --- + print(f"Loading dataset 1: {DATASET_NAME_WIKI} ({DATASET_CONFIG_WIKI}), split: {DATASET_SPLIT_WIKI}") + dataset_wiki_full = load_dataset(DATASET_NAME_WIKI, DATASET_CONFIG_WIKI, split=DATASET_SPLIT_WIKI, trust_remote_code=True) + print(f"Original Wikipedia dataset size: {len(dataset_wiki_full):,}") + + # Shrink the Wikipedia dataset + split_test_size = 1.0 - wiki_fraction + shrunk_dataset_split = dataset_wiki_full.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000)) + dataset_wiki = shrunk_dataset_split['train'] + print(f"Using {wiki_fraction*100:.3f}% of Wikipedia dataset: {len(dataset_wiki):,} items") + + count_wiki, chars_wiki = count_wiki_items_chars(dataset_wiki) + print(f"Wikipedia dataset loaded (shrunk). Precise text items: {count_wiki:,}, Characters: {chars_wiki:,}") + datasets['wiki'] = dataset_wiki + counts['wiki'] = (count_wiki, chars_wiki) + total_items += count_wiki + total_chars += chars_wiki + + # --- DailyDialog --- + print(f"\nLoading dataset 2: {DATASET_NAME_DD}, split: {DATASET_SPLIT_DD}") + dataset_dd = load_dataset(DATASET_NAME_DD, split=DATASET_SPLIT_DD, trust_remote_code=True) + count_dd, chars_dd = count_dd_items_chars(dataset_dd) + print(f"DailyDialog dataset loaded. Precise text items (utterances): {count_dd:,}, Characters: {chars_dd:,}") + datasets['dd'] = dataset_dd + counts['dd'] = (count_dd, chars_dd) + total_items += count_dd + total_chars += chars_dd + + # --- BlendedSkillTalk --- + print(f"\nLoading dataset 3: {DATASET_NAME_BST}, split: {DATASET_SPLIT_BST}") + dataset_bst = load_dataset(DATASET_NAME_BST, split=DATASET_SPLIT_BST, trust_remote_code=True) + count_bst, chars_bst = count_bst_items_chars(dataset_bst) + print(f"BlendedSkillTalk dataset loaded. Precise text items (extracted): {count_bst:,}, Characters: {chars_bst:,}") + datasets['bst'] = dataset_bst + counts['bst'] = (count_bst, chars_bst) + total_items += count_bst + total_chars += chars_bst + + # --- OpenSubtitles --- + print(f"\nLoading dataset 4: {DATASET_NAME_OS}, lang_pair: {DATASET_LANG_PAIR_OS}, split: {DATASET_SPLIT_OS}") + # Note: OpenSubtitles can be very large. Consider streaming or specific configurations if memory is an issue. + # For now, loading a standard configuration. + dataset_os = load_dataset(DATASET_NAME_OS, lang1=DATASET_LANG_PAIR_OS[0], lang2=DATASET_LANG_PAIR_OS[1], split=DATASET_SPLIT_OS, trust_remote_code=True) + split_test_size = 1.0 - subtitles_fraction + dataset_os = dataset_os.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))['train'] + count_os, chars_os = count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS) + print(f"OpenSubtitles dataset loaded. Precise text items: {count_os:,}, Characters: {chars_os:,}") + datasets['os'] = dataset_os + counts['os'] = (count_os, chars_os) + total_items += count_os + total_chars += chars_os + + print(f"\nTotal precise text items from loaded datasets: {total_items:,}") + print(f"Total characters from loaded datasets: {total_chars:,}") + + return datasets, counts + +def train_tokenizer(model_prefix, datasets, counts): + """Trains a Unigram tokenizer using SentencePiece and saves it.""" + total_items = sum(c[0] for c in counts.values()) + total_chars = sum(c[1] for c in counts.values()) + if total_items == 0: + print("Error: No items found in the datasets to train on. Exiting training.") + return + + print(f"\nTotal precise text items for training: {total_items:,}") + print(f"Total characters for training: {total_chars:,}") + + output_dir = os.path.dirname(model_prefix) + os.makedirs(output_dir, exist_ok=True) + + print(f"\nStarting SentencePiece Unigram tokenizer training with vocab size: {VOCAB_SIZE}") + print(f"Using combined dataset iterator.") + print(f"Output model prefix: {model_prefix}") + + iterators_to_chain = [] + def flatten_iterator(iterator): + for batch in iterator: + for item in batch: + yield item + + if datasets.get('wiki') and counts['wiki'][0] > 0: + iterators_to_chain.append(flatten_iterator(wiki_iterator(datasets['wiki']))) + if datasets.get('dd') and counts['dd'][0] > 0: + iterators_to_chain.append(flatten_iterator(dd_iterator(datasets['dd']))) + if datasets.get('bst') and counts['bst'][0] > 0: + iterators_to_chain.append(flatten_iterator(bst_iterator(datasets['bst']))) + if datasets.get('os') and counts['os'][0] > 0: + iterators_to_chain.append(flatten_iterator(os_iterator(datasets['os']))) + + if not iterators_to_chain: + print("Error: No valid dataset iterators available for training. Exiting.") + return + + combined_iterator = itertools.chain(*iterators_to_chain) + + # Include whitespace symbols so we can efficiently break lines. + # If we include the single space, it prevents the tokenizer from merging + # spaces with regular words, and tanks the average chars/token. + # This many tokens is kinda overkill, but it gives us a way to efficiently + # clear even fairly large boards, so I think it's worth. + whitespace_symbols = [] + for i in range(2, 40): + whitespace_symbols.append('▁' * i) + + spm.SentencePieceTrainer.train( + sentence_iterator=combined_iterator, + model_prefix=model_prefix, + vocab_size=VOCAB_SIZE, + model_type='unigram', + character_coverage=1.0, + unk_piece=UNK_TOKEN, + pad_piece=PAD_TOKEN, + control_symbols=CONTROL_SYMBOLS, + user_defined_symbols=whitespace_symbols, + # These whitespace options must be false, or else whitespace won't + # be respected when encoding. + add_dummy_prefix=False, + remove_extra_whitespaces=False, + split_by_whitespace=False, + num_threads=os.cpu_count(), + input_sentence_size=total_items, + ) + print("\nTraining finished.") + print(f"SentencePiece model saved to: {model_prefix}.model") + print(f"SentencePiece vocabulary saved to: {model_prefix}.vocab") + +def extract_text_samples(dataset, count, num_samples, text_extractor_func): + """Extracts a specified number of text samples from a dataset.""" + samples = [] + if dataset is None or count == 0 or num_samples == 0: + return samples + # Take samples from the beginning, ensure we don't exceed dataset size + actual_samples_to_take = min(num_samples, count) + # Use the provided function to extract text correctly for this dataset type + samples = text_extractor_func(dataset.select(range(actual_samples_to_take))) + return [unidecode(s) for s in samples if s and isinstance(s,str)] + +def wiki_text_extractor(dataset_slice): + """Extracts text from a slice of the Wikipedia dataset.""" + return [text for text in dataset_slice[TEXT_COLUMN_WIKI] if text] + +def dd_text_extractor(dataset_slice): + """Extracts text from a slice of the DailyDialog dataset.""" + texts = [] + for dialogue in dataset_slice: + texts.extend(utterance for utterance in dialogue[UTTERANCE_COLUMN_DD] if utterance) + return texts + +def bst_text_extractor(dataset_slice): + """Extracts text from a slice of the BlendedSkillTalk dataset.""" + texts = [] + for session in dataset_slice: + if session.get("previous_utterance"): + texts.append(session["previous_utterance"]) + if session.get("free_messages"): + texts.extend(session["free_messages"]) + if session.get("guided_messages"): + texts.extend(session["guided_messages"]) + return [t for t in texts if t and isinstance(t,str)] # Ensure only valid strings + +def os_text_extractor(dataset_slice, lang_code=TEXT_COLUMN_OS): + """Extracts text from a slice of the OpenSubtitles dataset.""" + texts = [] + for item in dataset_slice: + text = item['translation'][lang_code] + if text and isinstance(text, str): + texts.append(text) + return texts + +def test_tokenizer(model_path, datasets, counts, sample_size=1000): + print("\n--- Testing the trained Unigram (SentencePiece) tokenizer on data sample ---") + if sample_size <= 0: + print("Error: Sample size for testing must be positive.") + return + + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model from: {model_path}") + print(f"Vocabulary size: {sp.get_piece_size()}") + + # Identify available datasets with content + available_datasets = { + name: data + for name, data in datasets.items() + if counts[name][0] > 0 + } + num_available_datasets = len(available_datasets) + + if num_available_datasets == 0: + print("Warning: No data available in loaded datasets to sample for testing.") + return + + print(f"Found {num_available_datasets} non-empty dataset(s) for testing.") + print(f"Attempting to sample up to {sample_size} total items equally from these datasets...") + + # Calculate equal number of samples per available dataset + samples_per_dataset = sample_size // num_available_datasets + print(f"Targeting approximately {samples_per_dataset} samples per dataset.") + + test_samples = [] + dataset_extractors = { + 'wiki': wiki_text_extractor, + 'dd': dd_text_extractor, + 'bst': bst_text_extractor, + 'os': os_text_extractor, + } + + actual_samples_collected = {} + + # Sample equally from each available dataset + for name, dataset in available_datasets.items(): + count = counts[name][0] + extractor = dataset_extractors[name] + + # Determine the target number of samples for *this* dataset + num_samples_to_target = min(samples_per_dataset, count) + if num_samples_to_target <= 0: + print(f" Skipping '{name}' (no items requested or available).") + actual_samples_collected[name] = 0 + continue + + print(f" Sampling from '{name}' (target: {num_samples_to_target} items)...") + + items_before = len(test_samples) + # For OpenSubtitles, pass the lang_code if the extractor needs it + if name == 'os': + extracted_items = extract_text_samples(dataset, count, num_samples_to_target, lambda ds_slice: extractor(ds_slice, lang_code=TEXT_COLUMN_OS)) + else: + extracted_items = extract_text_samples(dataset, count, num_samples_to_target, extractor) + + # Limit the extracted items to the target number + final_samples_for_dataset = extracted_items[:num_samples_to_target] + + test_samples.extend(final_samples_for_dataset) + items_added = len(final_samples_for_dataset) + actual_samples_collected[name] = items_added + print(f" Added {items_added} items from '{name}'.") + + actual_sample_size = len(test_samples) + if actual_sample_size == 0: + print("\nCould not gather any samples for testing.") + print("--- Test finished ---") + return + + print(f"\n--- Starting Test Run ---") + print(f"Testing on {actual_sample_size} sampled text items (final count).") + print(f"Samples breakdown: {actual_samples_collected}") + + total_chars = 0 + total_tokens = 0 + examples_to_show = 5 + examples_shown = 0 + + random.shuffle(test_samples) + + for i, text_sample in enumerate(test_samples): + if not text_sample: # Should already be filtered by extractor, but double-check + continue + try: + tokens = sp.encode_as_pieces(text_sample) + num_tokens = len(tokens) + num_chars = len(text_sample) + total_tokens += num_tokens + total_chars += num_chars + + if examples_shown < examples_to_show: + print(f"\nSample {examples_shown + 1} (Overall index {i}):") # Clarify sample index + print(f" Text ({num_chars} chars): {text_sample[:100]}{'...' if len(text_sample)>100 else ''}") + print(f" Tokens ({num_tokens}): {tokens}") + examples_shown += 1 + except Exception as e: + print(f"\nWarning: Error encoding sample (Overall index {i}) with SentencePiece. Skipping sample. Error: {e}") + print(f"Problematic sample text: {text_sample[:100]}...") # Show problematic text + + # --- Test Summary --- + if total_tokens > 0 and total_chars > 0 : # Avoid division by zero + avg_chars_per_token = total_chars / total_tokens + print(f"\n--- Test Summary ---") + print(f"Tested on {actual_sample_size} text items.") + print(f"Final samples breakdown: {actual_samples_collected}") + print(f"Total characters in final sample: {total_chars:,}") + print(f"Total tokens in final sample: {total_tokens:,}") + print(f"Average characters per token: {avg_chars_per_token:.4f}") + elif actual_sample_size > 0: + print("\n--- Test Summary ---") + print(f"Tested on {actual_sample_size} text items.") + print(f"Final samples breakdown: {actual_samples_collected}") + print("No valid tokens were generated from the sample data (or samples were empty after unidecode).") + else: + # This case should be caught earlier, but included for completeness + print("\n--- Test Summary ---") + print("No samples were collected or processed.") + + print("--- Test finished ---") + +def parse_args(): + parser = argparse.ArgumentParser(description="Train and test a Unigram tokenizer on combined datasets.") + parser.add_argument("--train", action="store_true", help="Train the tokenizer on datasets") + parser.add_argument("--test", action="store_true", help="Test the trained tokenizer") + parser.add_argument("--sample_size", type=int, default=1000, help="Number of items to sample for testing") + parser.add_argument("--wiki_fraction", type=float, default=0.05, help="Fraction of Wikipedia dataset to use (e.g., 0.05 for 5%)") + parser.add_argument("--subtitles_fraction", type=float, + default=0.05, help="Fraction of opensubtitles dataset to use (e.g., 0.05 for 5%)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Load datasets and calculate counts once + print("--- Loading and Counting Datasets ---") + datasets, counts = load_and_count_datasets(wiki_fraction=args.wiki_fraction, subtitles_fraction=args.subtitles_fraction) + print("--- Dataset Loading Finished ---") + + run_train = args.train + run_test = args.test + + # If no arguments provided, run both by default + if not args.train and not args.test: + run_train = True + run_test = True + + if run_train: + print("\n--- Starting Tokenizer Training ---") + train_tokenizer(MODEL_PREFIX, datasets, counts) + print("--- Training Finished ---") + + if run_test: + # Ensure tokenizer file exists if training didn't just run + if not os.path.exists(MODEL_FILE): + print(f"\nError: SentencePiece model file {MODEL_FILE} not found. Cannot run test.") + print("Please run with --train first or ensure the file exists.") + else: + print("\n--- Starting Tokenizer Testing ---") + test_tokenizer(MODEL_FILE, datasets, counts, sample_size=args.sample_size) + print("--- Testing Finished ---") + + print("\nScript finished.") diff --git a/hi.py b/hi.py new file mode 100644 index 0000000..7c68071 --- /dev/null +++ b/hi.py @@ -0,0 +1,357 @@ +import argparse +from math import floor, ceil +import msvcrt +from pythonosc import udp_client +import sentencepiece as spm +import sys +import threading +import time + +TESTS_ENABLED = True + +# 0 = quiet, 1 = verbose, 2 = very verbose +LOG_LEVEL = 0 + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block") + parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator") + parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox") + parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox") + return parser.parse_args() + +def assert_equal(a, b): + err_msg = f"{a} != {b}" + assert a == b, err_msg + +# Turn a whitespace-delimited string into a list of strings no longer than +# `cols`. +# Preferentially breaks strings at whitespace boundaries. Preserves whitespace +# between words, except if that whitespace comes between lines. Breaks words +# longer than `cols` with a hyphen. +def wrap_line(line: str, cols): + # First, split line into alternating chunks of words and whitespace. + def get_sequences(line): + is_space = False + sequences = [] + seq_start = 0 + seq_end = -1 + for i in range(0, len(line)): + if line[i].isspace(): + if is_space: + seq_end = i + continue + # We were looking at text, now we see whitespace. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = True + else: + if not is_space: + seq_end = i + continue + # We were looking at whitespace, now we see text. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = False + sequences.append(line[seq_start:seq_end+1]) + return sequences + if TESTS_ENABLED: + assert_equal(get_sequences("foo"), ["foo"]) + assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + + # Next, greedily construct lines out of those sequences. + # Whitespace gets treated specially. If it would push us over the limit, we + # end the line and drop the whitespace. + sequences = get_sequences(line) + def coalesce_sequences(sequences, cols): + cur_line = "" + lines = [] + for seq in sequences: + if len(cur_line) + len(seq) <= cols: + cur_line += seq + continue + if seq.isspace(): + lines.append(cur_line) + cur_line = "" + continue + if len(cur_line) > 0: + lines.append(cur_line) + # Edge case: text sequence is longer than a line. + while len(seq) > cols: + seq_prefix = seq[0:cols-1] + "-" + seq = seq[cols-1:] + lines.append(seq_prefix) + cur_line = seq + if len(cur_line) > 0: + lines.append(cur_line) + return lines + if TESTS_ENABLED: + assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"]) + assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"]) + + lines = coalesce_sequences(sequences, cols) + + # Next, pad each line with whitespace. + def pad_lines(lines, cols): + for i in range(0, len(lines)): + lines[i] += ' ' * (cols - len(lines[i])) + return lines + if TESTS_ENABLED: + assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "]) + assert_equal(pad_lines(["foo"], 2), ["foo"]) + + return pad_lines(lines, cols) + +def get_blocks(lines, tokenizer, block_width, num_blocks): + if LOG_LEVEL == 2: + print(f"Lines sent to tokenizer: {''.join(lines)}") + tokens = tokenizer.encode_as_ids(''.join(lines)) + if LOG_LEVEL == 2: + print(f"Tokens: {tokens}") + pieces = [] + for tok in tokens: + piece = tokenizer.id_to_piece(tok) + pieces.append(piece) + if LOG_LEVEL == 2: + print(f"Pieces: {pieces}") + + # Group tokens into blocks and pad with empty characters. + # Also get visual pointers - the location where each block will be rendered. + def get_blocks(): + blocks = [] + visual_pointer = 0 + visual_pointers = [] + for i in range(0, ceil(len(tokens) / block_width)): + visual_pointers.append(visual_pointer) + block = [] + for j in range(0, block_width): + if i*block_width + j >= len(tokens): + # Pad block with empty characters. 65535 is a special token. + block += [65535] * (block_width - len(block)) + break + block.append(tokens[i*block_width+j]) + visual_pointer += len(pieces[i*block_width+j]) + blocks.append(block) + return (blocks, visual_pointers) + blocks, visual_pointers = get_blocks() + if LOG_LEVEL == 2: + print(f"Blocks: {blocks}") + print(f"Visual pointers: {visual_pointers}") + + # Set all blocks up to the next `num_blocks` boundary to blank tokens. + # This handles the edge case where a prior message wrote data there which + # is covering up our new data. + def pad_blocks(blocks, visual_pointers): + cur_num_blocks = len(blocks) + num_pad_blocks = num_blocks - cur_num_blocks + for i in range(0, num_pad_blocks): + blocks.append([65535] * block_width) + visual_pointers.append(255) + return blocks, visual_pointers + blocks, visual_pointers = pad_blocks(blocks, visual_pointers) + if LOG_LEVEL == 2: + print(f"Blocks (padded): {blocks}") + print(f"Visual pointers (padded): {visual_pointers}") + + return blocks, visual_pointers + +def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks, + cur_visual_pointers): + diff_indices = [] + diff_blocks = [] + diff_visual_pointers = [] + + for i in range(0, len(cur_blocks)): + if i >= len(prev_blocks): + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + continue + if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]: + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + + return diff_indices, diff_blocks, diff_visual_pointers + +def send_data(osc_client, indices, blocks, visual_pointers): + def split_blocks_by_byte(blocks): + blocks_byte00 = [] + blocks_byte01 = [] + for block in blocks: + block_byte00 = [] + block_byte01 = [] + for datum in block: + block_byte00.append((datum >> 0) & 0xFF) + block_byte01.append((datum >> 8) & 0xFF) + blocks_byte00.append(block_byte00) + blocks_byte01.append(block_byte01) + return blocks_byte00, blocks_byte01 + + blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) + if LOG_LEVEL == 2: + print(f"Blocks (byte 00): {blocks_byte00}") + print(f"Blocks (byte 01): {blocks_byte01}") + + def send_osc(osc_client, addr, data): + #print(f"Sending {data} to {addr}") + osc_client.send_message(addr, data) + + for i in range(0, len(blocks)): + lp_int = indices[i] + lp_param = "_Unigram_Letter_Grid_OSC_Pointer" + addr = "/avatar/parameters/" + lp_param + send_osc(osc_client, addr, lp_int) + + vp_float = (-127.5 + visual_pointers[i]) / 127.5 + vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer" + addr = "/avatar/parameters/" + vp_param + send_osc(osc_client, addr, vp_float) + if LOG_LEVEL == 2: + print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") + for j in range(0, len(blocks[i])): + byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 + byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 + byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00" + byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01" + addr = "/avatar/parameters/" + byte00_param + send_osc(osc_client, addr, byte00_float) + addr = "/avatar/parameters/" + byte01_param + send_osc(osc_client, addr, byte01_float) + time.sleep(0.34) + +def getOscClient(ip = "127.0.0.1", port = 9000): + return udp_client.SimpleUDPClient(ip, port) + +class InputState: + def __init__(self): + # Initialize the known state of the board to empty array. This will cause + # our paging logic to re-send everything the first time around. + self.blocks = [] + self.visual_pointers = [] + pass + +def handle_input(state: InputState, line: str, tokenizer, osc_client, args): + line_wrapped = wrap_line(line, args.cols) + if TESTS_ENABLED: + for line in line_wrapped: + assert_equal(len(line), args.cols) + if LOG_LEVEL == 2: + print(f"Wrapped lines: {line_wrapped}") + blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, + args.block_width, args.num_blocks) + # Wrap visual pointers. This is mostly done just to stay within our + # limited 8-bit precision budget. The shader also has to wrap in order + # to properly display things. + visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ] + indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) + indices = [idx % args.num_blocks for idx in indices] + # Send only one block at a time to make things snappier in interactive use + # case. + # TODO use a continuation (yield) instead of returning. Then we can be a + # little lighter on the cpu. Measurements show that this script is + # already very light but we're clearly wasting a lot of work by + # re-tokenizing the entire input every time we send a block. + if len(indices) == 0: + return + if indices[0] == len(state.blocks): + state.blocks.append(diff_blocks[0]) + state.visual_pointers.append(diff_visual_pointers[0]) + elif indices[0] > len(state.blocks): + print(f"This should never happen!") + sys.exit(1) + else: + state.blocks[indices[0]] = diff_blocks[0] + state.visual_pointers[indices[0]] = diff_visual_pointers[0] + + send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) + +class SharedThreadData: + def __init__(self): + self.word = "" + self.word_lock = threading.Lock() + self.exit_event = threading.Event() + +def osc_thread(shared_data: SharedThreadData, cli_args): + tokenizer = get_tokenizer() + osc_client = getOscClient() + + # Prime the board + print("Priming the board") + input_state = InputState() + handle_input(input_state, "", tokenizer, osc_client, cli_args) + + while not shared_data.exit_event.is_set(): + word_copy = "" + with shared_data.word_lock: + word_copy = shared_data.word + handle_input(input_state, word_copy, tokenizer, osc_client, cli_args) + time.sleep(0.01) + +if __name__ == "__main__": + cli_args = parse_args() + + shared_data = SharedThreadData() + osc_thread = threading.Thread( + target=osc_thread, + args=(shared_data, cli_args)) + osc_thread.start() + + word_is_over = False + local_word = "" + while True: + char_bytes = msvcrt.getch() + + try: + char = char_bytes.decode('utf-8') + if char == '\r' or char == '\n': + word_is_over = True + continue + except UnicodeDecodeError: + print(f"Unsupported character: {char_bytes}") + if char_bytes == b'\x00' or char_bytes == b'\xe0': + special_char = msvcrt.getch() + continue + + if char_bytes == b'\x03': # ctrl+C + break + elif char_bytes == b'\x08': # backspace + with shared_data.word_lock: + shared_data.word = shared_data.word[:-1] + local_word = shared_data.word + elif char_bytes == b'\x0c': # ctrl+L + with shared_data.word_lock: + shared_data.word = "" + local_word = shared_data.word + elif word_is_over: + with shared_data.word_lock: + shared_data.word = char + local_word = shared_data.word + word_is_over = False + else: + with shared_data.word_lock: + shared_data.word += char + local_word = shared_data.word + print(local_word + "_") + shared_data.exit_event.set() + osc_thread.join() + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..104e7b8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +datasets +pillow +python-osc +unidecode +sentencepiece + diff --git a/tokenize_me.py b/tokenize_me.py new file mode 100644 index 0000000..83be290 --- /dev/null +++ b/tokenize_me.py @@ -0,0 +1,28 @@ +import argparse +import sentencepiece as spm + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def parse_args(): + parser = argparse.ArgumentParser(description="Tokenize a given string using a SentencePiece model.") + parser.add_argument("text", type=str, help="The string to tokenize.") + args = parser.parse_args() + return args + +args = parse_args() +tok = get_tokenizer() +tokens = tok.encode_as_pieces(args.text) +print("Tokens:", tokens) + +token_ids = tok.encode_as_ids(args.text) +print("Token IDs:", token_ids) + +# Split each token ID into two 8-bit chunks (high byte, low byte) +byte_pairs = [(tid >> 8, tid & 0xFF) for tid in token_ids] +print("Token ID Byte Pairs:", byte_pairs) -- cgit v1.2.3 From f8e95c0b85288a10f435e0edabf43defa0c303ac Mon Sep 17 00:00:00 2001 From: yum Date: Sat, 17 May 2025 23:41:20 -0700 Subject: Add STT code --- .gitignore | 3 + Images/unigram_lut_for_visualization.png | Bin 0 -> 489395 bytes LICENSE | 7 + README.md | 45 ++- app_config.py | 39 +++ config.yaml | 18 + hi.py | 77 ++-- requirements.txt | 4 + shared_thread_data.py | 9 + stt.py | 581 +++++++++++++++++++++++++++++++ vad.py | 313 +++++++++++++++++ 11 files changed, 1068 insertions(+), 28 deletions(-) create mode 100644 .gitignore create mode 100644 Images/unigram_lut_for_visualization.png create mode 100644 LICENSE create mode 100644 app_config.py create mode 100644 config.yaml create mode 100644 shared_thread_data.py create mode 100644 stt.py create mode 100644 vad.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a102cf0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.*.sw[po] +*.meta + diff --git a/Images/unigram_lut_for_visualization.png b/Images/unigram_lut_for_visualization.png new file mode 100644 index 0000000..622419d Binary files /dev/null and b/Images/unigram_lut_for_visualization.png differ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1ebdcb5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2025 yum_food + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index eaeceea..abb0576 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,40 @@ # Optimized text paging for VRChat +This repo provides code to help you send English text into VRChat. It includes: + +1. Training code to produce an English-language tokenizer of any vocabulary + size. +2. Code to turn your tokenizer into a lookup table for GPU decoding. +3. Unity code to generate an animator to shuttle data from OSC to material + properties. +4. OSC code to talk to your Unity animator. + +To get started, see Quick Start. + +## Quick start + +1. Clone this repo. +2. Clone my toon shader, [2ner](https://github.com/yum-food/2ner). +3. Install Lyuma's av3emulator. +4. Drag STT.prefab onto your avatar's root. +5. Enter play mode. +6. Open PowerShell. + +```bash +$ cd ~ +$ mkdir tmp +$ cd tmp +$ python.exe -m venv venv +$ ./venv/Scripts/Activate.ps1 +$ pushd /path/to/FastTextPaging/ +$ pip3 install -r requirements.txt +$ python3 ./hi.py +``` + +7. Start typing. + +## Design overview + It is sometimes useful to send text data into VRChat, for example for speech-to-text (STT). This is typically done naively, with a "block" of n 8-bit characters\* sent in along with an 8-bit pointer. Since avatars can only @@ -19,7 +54,7 @@ used. Thus to reach a typical reading speed, you need to use (260/4.7) = 55.5 OSC bits. The goal of this module is to get more out of these bits by compressing text over the wire. -## Unigram tokenizer +### Unigram tokenizer Byte pair encoding (BPE) is an encoding scheme frequently used in natural language processing (NLP) contexts. For any language with a fixed character set @@ -127,7 +162,7 @@ bits naive rate bpe rate speedup factor I reserve 39 token slots for sequences of whitespace characters of length 2-40. This helps simplify formatting. To end a line or position text, you can just send in the exact right number of spaces, and a fixed-width font renderer will position things as intended. -## Paging data into shader +### Paging data into shader Sending this data to a shader is pretty simple: @@ -224,7 +259,7 @@ void GetTokens(uint screen_ptr, out uint block_ptr, out uint tokens[BLOCK_WIDTH] } ``` -## GPU decoding +### GPU decoding Now we have to translate the tokens into text. I do this with a texture laid out as follows: @@ -236,6 +271,10 @@ My tokenizer's vocabulary is 65,536 tokens. If we add up the lengths of every to So, the entire vocabulary - length+offset head and content - requires a 32-bit RGBA texture with 232,419 slots. We'll just jam this into a 512x512 texture, at an occupancy ratio of 88.66% (11.34% waste). The total VRAM usage of that lookup table (LUT) is 1 MiB. +![Unigram tokenizer texture](Images/unigram_lut_for_visualization.png) + +*A 64K vocabulary tokenizer I trained on Wikipedia and OpenSubtitles.* + We want to implement this API: ```c diff --git a/app_config.py b/app_config.py new file mode 100644 index 0000000..f911456 --- /dev/null +++ b/app_config.py @@ -0,0 +1,39 @@ +import os +import sys +import typing + +def getConfig(path: str) -> typing.Dict[str, typing.Union[str, float, int, bool]]: + # Helper functions to detect and convert the type + def is_int(value: str) -> bool: + try: + int(value) + return True + except ValueError: + return False + + def is_float(value: str) -> bool: + try: + float(value) + return True + except ValueError: + return False + + def convert_value(key: str, value: str): + if key.startswith(("enable_", "remove_", "use_", "clear_")): + return bool(int(value)) + elif is_int(value): + return int(value) + elif is_float(value): + return float(value) + else: + return value + + config = {} + with open(path, 'r') as file: + for line in file: + key_value = line.strip().split(": ", maxsplit=1) + key = key_value[0] + value = key_value[1] if len(key_value) > 1 else "" + config[key] = convert_value(key, value.strip()) + return config + diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..164b4e6 --- /dev/null +++ b/config.yaml @@ -0,0 +1,18 @@ +compute_type: int8 +enable_debug_mode: 0 +enable_previews: 1 +language: english +gpu_idx: 0 +max_speech_duration_s: 10 +min_silence_duration_ms: 250 +microphone: motu +model: turbo +reset_after_silence_s: 15 +transcription_loop_delay_ms: 100 +use_cpu: 0 + +block_width: 2 +num_blocks: 40 +rows: 10 +cols: 24 + diff --git a/hi.py b/hi.py index 7c68071..0129958 100644 --- a/hi.py +++ b/hi.py @@ -1,8 +1,11 @@ +import app_config import argparse from math import floor, ceil import msvcrt from pythonosc import udp_client import sentencepiece as spm +from shared_thread_data import SharedThreadData +import stt import sys import threading import time @@ -22,10 +25,7 @@ def get_tokenizer(): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block") - parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator") - parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox") - parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox") + parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True) return parser.parse_args() def assert_equal(a, b): @@ -244,27 +244,48 @@ def getOscClient(ip = "127.0.0.1", port = 9000): class InputState: def __init__(self): + self.page = 0 # Initialize the known state of the board to empty array. This will cause # our paging logic to re-send everything the first time around. self.blocks = [] self.visual_pointers = [] pass -def handle_input(state: InputState, line: str, tokenizer, osc_client, args): - line_wrapped = wrap_line(line, args.cols) +def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): + line_wrapped = wrap_line(line, cfg["cols"]) if TESTS_ENABLED: for line in line_wrapped: - assert_equal(len(line), args.cols) + assert_equal(len(line), cfg["cols"]) if LOG_LEVEL == 2: print(f"Wrapped lines: {line_wrapped}") + + # Get several blank lines whenever we roll over. + # It's better for the reader to have some continuity when the board pages + # over. If we simply replaced the entire screen, it would be harder to + # understand. + line_rollover = cfg["rows"] - 2 + blank_line = ' ' * cfg["cols"] + # We show a full page, then only `line_rollover` additional lines per page. + end_ptr = cfg["rows"] + which_page = 0 + while end_ptr < len(line_wrapped): + end_ptr += line_rollover + which_page += 1 + if state.page != which_page: + state.blocks = [] + state.visual_pointers = [] + state.page = which_page + line_wrapped = line_wrapped[end_ptr-cfg["rows"]:] + + # Get blocks and visual pointers. blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, - args.block_width, args.num_blocks) - # Wrap visual pointers. This is mostly done just to stay within our - # limited 8-bit precision budget. The shader also has to wrap in order - # to properly display things. - visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ] + cfg["block_width"], cfg["num_blocks"]) + + # Note that because we only send one page of data at a time, we don't have + # to worry about wrapping visual pointers! We will basically never run out + # of space. indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) - indices = [idx % args.num_blocks for idx in indices] + indices = [idx % cfg["num_blocks"] for idx in indices] # Send only one block at a time to make things snappier in interactive use # case. # TODO use a continuation (yield) instead of returning. Then we can be a @@ -285,41 +306,46 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, args): send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) -class SharedThreadData: - def __init__(self): - self.word = "" - self.word_lock = threading.Lock() - self.exit_event = threading.Event() - -def osc_thread(shared_data: SharedThreadData, cli_args): +def osc_thread(shared_data: SharedThreadData): tokenizer = get_tokenizer() osc_client = getOscClient() # Prime the board print("Priming the board") input_state = InputState() - handle_input(input_state, "", tokenizer, osc_client, cli_args) + handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) while not shared_data.exit_event.is_set(): word_copy = "" with shared_data.word_lock: word_copy = shared_data.word - handle_input(input_state, word_copy, tokenizer, osc_client, cli_args) + handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) time.sleep(0.01) if __name__ == "__main__": cli_args = parse_args() - - shared_data = SharedThreadData() + cfg = app_config.getConfig(cli_args.config) + shared_data = SharedThreadData(cfg) osc_thread = threading.Thread( target=osc_thread, - args=(shared_data, cli_args)) + args=(shared_data,)) osc_thread.start() + transcribe_thread = threading.Thread( + target=stt.transcriptionThread, + args=(shared_data,)) + transcribe_thread.start() + word_is_over = False local_word = "" while True: char_bytes = msvcrt.getch() + if char_bytes == b'\x03': # ctrl+C + break + + time.sleep(0.1) + continue + try: char = char_bytes.decode('utf-8') @@ -354,4 +380,5 @@ if __name__ == "__main__": print(local_word + "_") shared_data.exit_event.set() osc_thread.join() + transcribe_thread.join() diff --git a/requirements.txt b/requirements.txt index 104e7b8..1043fae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,9 @@ datasets +faster-whisper +langcodes pillow +pyaudio +pydub python-osc unidecode sentencepiece diff --git a/shared_thread_data.py b/shared_thread_data.py new file mode 100644 index 0000000..ba0a419 --- /dev/null +++ b/shared_thread_data.py @@ -0,0 +1,9 @@ +import threading + +class SharedThreadData: + def __init__(self, cfg): + self.word = "" + self.word_lock = threading.Lock() + self.exit_event = threading.Event() + self.cfg = cfg + diff --git a/stt.py b/stt.py new file mode 100644 index 0000000..34ef2e9 --- /dev/null +++ b/stt.py @@ -0,0 +1,581 @@ +from faster_whisper import WhisperModel +import langcodes +import numpy as np +import os +import pyaudio +from pydub import AudioSegment +from shared_thread_data import SharedThreadData +import sys +import time +import typing +import vad + +class AudioStream(): + FORMAT = pyaudio.paInt16 + # Size of each frame (audio sample), in bytes. If you change FORMAT, make + # sure this stays up to date! + FRAME_SZ = 2 + # Frames per second. + FPS = 16000 + CHANNELS = 1 + def __init__(self): + pass + + def getSamples(self) -> bytes: + raise NotImplementedError("getSamples is not implemented!") + +class MicStream(AudioStream): + CHUNK_SZ = 1024 + + def __init__(self, which_mic: str): + self.p = pyaudio.PyAudio() + self.stream = None + self.sample_rate = None + # Each time pyaudio gives us audio data, it's in the form of a chunk of + # samples. We keep these in a list to keep the audio callback as light + # as possible. Whenever downstream layers want data, we collapse the + # list into a single array of data (a bytes object). + self.chunks = [] + # If set, incoming frames are simply discarded. + self.paused = False + + print(f"Finding mic {which_mic}", file=sys.stderr) + self.dumpMicDevices() + + got_match = False + device_index = -1 + if which_mic == "index": + target_str = "Digital Audio Interface" + elif which_mic == "focusrite": + target_str = "Focusrite" + elif which_mic == "motu": + target_str = "In 1-2 (MOTU M Series)" + elif which_mic == "beyond": + target_str = "Microphone (Beyond)" + else: + print(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID", file=sys.stderr) + device_index = int(which_mic) + got_match = True + if not got_match: + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + if target_str in device_name: + print(f"Got matching mic: {device_name}", + file=sys.stderr) + device_index = i + got_match = True + break + if not got_match: + raise KeyError(f"Mic {which_mic} not found") + + info = self.p.get_device_info_by_host_api_device_index(0, device_index) + print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + self.sample_rate = int(info['defaultSampleRate']) + print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + + self.stream = self.p.open( + rate=self.sample_rate, + channels=AudioStream.CHANNELS, + format=AudioStream.FORMAT, + input=True, + frames_per_buffer=MicStream.CHUNK_SZ, + input_device_index=device_index, + stream_callback=self.onAudioFramesAvailable) + + self.stream.start_stream() + + AudioStream.__init__(self) + + def pause(self, state: bool = True): + self.paused = state + + def dumpMicDevices(self): + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + print("Input Device id ", i, " - ", device_name) + + def onAudioFramesAvailable(self, + frames, + frame_count, + time_info, + status_flags): + if self.paused: + # Don't literally pause, just start returning silence. This allows + # the `min_segment_age_s` check to work while paused. + n_frames = int(frame_count * AudioStream.FPS / + float(self.sample_rate)) + self.chunks.append(np.zeros(n_frames, + dtype=np.int16).tobytes()) + return (frames, pyaudio.paContinue) + + decimated = b'' + # In pyaudio, a `frame` is a single sample of audio data. + frame_len = AudioStream.FRAME_SZ + next_frame = 0.0 + # The mic probably has a higher sample rate than Whisper wants, so + # decrease the sample rate by dropping samples. Note that this + # algorithm only works if the mic's rate is higher than whisper's + # expected rate. + keep_every = float(self.sample_rate) / AudioStream.FPS + for i in range(frame_count): + if i >= next_frame: + decimated += frames[i*frame_len:(i+1)*frame_len] + next_frame += keep_every + self.chunks.append(decimated) + + return (frames, pyaudio.paContinue) + + # Get audio data and the corresponding timestamp. + def getSamples(self) -> bytes: + chunks = self.chunks + self.chunks = [] + result = b''.join(chunks) + return result + +class AudioCollector: + def __init__(self, stream: AudioStream): + self.stream = stream + self.frames = b'' + # Note: by design, this is the only spot where we anchor our timestamps + # against the real world. This is done to make it possible to profile + # test cases which read from disk (at much faster than real speed) in + # the same way that we profile real-time data. + self.wall_ts = time.time() + + def getAudio(self) -> bytes: + frames = self.stream.getSamples() + if frames: + self.frames += frames + return self.frames + + def dropAudioPrefix(self, dur_s: float) -> bytes: + n_bytes = int(dur_s * AudioStream.FPS) * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] + self.frames = self.frames[n_bytes:] + self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS + return cut_portion + + def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes: + n_bytes = dur_frames * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] + self.frames = self.frames[n_bytes:] + self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS + return cut_portion + + def keepLast(self, dur_s: float) -> bytes: + drop_len = max(0, self.duration() - dur_s) + return self.dropAudioPrefix(drop_len) + + def dropAudio(self): + self.wall_ts += self.duration() + cut_portion = self.frames + self.frames = b'' + return cut_portion + + def duration(self): + return len(self.frames) / (AudioStream.FPS * self.stream.FRAME_SZ) + + def begin(self): + return self.wall_ts + + def now(self): + return self.begin() + self.duration() + +class AudioCollectorFilter: + def __init__(self, parent: AudioCollector): + self.parent = parent + + def getAudio(self) -> bytes: + return self.parent.getAudio() + def dropAudioPrefix(self, dur_s: float): + return self.parent.dropAudioPrefix(dur_s) + def dropAudioPrefixByFrames(self, dur_frames: int): + return self.parent.dropAudioPrefixByFrames(dur_frames) + def keepLast(self, dur_s): + return self.parent.keepLast(dur_s) + def dropAudio(self): + return self.parent.dropAudio() + def duration(self): + return self.parent.duration() + def begin(self): + return self.parent.begin() + def now(self): + return self.parent.now() + +# Audio collector that enforces a minimum length on its audio data. +class LengthEnforcingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, min_duration_s: float): + AudioCollectorFilter.__init__(self, parent) + self.min_duration_s = min_duration_s + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + min_duration_frames = int(self.min_duration_s * AudioStream.FPS) + pad_len_frames = max(0, min_duration_frames - int(len(audio) / + AudioStream.FRAME_SZ)) + pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes() + return pad + audio + +class NormalizingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + audio = audio.normalize() + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + +class CompressingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + # subtle compression has a slight positive effect on my benchmark + audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0) + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + +class AudioSegmenter: + def __init__(self, + min_silence_ms=250, + max_speech_s=5): + self.vad_options = vad.VadOptions( + min_silence_duration_ms=min_silence_ms, + max_speech_duration_s=max_speech_s) + pass + + def segmentAudio(self, audio: bytes): + audio = np.frombuffer(audio, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + return vad.get_speech_timestamps(audio, vad_options=self.vad_options) + + # Returns the stable cutoff (if any) and whether there are any segments. + def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]: + min_delta_frames = int((self.vad_options.min_silence_duration_ms * + AudioStream.FPS) / 1000.0) + cutoff = None + + last_end = None + segments = self.segmentAudio(audio) + + for i in range(len(segments)): + s = segments[i] + #print(f"s: {s}") + #print(f"last_end: {last_end}") + + if last_end: + delta_frames = s['start'] - last_end + #print(f"delta frames: {delta_frames}") + if delta_frames > min_delta_frames: + cutoff = s['start'] + else: + last_end = s['end'] + + if i == len(segments) - 1: + now = int(len(audio) / AudioStream.FRAME_SZ) + #print(f"now: {now}") + #print(f"min d: {min_delta_frames}") + delta_frames = now - s['end'] + if delta_frames > min_delta_frames: + cutoff = now - int(min_delta_frames / 2) + + return (cutoff, len(segments) > 0) + +# A segment of transcribed audio. `start_ts` and `end_ts` are floating point +# number of seconds since the beginning of audio data. +class Segment: + def __init__(self, + transcript: str, + start_ts: float, + end_ts: float, + wall_ts: float, + avg_logprob: float, + no_speech_prob: float, + compression_ratio: float): + self.transcript = transcript + # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. + self.start_ts = start_ts + self.end_ts = end_ts + # wall_ts is the time.time() at which the oldest audio sample leading + # to this transcript was collected. + self.wall_ts = wall_ts + self.avg_logprob = avg_logprob + self.no_speech_prob = no_speech_prob + self.compression_ratio = compression_ratio + + def __str__(self): + ts = f"(ts: {self.start_ts}-{self.end_ts}) " + + wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " + + no_speech = f"(no_speech: {self.no_speech_prob}) " + avg_logprob = f"(avg_logprob: {self.avg_logprob}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + +class Whisper: + def __init__(self, + collector: AudioCollector, + cfg: typing.Dict): + self.collector = collector + self.model = None + self.cfg = cfg + + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + parent_dir = os.path.dirname(my_dir) + + model_str = cfg["model"] + model_root = os.path.join(parent_dir, "Models", + os.path.normpath(model_str)) + print(f"Model {cfg['model']} will be saved to {model_root}", + file=sys.stderr) + + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + + already_downloaded = os.path.exists(model_root) + + self.model = WhisperModel(model_str, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = cfg["compute_type"], + download_root = model_root, + local_files_only = already_downloaded) + + def transcribe(self, frames: bytes = None) -> typing.List[Segment]: + if frames is None: + frames = self.collector.getAudio() + # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on + # [-1, 1]. + audio = np.frombuffer(frames, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + + t0 = time.time() + segments, info = self.model.transcribe( + audio, + language = langcodes.find(self.cfg["language"]).language, + vad_filter = True, + temperature=0.0, + without_timestamps = False) + res = [] + for s in segments: + # Manual touchup. I see a decent number of hallucinations sneaking + # in with high `no_speech_prob` and modest `avg_logprob`. + if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 1) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + continue + # Another touchup targeted at the vexatious "thanks for watching!" + # hallucination. This triggers a lot when listening to + # instrumental/electronic music. + if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 2) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + continue + if self.cfg["enable_debug_mode"]: + print(f"s get: {s}") + if s.avg_logprob < -1.0: + continue + if s.compression_ratio > 2.4: + continue + res.append(Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + t1 = time.time() + if self.cfg["enable_debug_mode"]: + print(f"Transcription latency (s): {t1 - t0}") + return res + +class TranscriptCommit: + def __init__(self, + delta: str, + preview: str, + latency_s: float = None, + thresh_at_commit: int = None, + audio: bytes = None, + duration_s: float = None, + start_ts: float = None): + self.delta = delta + self.preview = preview + self.latency_s = latency_s + self.thresh_at_commit = thresh_at_commit + self.audio = audio + # Time at which the commit is generated + self.ts = time.time() + # Time corresponding to the start of the segment + self.start_ts = start_ts + # The duration of the audio segment, in seconds. + self.duration_s = duration_s + + +class VadCommitter: + def __init__(self, + cfg: typing.Dict, + collector: AudioCollector, + whisper: Whisper, + segmenter: AudioSegmenter): + self.cfg = cfg + self.collector = collector + self.whisper = whisper + self.segmenter = segmenter + + def getDelta(self) -> TranscriptCommit: + audio = self.collector.getAudio() + stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio) + + delta = "" + commit_audio = None + latency_s = None + duration_s = self.collector.duration() + start_ts = self.collector.begin() + + if has_audio and stable_cutoff: + #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr) + latency_s = self.collector.now() - self.collector.begin() + duration_s = stable_cutoff / AudioStream.FPS + start_ts = self.collector.begin() + commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff) + + segments = self.whisper.transcribe(commit_audio) + delta = ''.join(s.transcript for s in segments) + audio = self.collector.getAudio() + if self.cfg["enable_debug_mode"]: + for s in segments: + print(f"commit segment: {s}", file=sys.stderr) + print(f"delta get: {delta}", file=sys.stderr) + + if False: + ts = datetime.fromtimestamp(self.collector.now() - latency_s) + filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav" + saveAudio(commit_audio, filename) + + preview = "" + if self.cfg["enable_previews"] and has_audio: + segments = self.whisper.transcribe(audio) + preview = "".join(s.transcript for s in segments) + + if not has_audio: + #print("VAD detects no audio, skip transcription", file=sys.stderr) + self.collector.keepLast(1.0) + + return TranscriptCommit( + delta.strip(), + preview.strip(), + latency_s, + audio=audio, + duration_s=duration_s, + start_ts=start_ts) + +def transcriptionThread(shared_data: SharedThreadData): + last_stable_commit = None + + stream = MicStream(shared_data.cfg["microphone"]) + collector = AudioCollector(stream) + collector = NormalizingAudioCollector(collector) + collector = CompressingAudioCollector(collector) + whisper = Whisper(collector, shared_data.cfg) + segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], + max_speech_s=shared_data.cfg["max_speech_duration_s"]) + committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + + transcript = "" + preview = "" + + while not shared_data.exit_event.is_set(): + time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); + + op = None + + commit = committer.getDelta() + + if len(commit.delta) > 0 or len(commit.preview) > 0: + # Avoid re-sending text after long pauses. User controls the length + # of the pause in the UI. + if shared_data.cfg["reset_after_silence_s"] > 0: + silence_duration = 0 + if last_stable_commit: + last_commit_end_ts = \ + last_stable_commit.start_ts + \ + last_stable_commit.duration_s + silence_duration = commit.start_ts - last_commit_end_ts + if silence_duration > shared_data.cfg["reset_after_silence_s"]: + print(f"Resetting transcript after {silence_duration}-second " + "silence", file=sys.stderr) + transcript = "" + preview = "" + if commit.delta: + last_stable_commit = commit + + # Hard-cap displayed transcript length at 4k characters to prevent + # runaway memory use in UI. Keep the full transcript to avoid + # breaking OSC pager. + transcript = transcript[-4096:] + def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + transcript = join_segments(transcript, commit.delta) + preview = commit.preview + + try: + print(f"Transcript: {transcript}") + except UnicodeEncodeError: + print("Failed to encode transcript - discarding delta", + file=sys.stderr) + continue + try: + print(f"Preview: {preview}") + except UnicodeEncodeError: + print("Failed to encode preview - discarding", file=sys.stderr) + + with shared_data.word_lock: + shared_data.word = join_segments(transcript, preview) + + if shared_data.cfg["enable_debug_mode"]: + print(f"commit latency: {commit.latency_s}", file=sys.stderr) + print(f"commit thresh: {commit.thresh_at_commit}", + file=sys.stderr) + + if len(transcript) > 0 and \ + (not transcript.endswith(' ')) and \ + (not commit.delta.startswith(' ')): + commit.delta = ' ' + commit.delta + if len(commit.delta) > 0 and \ + (not commit.delta.endswith(' ')) and \ + (not commit.preview.startswith(' ')): + commit.preview = ' ' + commit.preview + diff --git a/vad.py b/vad.py new file mode 100644 index 0000000..10a72d3 --- /dev/null +++ b/vad.py @@ -0,0 +1,313 @@ +# MIT License +# +# Copyright (c) 2023 Guillaume Klein +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import bisect +import functools +import os +import warnings + +from typing import List, NamedTuple, Optional + +import numpy as np + + +# The code below is adapted from https://github.com/snakers4/silero-vad. +class VadOptions(NamedTuple): + """VAD options. + + Attributes: + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model performance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +def get_speech_timestamps( + audio: np.ndarray, + vad_options: Optional[VadOptions] = None, + **kwargs, +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + if vad_options is None: + vad_options = VadOptions(**kwargs) + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + model = get_vad_model() + state = model.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[current_start_sample : current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = model(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + +def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + + +class SpeechTimestampsMap: + """Helper class to restore original speech timestamps.""" + + def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): + self.sampling_rate = sampling_rate + self.time_precision = time_precision + self.chunk_end_sample = [] + self.total_silence_before = [] + + previous_end = 0 + silent_samples = 0 + + for chunk in chunks: + silent_samples += chunk["start"] - previous_end + previous_end = chunk["end"] + + self.chunk_end_sample.append(chunk["end"] - silent_samples) + self.total_silence_before.append(silent_samples / sampling_rate) + + def get_original_time( + self, + time: float, + chunk_index: Optional[int] = None, + ) -> float: + if chunk_index is None: + chunk_index = self.get_chunk_index(time) + + total_silence_before = self.total_silence_before[chunk_index] + return round(total_silence_before + time, self.time_precision) + + def get_chunk_index(self, time: float) -> int: + sample = int(time * self.sampling_rate) + return min( + bisect.bisect(self.chunk_end_sample, sample), + len(self.chunk_end_sample) - 1, + ) + + +@functools.lru_cache +def get_vad_model(): + """Returns the VAD model instance.""" + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + path = os.path.join(my_dir, "Models/silero_vad.onnx") + return SileroVADModel(path) + + +class SileroVADModel: + def __init__(self, path): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state -- cgit v1.2.3 From 0ebc79354ace812731a5c9a0a670cecd1ea941d7 Mon Sep 17 00:00:00 2001 From: yum Date: Thu, 29 May 2025 15:03:06 -0700 Subject: Move core app logic into folder --- app/app_config.py | 39 ++++ app/hi.py | 384 ++++++++++++++++++++++++++++++ app/requirements.txt | 7 + app/shared_thread_data.py | 9 + app/stt.py | 581 ++++++++++++++++++++++++++++++++++++++++++++++ app/vad.py | 313 +++++++++++++++++++++++++ app_config.py | 39 ---- hi.py | 384 ------------------------------ requirements.txt | 10 - shared_thread_data.py | 9 - stt.py | 581 ---------------------------------------------- vad.py | 313 ------------------------- 12 files changed, 1333 insertions(+), 1336 deletions(-) create mode 100644 app/app_config.py create mode 100644 app/hi.py create mode 100644 app/requirements.txt create mode 100644 app/shared_thread_data.py create mode 100644 app/stt.py create mode 100644 app/vad.py delete mode 100644 app_config.py delete mode 100644 hi.py delete mode 100644 requirements.txt delete mode 100644 shared_thread_data.py delete mode 100644 stt.py delete mode 100644 vad.py diff --git a/app/app_config.py b/app/app_config.py new file mode 100644 index 0000000..f911456 --- /dev/null +++ b/app/app_config.py @@ -0,0 +1,39 @@ +import os +import sys +import typing + +def getConfig(path: str) -> typing.Dict[str, typing.Union[str, float, int, bool]]: + # Helper functions to detect and convert the type + def is_int(value: str) -> bool: + try: + int(value) + return True + except ValueError: + return False + + def is_float(value: str) -> bool: + try: + float(value) + return True + except ValueError: + return False + + def convert_value(key: str, value: str): + if key.startswith(("enable_", "remove_", "use_", "clear_")): + return bool(int(value)) + elif is_int(value): + return int(value) + elif is_float(value): + return float(value) + else: + return value + + config = {} + with open(path, 'r') as file: + for line in file: + key_value = line.strip().split(": ", maxsplit=1) + key = key_value[0] + value = key_value[1] if len(key_value) > 1 else "" + config[key] = convert_value(key, value.strip()) + return config + diff --git a/app/hi.py b/app/hi.py new file mode 100644 index 0000000..0129958 --- /dev/null +++ b/app/hi.py @@ -0,0 +1,384 @@ +import app_config +import argparse +from math import floor, ceil +import msvcrt +from pythonosc import udp_client +import sentencepiece as spm +from shared_thread_data import SharedThreadData +import stt +import sys +import threading +import time + +TESTS_ENABLED = True + +# 0 = quiet, 1 = verbose, 2 = very verbose +LOG_LEVEL = 0 + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True) + return parser.parse_args() + +def assert_equal(a, b): + err_msg = f"{a} != {b}" + assert a == b, err_msg + +# Turn a whitespace-delimited string into a list of strings no longer than +# `cols`. +# Preferentially breaks strings at whitespace boundaries. Preserves whitespace +# between words, except if that whitespace comes between lines. Breaks words +# longer than `cols` with a hyphen. +def wrap_line(line: str, cols): + # First, split line into alternating chunks of words and whitespace. + def get_sequences(line): + is_space = False + sequences = [] + seq_start = 0 + seq_end = -1 + for i in range(0, len(line)): + if line[i].isspace(): + if is_space: + seq_end = i + continue + # We were looking at text, now we see whitespace. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = True + else: + if not is_space: + seq_end = i + continue + # We were looking at whitespace, now we see text. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = False + sequences.append(line[seq_start:seq_end+1]) + return sequences + if TESTS_ENABLED: + assert_equal(get_sequences("foo"), ["foo"]) + assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + + # Next, greedily construct lines out of those sequences. + # Whitespace gets treated specially. If it would push us over the limit, we + # end the line and drop the whitespace. + sequences = get_sequences(line) + def coalesce_sequences(sequences, cols): + cur_line = "" + lines = [] + for seq in sequences: + if len(cur_line) + len(seq) <= cols: + cur_line += seq + continue + if seq.isspace(): + lines.append(cur_line) + cur_line = "" + continue + if len(cur_line) > 0: + lines.append(cur_line) + # Edge case: text sequence is longer than a line. + while len(seq) > cols: + seq_prefix = seq[0:cols-1] + "-" + seq = seq[cols-1:] + lines.append(seq_prefix) + cur_line = seq + if len(cur_line) > 0: + lines.append(cur_line) + return lines + if TESTS_ENABLED: + assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"]) + assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"]) + + lines = coalesce_sequences(sequences, cols) + + # Next, pad each line with whitespace. + def pad_lines(lines, cols): + for i in range(0, len(lines)): + lines[i] += ' ' * (cols - len(lines[i])) + return lines + if TESTS_ENABLED: + assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "]) + assert_equal(pad_lines(["foo"], 2), ["foo"]) + + return pad_lines(lines, cols) + +def get_blocks(lines, tokenizer, block_width, num_blocks): + if LOG_LEVEL == 2: + print(f"Lines sent to tokenizer: {''.join(lines)}") + tokens = tokenizer.encode_as_ids(''.join(lines)) + if LOG_LEVEL == 2: + print(f"Tokens: {tokens}") + pieces = [] + for tok in tokens: + piece = tokenizer.id_to_piece(tok) + pieces.append(piece) + if LOG_LEVEL == 2: + print(f"Pieces: {pieces}") + + # Group tokens into blocks and pad with empty characters. + # Also get visual pointers - the location where each block will be rendered. + def get_blocks(): + blocks = [] + visual_pointer = 0 + visual_pointers = [] + for i in range(0, ceil(len(tokens) / block_width)): + visual_pointers.append(visual_pointer) + block = [] + for j in range(0, block_width): + if i*block_width + j >= len(tokens): + # Pad block with empty characters. 65535 is a special token. + block += [65535] * (block_width - len(block)) + break + block.append(tokens[i*block_width+j]) + visual_pointer += len(pieces[i*block_width+j]) + blocks.append(block) + return (blocks, visual_pointers) + blocks, visual_pointers = get_blocks() + if LOG_LEVEL == 2: + print(f"Blocks: {blocks}") + print(f"Visual pointers: {visual_pointers}") + + # Set all blocks up to the next `num_blocks` boundary to blank tokens. + # This handles the edge case where a prior message wrote data there which + # is covering up our new data. + def pad_blocks(blocks, visual_pointers): + cur_num_blocks = len(blocks) + num_pad_blocks = num_blocks - cur_num_blocks + for i in range(0, num_pad_blocks): + blocks.append([65535] * block_width) + visual_pointers.append(255) + return blocks, visual_pointers + blocks, visual_pointers = pad_blocks(blocks, visual_pointers) + if LOG_LEVEL == 2: + print(f"Blocks (padded): {blocks}") + print(f"Visual pointers (padded): {visual_pointers}") + + return blocks, visual_pointers + +def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks, + cur_visual_pointers): + diff_indices = [] + diff_blocks = [] + diff_visual_pointers = [] + + for i in range(0, len(cur_blocks)): + if i >= len(prev_blocks): + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + continue + if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]: + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + + return diff_indices, diff_blocks, diff_visual_pointers + +def send_data(osc_client, indices, blocks, visual_pointers): + def split_blocks_by_byte(blocks): + blocks_byte00 = [] + blocks_byte01 = [] + for block in blocks: + block_byte00 = [] + block_byte01 = [] + for datum in block: + block_byte00.append((datum >> 0) & 0xFF) + block_byte01.append((datum >> 8) & 0xFF) + blocks_byte00.append(block_byte00) + blocks_byte01.append(block_byte01) + return blocks_byte00, blocks_byte01 + + blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) + if LOG_LEVEL == 2: + print(f"Blocks (byte 00): {blocks_byte00}") + print(f"Blocks (byte 01): {blocks_byte01}") + + def send_osc(osc_client, addr, data): + #print(f"Sending {data} to {addr}") + osc_client.send_message(addr, data) + + for i in range(0, len(blocks)): + lp_int = indices[i] + lp_param = "_Unigram_Letter_Grid_OSC_Pointer" + addr = "/avatar/parameters/" + lp_param + send_osc(osc_client, addr, lp_int) + + vp_float = (-127.5 + visual_pointers[i]) / 127.5 + vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer" + addr = "/avatar/parameters/" + vp_param + send_osc(osc_client, addr, vp_float) + if LOG_LEVEL == 2: + print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") + for j in range(0, len(blocks[i])): + byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 + byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 + byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00" + byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01" + addr = "/avatar/parameters/" + byte00_param + send_osc(osc_client, addr, byte00_float) + addr = "/avatar/parameters/" + byte01_param + send_osc(osc_client, addr, byte01_float) + time.sleep(0.34) + +def getOscClient(ip = "127.0.0.1", port = 9000): + return udp_client.SimpleUDPClient(ip, port) + +class InputState: + def __init__(self): + self.page = 0 + # Initialize the known state of the board to empty array. This will cause + # our paging logic to re-send everything the first time around. + self.blocks = [] + self.visual_pointers = [] + pass + +def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): + line_wrapped = wrap_line(line, cfg["cols"]) + if TESTS_ENABLED: + for line in line_wrapped: + assert_equal(len(line), cfg["cols"]) + if LOG_LEVEL == 2: + print(f"Wrapped lines: {line_wrapped}") + + # Get several blank lines whenever we roll over. + # It's better for the reader to have some continuity when the board pages + # over. If we simply replaced the entire screen, it would be harder to + # understand. + line_rollover = cfg["rows"] - 2 + blank_line = ' ' * cfg["cols"] + # We show a full page, then only `line_rollover` additional lines per page. + end_ptr = cfg["rows"] + which_page = 0 + while end_ptr < len(line_wrapped): + end_ptr += line_rollover + which_page += 1 + if state.page != which_page: + state.blocks = [] + state.visual_pointers = [] + state.page = which_page + line_wrapped = line_wrapped[end_ptr-cfg["rows"]:] + + # Get blocks and visual pointers. + blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, + cfg["block_width"], cfg["num_blocks"]) + + # Note that because we only send one page of data at a time, we don't have + # to worry about wrapping visual pointers! We will basically never run out + # of space. + indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) + indices = [idx % cfg["num_blocks"] for idx in indices] + # Send only one block at a time to make things snappier in interactive use + # case. + # TODO use a continuation (yield) instead of returning. Then we can be a + # little lighter on the cpu. Measurements show that this script is + # already very light but we're clearly wasting a lot of work by + # re-tokenizing the entire input every time we send a block. + if len(indices) == 0: + return + if indices[0] == len(state.blocks): + state.blocks.append(diff_blocks[0]) + state.visual_pointers.append(diff_visual_pointers[0]) + elif indices[0] > len(state.blocks): + print(f"This should never happen!") + sys.exit(1) + else: + state.blocks[indices[0]] = diff_blocks[0] + state.visual_pointers[indices[0]] = diff_visual_pointers[0] + + send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) + +def osc_thread(shared_data: SharedThreadData): + tokenizer = get_tokenizer() + osc_client = getOscClient() + + # Prime the board + print("Priming the board") + input_state = InputState() + handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) + + while not shared_data.exit_event.is_set(): + word_copy = "" + with shared_data.word_lock: + word_copy = shared_data.word + handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) + time.sleep(0.01) + +if __name__ == "__main__": + cli_args = parse_args() + cfg = app_config.getConfig(cli_args.config) + shared_data = SharedThreadData(cfg) + osc_thread = threading.Thread( + target=osc_thread, + args=(shared_data,)) + osc_thread.start() + + transcribe_thread = threading.Thread( + target=stt.transcriptionThread, + args=(shared_data,)) + transcribe_thread.start() + + word_is_over = False + local_word = "" + while True: + char_bytes = msvcrt.getch() + if char_bytes == b'\x03': # ctrl+C + break + + time.sleep(0.1) + continue + + + try: + char = char_bytes.decode('utf-8') + if char == '\r' or char == '\n': + word_is_over = True + continue + except UnicodeDecodeError: + print(f"Unsupported character: {char_bytes}") + if char_bytes == b'\x00' or char_bytes == b'\xe0': + special_char = msvcrt.getch() + continue + + if char_bytes == b'\x03': # ctrl+C + break + elif char_bytes == b'\x08': # backspace + with shared_data.word_lock: + shared_data.word = shared_data.word[:-1] + local_word = shared_data.word + elif char_bytes == b'\x0c': # ctrl+L + with shared_data.word_lock: + shared_data.word = "" + local_word = shared_data.word + elif word_is_over: + with shared_data.word_lock: + shared_data.word = char + local_word = shared_data.word + word_is_over = False + else: + with shared_data.word_lock: + shared_data.word += char + local_word = shared_data.word + print(local_word + "_") + shared_data.exit_event.set() + osc_thread.join() + transcribe_thread.join() + diff --git a/app/requirements.txt b/app/requirements.txt new file mode 100644 index 0000000..4e79312 --- /dev/null +++ b/app/requirements.txt @@ -0,0 +1,7 @@ +faster-whisper +langcodes +pyaudio +pydub +python-osc +sentencepiece + diff --git a/app/shared_thread_data.py b/app/shared_thread_data.py new file mode 100644 index 0000000..ba0a419 --- /dev/null +++ b/app/shared_thread_data.py @@ -0,0 +1,9 @@ +import threading + +class SharedThreadData: + def __init__(self, cfg): + self.word = "" + self.word_lock = threading.Lock() + self.exit_event = threading.Event() + self.cfg = cfg + diff --git a/app/stt.py b/app/stt.py new file mode 100644 index 0000000..34ef2e9 --- /dev/null +++ b/app/stt.py @@ -0,0 +1,581 @@ +from faster_whisper import WhisperModel +import langcodes +import numpy as np +import os +import pyaudio +from pydub import AudioSegment +from shared_thread_data import SharedThreadData +import sys +import time +import typing +import vad + +class AudioStream(): + FORMAT = pyaudio.paInt16 + # Size of each frame (audio sample), in bytes. If you change FORMAT, make + # sure this stays up to date! + FRAME_SZ = 2 + # Frames per second. + FPS = 16000 + CHANNELS = 1 + def __init__(self): + pass + + def getSamples(self) -> bytes: + raise NotImplementedError("getSamples is not implemented!") + +class MicStream(AudioStream): + CHUNK_SZ = 1024 + + def __init__(self, which_mic: str): + self.p = pyaudio.PyAudio() + self.stream = None + self.sample_rate = None + # Each time pyaudio gives us audio data, it's in the form of a chunk of + # samples. We keep these in a list to keep the audio callback as light + # as possible. Whenever downstream layers want data, we collapse the + # list into a single array of data (a bytes object). + self.chunks = [] + # If set, incoming frames are simply discarded. + self.paused = False + + print(f"Finding mic {which_mic}", file=sys.stderr) + self.dumpMicDevices() + + got_match = False + device_index = -1 + if which_mic == "index": + target_str = "Digital Audio Interface" + elif which_mic == "focusrite": + target_str = "Focusrite" + elif which_mic == "motu": + target_str = "In 1-2 (MOTU M Series)" + elif which_mic == "beyond": + target_str = "Microphone (Beyond)" + else: + print(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID", file=sys.stderr) + device_index = int(which_mic) + got_match = True + if not got_match: + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + if target_str in device_name: + print(f"Got matching mic: {device_name}", + file=sys.stderr) + device_index = i + got_match = True + break + if not got_match: + raise KeyError(f"Mic {which_mic} not found") + + info = self.p.get_device_info_by_host_api_device_index(0, device_index) + print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + self.sample_rate = int(info['defaultSampleRate']) + print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + + self.stream = self.p.open( + rate=self.sample_rate, + channels=AudioStream.CHANNELS, + format=AudioStream.FORMAT, + input=True, + frames_per_buffer=MicStream.CHUNK_SZ, + input_device_index=device_index, + stream_callback=self.onAudioFramesAvailable) + + self.stream.start_stream() + + AudioStream.__init__(self) + + def pause(self, state: bool = True): + self.paused = state + + def dumpMicDevices(self): + info = self.p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + for i in range(0, numdevices): + if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') + print("Input Device id ", i, " - ", device_name) + + def onAudioFramesAvailable(self, + frames, + frame_count, + time_info, + status_flags): + if self.paused: + # Don't literally pause, just start returning silence. This allows + # the `min_segment_age_s` check to work while paused. + n_frames = int(frame_count * AudioStream.FPS / + float(self.sample_rate)) + self.chunks.append(np.zeros(n_frames, + dtype=np.int16).tobytes()) + return (frames, pyaudio.paContinue) + + decimated = b'' + # In pyaudio, a `frame` is a single sample of audio data. + frame_len = AudioStream.FRAME_SZ + next_frame = 0.0 + # The mic probably has a higher sample rate than Whisper wants, so + # decrease the sample rate by dropping samples. Note that this + # algorithm only works if the mic's rate is higher than whisper's + # expected rate. + keep_every = float(self.sample_rate) / AudioStream.FPS + for i in range(frame_count): + if i >= next_frame: + decimated += frames[i*frame_len:(i+1)*frame_len] + next_frame += keep_every + self.chunks.append(decimated) + + return (frames, pyaudio.paContinue) + + # Get audio data and the corresponding timestamp. + def getSamples(self) -> bytes: + chunks = self.chunks + self.chunks = [] + result = b''.join(chunks) + return result + +class AudioCollector: + def __init__(self, stream: AudioStream): + self.stream = stream + self.frames = b'' + # Note: by design, this is the only spot where we anchor our timestamps + # against the real world. This is done to make it possible to profile + # test cases which read from disk (at much faster than real speed) in + # the same way that we profile real-time data. + self.wall_ts = time.time() + + def getAudio(self) -> bytes: + frames = self.stream.getSamples() + if frames: + self.frames += frames + return self.frames + + def dropAudioPrefix(self, dur_s: float) -> bytes: + n_bytes = int(dur_s * AudioStream.FPS) * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] + self.frames = self.frames[n_bytes:] + self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS + return cut_portion + + def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes: + n_bytes = dur_frames * self.stream.FRAME_SZ + n_bytes = min(n_bytes, len(self.frames)) + cut_portion = self.frames[:n_bytes] + self.frames = self.frames[n_bytes:] + self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS + return cut_portion + + def keepLast(self, dur_s: float) -> bytes: + drop_len = max(0, self.duration() - dur_s) + return self.dropAudioPrefix(drop_len) + + def dropAudio(self): + self.wall_ts += self.duration() + cut_portion = self.frames + self.frames = b'' + return cut_portion + + def duration(self): + return len(self.frames) / (AudioStream.FPS * self.stream.FRAME_SZ) + + def begin(self): + return self.wall_ts + + def now(self): + return self.begin() + self.duration() + +class AudioCollectorFilter: + def __init__(self, parent: AudioCollector): + self.parent = parent + + def getAudio(self) -> bytes: + return self.parent.getAudio() + def dropAudioPrefix(self, dur_s: float): + return self.parent.dropAudioPrefix(dur_s) + def dropAudioPrefixByFrames(self, dur_frames: int): + return self.parent.dropAudioPrefixByFrames(dur_frames) + def keepLast(self, dur_s): + return self.parent.keepLast(dur_s) + def dropAudio(self): + return self.parent.dropAudio() + def duration(self): + return self.parent.duration() + def begin(self): + return self.parent.begin() + def now(self): + return self.parent.now() + +# Audio collector that enforces a minimum length on its audio data. +class LengthEnforcingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, min_duration_s: float): + AudioCollectorFilter.__init__(self, parent) + self.min_duration_s = min_duration_s + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + min_duration_frames = int(self.min_duration_s * AudioStream.FPS) + pad_len_frames = max(0, min_duration_frames - int(len(audio) / + AudioStream.FRAME_SZ)) + pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes() + return pad + audio + +class NormalizingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + audio = audio.normalize() + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + +class CompressingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector): + AudioCollectorFilter.__init__(self, parent) + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + # subtle compression has a slight positive effect on my benchmark + audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0) + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + +class AudioSegmenter: + def __init__(self, + min_silence_ms=250, + max_speech_s=5): + self.vad_options = vad.VadOptions( + min_silence_duration_ms=min_silence_ms, + max_speech_duration_s=max_speech_s) + pass + + def segmentAudio(self, audio: bytes): + audio = np.frombuffer(audio, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + return vad.get_speech_timestamps(audio, vad_options=self.vad_options) + + # Returns the stable cutoff (if any) and whether there are any segments. + def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]: + min_delta_frames = int((self.vad_options.min_silence_duration_ms * + AudioStream.FPS) / 1000.0) + cutoff = None + + last_end = None + segments = self.segmentAudio(audio) + + for i in range(len(segments)): + s = segments[i] + #print(f"s: {s}") + #print(f"last_end: {last_end}") + + if last_end: + delta_frames = s['start'] - last_end + #print(f"delta frames: {delta_frames}") + if delta_frames > min_delta_frames: + cutoff = s['start'] + else: + last_end = s['end'] + + if i == len(segments) - 1: + now = int(len(audio) / AudioStream.FRAME_SZ) + #print(f"now: {now}") + #print(f"min d: {min_delta_frames}") + delta_frames = now - s['end'] + if delta_frames > min_delta_frames: + cutoff = now - int(min_delta_frames / 2) + + return (cutoff, len(segments) > 0) + +# A segment of transcribed audio. `start_ts` and `end_ts` are floating point +# number of seconds since the beginning of audio data. +class Segment: + def __init__(self, + transcript: str, + start_ts: float, + end_ts: float, + wall_ts: float, + avg_logprob: float, + no_speech_prob: float, + compression_ratio: float): + self.transcript = transcript + # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. + self.start_ts = start_ts + self.end_ts = end_ts + # wall_ts is the time.time() at which the oldest audio sample leading + # to this transcript was collected. + self.wall_ts = wall_ts + self.avg_logprob = avg_logprob + self.no_speech_prob = no_speech_prob + self.compression_ratio = compression_ratio + + def __str__(self): + ts = f"(ts: {self.start_ts}-{self.end_ts}) " + + wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') + wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " + + no_speech = f"(no_speech: {self.no_speech_prob}) " + avg_logprob = f"(avg_logprob: {self.avg_logprob}) " + return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob + +class Whisper: + def __init__(self, + collector: AudioCollector, + cfg: typing.Dict): + self.collector = collector + self.model = None + self.cfg = cfg + + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + parent_dir = os.path.dirname(my_dir) + + model_str = cfg["model"] + model_root = os.path.join(parent_dir, "Models", + os.path.normpath(model_str)) + print(f"Model {cfg['model']} will be saved to {model_root}", + file=sys.stderr) + + model_device = "cuda" + if cfg["use_cpu"]: + model_device = "cpu" + + already_downloaded = os.path.exists(model_root) + + self.model = WhisperModel(model_str, + device = model_device, + device_index = cfg["gpu_idx"], + compute_type = cfg["compute_type"], + download_root = model_root, + local_files_only = already_downloaded) + + def transcribe(self, frames: bytes = None) -> typing.List[Segment]: + if frames is None: + frames = self.collector.getAudio() + # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on + # [-1, 1]. + audio = np.frombuffer(frames, + dtype=np.int16).flatten().astype(np.float32) / 32768.0 + + t0 = time.time() + segments, info = self.model.transcribe( + audio, + language = langcodes.find(self.cfg["language"]).language, + vad_filter = True, + temperature=0.0, + without_timestamps = False) + res = [] + for s in segments: + # Manual touchup. I see a decent number of hallucinations sneaking + # in with high `no_speech_prob` and modest `avg_logprob`. + if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 1) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + continue + # Another touchup targeted at the vexatious "thanks for watching!" + # hallucination. This triggers a lot when listening to + # instrumental/electronic music. + if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 2) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + continue + if self.cfg["enable_debug_mode"]: + print(f"s get: {s}") + if s.avg_logprob < -1.0: + continue + if s.compression_ratio > 2.4: + continue + res.append(Segment(s.text, s.start, s.end, + self.collector.begin(), + s.avg_logprob, s.no_speech_prob, s.compression_ratio)) + t1 = time.time() + if self.cfg["enable_debug_mode"]: + print(f"Transcription latency (s): {t1 - t0}") + return res + +class TranscriptCommit: + def __init__(self, + delta: str, + preview: str, + latency_s: float = None, + thresh_at_commit: int = None, + audio: bytes = None, + duration_s: float = None, + start_ts: float = None): + self.delta = delta + self.preview = preview + self.latency_s = latency_s + self.thresh_at_commit = thresh_at_commit + self.audio = audio + # Time at which the commit is generated + self.ts = time.time() + # Time corresponding to the start of the segment + self.start_ts = start_ts + # The duration of the audio segment, in seconds. + self.duration_s = duration_s + + +class VadCommitter: + def __init__(self, + cfg: typing.Dict, + collector: AudioCollector, + whisper: Whisper, + segmenter: AudioSegmenter): + self.cfg = cfg + self.collector = collector + self.whisper = whisper + self.segmenter = segmenter + + def getDelta(self) -> TranscriptCommit: + audio = self.collector.getAudio() + stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio) + + delta = "" + commit_audio = None + latency_s = None + duration_s = self.collector.duration() + start_ts = self.collector.begin() + + if has_audio and stable_cutoff: + #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr) + latency_s = self.collector.now() - self.collector.begin() + duration_s = stable_cutoff / AudioStream.FPS + start_ts = self.collector.begin() + commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff) + + segments = self.whisper.transcribe(commit_audio) + delta = ''.join(s.transcript for s in segments) + audio = self.collector.getAudio() + if self.cfg["enable_debug_mode"]: + for s in segments: + print(f"commit segment: {s}", file=sys.stderr) + print(f"delta get: {delta}", file=sys.stderr) + + if False: + ts = datetime.fromtimestamp(self.collector.now() - latency_s) + filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav" + saveAudio(commit_audio, filename) + + preview = "" + if self.cfg["enable_previews"] and has_audio: + segments = self.whisper.transcribe(audio) + preview = "".join(s.transcript for s in segments) + + if not has_audio: + #print("VAD detects no audio, skip transcription", file=sys.stderr) + self.collector.keepLast(1.0) + + return TranscriptCommit( + delta.strip(), + preview.strip(), + latency_s, + audio=audio, + duration_s=duration_s, + start_ts=start_ts) + +def transcriptionThread(shared_data: SharedThreadData): + last_stable_commit = None + + stream = MicStream(shared_data.cfg["microphone"]) + collector = AudioCollector(stream) + collector = NormalizingAudioCollector(collector) + collector = CompressingAudioCollector(collector) + whisper = Whisper(collector, shared_data.cfg) + segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], + max_speech_s=shared_data.cfg["max_speech_duration_s"]) + committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + + transcript = "" + preview = "" + + while not shared_data.exit_event.is_set(): + time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); + + op = None + + commit = committer.getDelta() + + if len(commit.delta) > 0 or len(commit.preview) > 0: + # Avoid re-sending text after long pauses. User controls the length + # of the pause in the UI. + if shared_data.cfg["reset_after_silence_s"] > 0: + silence_duration = 0 + if last_stable_commit: + last_commit_end_ts = \ + last_stable_commit.start_ts + \ + last_stable_commit.duration_s + silence_duration = commit.start_ts - last_commit_end_ts + if silence_duration > shared_data.cfg["reset_after_silence_s"]: + print(f"Resetting transcript after {silence_duration}-second " + "silence", file=sys.stderr) + transcript = "" + preview = "" + if commit.delta: + last_stable_commit = commit + + # Hard-cap displayed transcript length at 4k characters to prevent + # runaway memory use in UI. Keep the full transcript to avoid + # breaking OSC pager. + transcript = transcript[-4096:] + def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + transcript = join_segments(transcript, commit.delta) + preview = commit.preview + + try: + print(f"Transcript: {transcript}") + except UnicodeEncodeError: + print("Failed to encode transcript - discarding delta", + file=sys.stderr) + continue + try: + print(f"Preview: {preview}") + except UnicodeEncodeError: + print("Failed to encode preview - discarding", file=sys.stderr) + + with shared_data.word_lock: + shared_data.word = join_segments(transcript, preview) + + if shared_data.cfg["enable_debug_mode"]: + print(f"commit latency: {commit.latency_s}", file=sys.stderr) + print(f"commit thresh: {commit.thresh_at_commit}", + file=sys.stderr) + + if len(transcript) > 0 and \ + (not transcript.endswith(' ')) and \ + (not commit.delta.startswith(' ')): + commit.delta = ' ' + commit.delta + if len(commit.delta) > 0 and \ + (not commit.delta.endswith(' ')) and \ + (not commit.preview.startswith(' ')): + commit.preview = ' ' + commit.preview + diff --git a/app/vad.py b/app/vad.py new file mode 100644 index 0000000..10a72d3 --- /dev/null +++ b/app/vad.py @@ -0,0 +1,313 @@ +# MIT License +# +# Copyright (c) 2023 Guillaume Klein +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import bisect +import functools +import os +import warnings + +from typing import List, NamedTuple, Optional + +import numpy as np + + +# The code below is adapted from https://github.com/snakers4/silero-vad. +class VadOptions(NamedTuple): + """VAD options. + + Attributes: + threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune this + parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer + than max_speech_duration_s will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be + split aggressively just before max_speech_duration_s. + min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms + before separating it + window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. + Values other than these may affect model performance!! + speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +def get_speech_timestamps( + audio: np.ndarray, + vad_options: Optional[VadOptions] = None, + **kwargs, +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. + + Returns: + List of dicts containing begin and end samples of each speech chunk. + """ + if vad_options is None: + vad_options = VadOptions(**kwargs) + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + + if window_size_samples not in [512, 1024, 1536]: + warnings.warn( + "Unusual window_size_samples! Supported window_size_samples:\n" + " - [512, 1024, 1536] for 16000 sampling_rate" + ) + + sampling_rate = 16000 + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = ( + sampling_rate * max_speech_duration_s + - window_size_samples + - 2 * speech_pad_samples + ) + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + model = get_vad_model() + state = model.get_initial_state(batch_size=1) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[current_start_sample : current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob, state = model(chunk, state, sampling_rate) + speech_probs.append(speech_prob) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i + continue + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end + speeches.append(current_speech) + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + + return speeches + + +def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: + """Collects and concatenates audio chunks.""" + if not chunks: + return np.array([], dtype=np.float32) + + return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + + +class SpeechTimestampsMap: + """Helper class to restore original speech timestamps.""" + + def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): + self.sampling_rate = sampling_rate + self.time_precision = time_precision + self.chunk_end_sample = [] + self.total_silence_before = [] + + previous_end = 0 + silent_samples = 0 + + for chunk in chunks: + silent_samples += chunk["start"] - previous_end + previous_end = chunk["end"] + + self.chunk_end_sample.append(chunk["end"] - silent_samples) + self.total_silence_before.append(silent_samples / sampling_rate) + + def get_original_time( + self, + time: float, + chunk_index: Optional[int] = None, + ) -> float: + if chunk_index is None: + chunk_index = self.get_chunk_index(time) + + total_silence_before = self.total_silence_before[chunk_index] + return round(total_silence_before + time, self.time_precision) + + def get_chunk_index(self, time: float) -> int: + sample = int(time * self.sampling_rate) + return min( + bisect.bisect(self.chunk_end_sample, sample), + len(self.chunk_end_sample) - 1, + ) + + +@functools.lru_cache +def get_vad_model(): + """Returns the VAD model instance.""" + abspath = os.path.abspath(__file__) + my_dir = os.path.dirname(abspath) + path = os.path.join(my_dir, "Models/silero_vad.onnx") + return SileroVADModel(path) + + +class SileroVADModel: + def __init__(self, path): + try: + import onnxruntime + except ImportError as e: + raise RuntimeError( + "Applying the VAD filter requires the onnxruntime package" + ) from e + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + opts.log_severity_level = 4 + + self.session = onnxruntime.InferenceSession( + path, + providers=["CPUExecutionProvider"], + sess_options=opts, + ) + + def get_initial_state(self, batch_size: int): + h = np.zeros((2, batch_size, 64), dtype=np.float32) + c = np.zeros((2, batch_size, 64), dtype=np.float32) + return h, c + + def __call__(self, x, state, sr: int): + if len(x.shape) == 1: + x = np.expand_dims(x, 0) + if len(x.shape) > 2: + raise ValueError( + f"Too many dimensions for input audio chunk {len(x.shape)}" + ) + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + h, c = state + + ort_inputs = { + "input": x, + "h": h, + "c": c, + "sr": np.array(sr, dtype="int64"), + } + + out, h, c = self.session.run(None, ort_inputs) + state = (h, c) + + return out, state diff --git a/app_config.py b/app_config.py deleted file mode 100644 index f911456..0000000 --- a/app_config.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import sys -import typing - -def getConfig(path: str) -> typing.Dict[str, typing.Union[str, float, int, bool]]: - # Helper functions to detect and convert the type - def is_int(value: str) -> bool: - try: - int(value) - return True - except ValueError: - return False - - def is_float(value: str) -> bool: - try: - float(value) - return True - except ValueError: - return False - - def convert_value(key: str, value: str): - if key.startswith(("enable_", "remove_", "use_", "clear_")): - return bool(int(value)) - elif is_int(value): - return int(value) - elif is_float(value): - return float(value) - else: - return value - - config = {} - with open(path, 'r') as file: - for line in file: - key_value = line.strip().split(": ", maxsplit=1) - key = key_value[0] - value = key_value[1] if len(key_value) > 1 else "" - config[key] = convert_value(key, value.strip()) - return config - diff --git a/hi.py b/hi.py deleted file mode 100644 index 0129958..0000000 --- a/hi.py +++ /dev/null @@ -1,384 +0,0 @@ -import app_config -import argparse -from math import floor, ceil -import msvcrt -from pythonosc import udp_client -import sentencepiece as spm -from shared_thread_data import SharedThreadData -import stt -import sys -import threading -import time - -TESTS_ENABLED = True - -# 0 = quiet, 1 = verbose, 2 = very verbose -LOG_LEVEL = 0 - -def get_tokenizer(): - model_path = "./custom_unigram_tokenizer_65k/unigram.model" - print(f"Loading SentencePiece tokenizer from: {model_path}") - sp = spm.SentencePieceProcessor() - sp.load(model_path) - print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") - return sp - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True) - return parser.parse_args() - -def assert_equal(a, b): - err_msg = f"{a} != {b}" - assert a == b, err_msg - -# Turn a whitespace-delimited string into a list of strings no longer than -# `cols`. -# Preferentially breaks strings at whitespace boundaries. Preserves whitespace -# between words, except if that whitespace comes between lines. Breaks words -# longer than `cols` with a hyphen. -def wrap_line(line: str, cols): - # First, split line into alternating chunks of words and whitespace. - def get_sequences(line): - is_space = False - sequences = [] - seq_start = 0 - seq_end = -1 - for i in range(0, len(line)): - if line[i].isspace(): - if is_space: - seq_end = i - continue - # We were looking at text, now we see whitespace. - seq = line[seq_start:seq_end+1] - if len(seq) > 0: - sequences.append(seq) - seq_start = i - seq_end = i - is_space = True - else: - if not is_space: - seq_end = i - continue - # We were looking at whitespace, now we see text. - seq = line[seq_start:seq_end+1] - if len(seq) > 0: - sequences.append(seq) - seq_start = i - seq_end = i - is_space = False - sequences.append(line[seq_start:seq_end+1]) - return sequences - if TESTS_ENABLED: - assert_equal(get_sequences("foo"), ["foo"]) - assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"]) - assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) - assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) - - # Next, greedily construct lines out of those sequences. - # Whitespace gets treated specially. If it would push us over the limit, we - # end the line and drop the whitespace. - sequences = get_sequences(line) - def coalesce_sequences(sequences, cols): - cur_line = "" - lines = [] - for seq in sequences: - if len(cur_line) + len(seq) <= cols: - cur_line += seq - continue - if seq.isspace(): - lines.append(cur_line) - cur_line = "" - continue - if len(cur_line) > 0: - lines.append(cur_line) - # Edge case: text sequence is longer than a line. - while len(seq) > cols: - seq_prefix = seq[0:cols-1] + "-" - seq = seq[cols-1:] - lines.append(seq_prefix) - cur_line = seq - if len(cur_line) > 0: - lines.append(cur_line) - return lines - if TESTS_ENABLED: - assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"]) - assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"]) - assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"]) - assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"]) - assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"]) - - lines = coalesce_sequences(sequences, cols) - - # Next, pad each line with whitespace. - def pad_lines(lines, cols): - for i in range(0, len(lines)): - lines[i] += ' ' * (cols - len(lines[i])) - return lines - if TESTS_ENABLED: - assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "]) - assert_equal(pad_lines(["foo"], 2), ["foo"]) - - return pad_lines(lines, cols) - -def get_blocks(lines, tokenizer, block_width, num_blocks): - if LOG_LEVEL == 2: - print(f"Lines sent to tokenizer: {''.join(lines)}") - tokens = tokenizer.encode_as_ids(''.join(lines)) - if LOG_LEVEL == 2: - print(f"Tokens: {tokens}") - pieces = [] - for tok in tokens: - piece = tokenizer.id_to_piece(tok) - pieces.append(piece) - if LOG_LEVEL == 2: - print(f"Pieces: {pieces}") - - # Group tokens into blocks and pad with empty characters. - # Also get visual pointers - the location where each block will be rendered. - def get_blocks(): - blocks = [] - visual_pointer = 0 - visual_pointers = [] - for i in range(0, ceil(len(tokens) / block_width)): - visual_pointers.append(visual_pointer) - block = [] - for j in range(0, block_width): - if i*block_width + j >= len(tokens): - # Pad block with empty characters. 65535 is a special token. - block += [65535] * (block_width - len(block)) - break - block.append(tokens[i*block_width+j]) - visual_pointer += len(pieces[i*block_width+j]) - blocks.append(block) - return (blocks, visual_pointers) - blocks, visual_pointers = get_blocks() - if LOG_LEVEL == 2: - print(f"Blocks: {blocks}") - print(f"Visual pointers: {visual_pointers}") - - # Set all blocks up to the next `num_blocks` boundary to blank tokens. - # This handles the edge case where a prior message wrote data there which - # is covering up our new data. - def pad_blocks(blocks, visual_pointers): - cur_num_blocks = len(blocks) - num_pad_blocks = num_blocks - cur_num_blocks - for i in range(0, num_pad_blocks): - blocks.append([65535] * block_width) - visual_pointers.append(255) - return blocks, visual_pointers - blocks, visual_pointers = pad_blocks(blocks, visual_pointers) - if LOG_LEVEL == 2: - print(f"Blocks (padded): {blocks}") - print(f"Visual pointers (padded): {visual_pointers}") - - return blocks, visual_pointers - -def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks, - cur_visual_pointers): - diff_indices = [] - diff_blocks = [] - diff_visual_pointers = [] - - for i in range(0, len(cur_blocks)): - if i >= len(prev_blocks): - diff_blocks.append(cur_blocks[i]) - diff_visual_pointers.append(cur_visual_pointers[i]) - diff_indices.append(i) - continue - if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]: - diff_blocks.append(cur_blocks[i]) - diff_visual_pointers.append(cur_visual_pointers[i]) - diff_indices.append(i) - - return diff_indices, diff_blocks, diff_visual_pointers - -def send_data(osc_client, indices, blocks, visual_pointers): - def split_blocks_by_byte(blocks): - blocks_byte00 = [] - blocks_byte01 = [] - for block in blocks: - block_byte00 = [] - block_byte01 = [] - for datum in block: - block_byte00.append((datum >> 0) & 0xFF) - block_byte01.append((datum >> 8) & 0xFF) - blocks_byte00.append(block_byte00) - blocks_byte01.append(block_byte01) - return blocks_byte00, blocks_byte01 - - blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) - if LOG_LEVEL == 2: - print(f"Blocks (byte 00): {blocks_byte00}") - print(f"Blocks (byte 01): {blocks_byte01}") - - def send_osc(osc_client, addr, data): - #print(f"Sending {data} to {addr}") - osc_client.send_message(addr, data) - - for i in range(0, len(blocks)): - lp_int = indices[i] - lp_param = "_Unigram_Letter_Grid_OSC_Pointer" - addr = "/avatar/parameters/" + lp_param - send_osc(osc_client, addr, lp_int) - - vp_float = (-127.5 + visual_pointers[i]) / 127.5 - vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer" - addr = "/avatar/parameters/" + vp_param - send_osc(osc_client, addr, vp_float) - if LOG_LEVEL == 2: - print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") - for j in range(0, len(blocks[i])): - byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 - byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 - byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00" - byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01" - addr = "/avatar/parameters/" + byte00_param - send_osc(osc_client, addr, byte00_float) - addr = "/avatar/parameters/" + byte01_param - send_osc(osc_client, addr, byte01_float) - time.sleep(0.34) - -def getOscClient(ip = "127.0.0.1", port = 9000): - return udp_client.SimpleUDPClient(ip, port) - -class InputState: - def __init__(self): - self.page = 0 - # Initialize the known state of the board to empty array. This will cause - # our paging logic to re-send everything the first time around. - self.blocks = [] - self.visual_pointers = [] - pass - -def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): - line_wrapped = wrap_line(line, cfg["cols"]) - if TESTS_ENABLED: - for line in line_wrapped: - assert_equal(len(line), cfg["cols"]) - if LOG_LEVEL == 2: - print(f"Wrapped lines: {line_wrapped}") - - # Get several blank lines whenever we roll over. - # It's better for the reader to have some continuity when the board pages - # over. If we simply replaced the entire screen, it would be harder to - # understand. - line_rollover = cfg["rows"] - 2 - blank_line = ' ' * cfg["cols"] - # We show a full page, then only `line_rollover` additional lines per page. - end_ptr = cfg["rows"] - which_page = 0 - while end_ptr < len(line_wrapped): - end_ptr += line_rollover - which_page += 1 - if state.page != which_page: - state.blocks = [] - state.visual_pointers = [] - state.page = which_page - line_wrapped = line_wrapped[end_ptr-cfg["rows"]:] - - # Get blocks and visual pointers. - blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, - cfg["block_width"], cfg["num_blocks"]) - - # Note that because we only send one page of data at a time, we don't have - # to worry about wrapping visual pointers! We will basically never run out - # of space. - indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) - indices = [idx % cfg["num_blocks"] for idx in indices] - # Send only one block at a time to make things snappier in interactive use - # case. - # TODO use a continuation (yield) instead of returning. Then we can be a - # little lighter on the cpu. Measurements show that this script is - # already very light but we're clearly wasting a lot of work by - # re-tokenizing the entire input every time we send a block. - if len(indices) == 0: - return - if indices[0] == len(state.blocks): - state.blocks.append(diff_blocks[0]) - state.visual_pointers.append(diff_visual_pointers[0]) - elif indices[0] > len(state.blocks): - print(f"This should never happen!") - sys.exit(1) - else: - state.blocks[indices[0]] = diff_blocks[0] - state.visual_pointers[indices[0]] = diff_visual_pointers[0] - - send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) - -def osc_thread(shared_data: SharedThreadData): - tokenizer = get_tokenizer() - osc_client = getOscClient() - - # Prime the board - print("Priming the board") - input_state = InputState() - handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) - - while not shared_data.exit_event.is_set(): - word_copy = "" - with shared_data.word_lock: - word_copy = shared_data.word - handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) - time.sleep(0.01) - -if __name__ == "__main__": - cli_args = parse_args() - cfg = app_config.getConfig(cli_args.config) - shared_data = SharedThreadData(cfg) - osc_thread = threading.Thread( - target=osc_thread, - args=(shared_data,)) - osc_thread.start() - - transcribe_thread = threading.Thread( - target=stt.transcriptionThread, - args=(shared_data,)) - transcribe_thread.start() - - word_is_over = False - local_word = "" - while True: - char_bytes = msvcrt.getch() - if char_bytes == b'\x03': # ctrl+C - break - - time.sleep(0.1) - continue - - - try: - char = char_bytes.decode('utf-8') - if char == '\r' or char == '\n': - word_is_over = True - continue - except UnicodeDecodeError: - print(f"Unsupported character: {char_bytes}") - if char_bytes == b'\x00' or char_bytes == b'\xe0': - special_char = msvcrt.getch() - continue - - if char_bytes == b'\x03': # ctrl+C - break - elif char_bytes == b'\x08': # backspace - with shared_data.word_lock: - shared_data.word = shared_data.word[:-1] - local_word = shared_data.word - elif char_bytes == b'\x0c': # ctrl+L - with shared_data.word_lock: - shared_data.word = "" - local_word = shared_data.word - elif word_is_over: - with shared_data.word_lock: - shared_data.word = char - local_word = shared_data.word - word_is_over = False - else: - with shared_data.word_lock: - shared_data.word += char - local_word = shared_data.word - print(local_word + "_") - shared_data.exit_event.set() - osc_thread.join() - transcribe_thread.join() - diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 1043fae..0000000 --- a/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -datasets -faster-whisper -langcodes -pillow -pyaudio -pydub -python-osc -unidecode -sentencepiece - diff --git a/shared_thread_data.py b/shared_thread_data.py deleted file mode 100644 index ba0a419..0000000 --- a/shared_thread_data.py +++ /dev/null @@ -1,9 +0,0 @@ -import threading - -class SharedThreadData: - def __init__(self, cfg): - self.word = "" - self.word_lock = threading.Lock() - self.exit_event = threading.Event() - self.cfg = cfg - diff --git a/stt.py b/stt.py deleted file mode 100644 index 34ef2e9..0000000 --- a/stt.py +++ /dev/null @@ -1,581 +0,0 @@ -from faster_whisper import WhisperModel -import langcodes -import numpy as np -import os -import pyaudio -from pydub import AudioSegment -from shared_thread_data import SharedThreadData -import sys -import time -import typing -import vad - -class AudioStream(): - FORMAT = pyaudio.paInt16 - # Size of each frame (audio sample), in bytes. If you change FORMAT, make - # sure this stays up to date! - FRAME_SZ = 2 - # Frames per second. - FPS = 16000 - CHANNELS = 1 - def __init__(self): - pass - - def getSamples(self) -> bytes: - raise NotImplementedError("getSamples is not implemented!") - -class MicStream(AudioStream): - CHUNK_SZ = 1024 - - def __init__(self, which_mic: str): - self.p = pyaudio.PyAudio() - self.stream = None - self.sample_rate = None - # Each time pyaudio gives us audio data, it's in the form of a chunk of - # samples. We keep these in a list to keep the audio callback as light - # as possible. Whenever downstream layers want data, we collapse the - # list into a single array of data (a bytes object). - self.chunks = [] - # If set, incoming frames are simply discarded. - self.paused = False - - print(f"Finding mic {which_mic}", file=sys.stderr) - self.dumpMicDevices() - - got_match = False - device_index = -1 - if which_mic == "index": - target_str = "Digital Audio Interface" - elif which_mic == "focusrite": - target_str = "Focusrite" - elif which_mic == "motu": - target_str = "In 1-2 (MOTU M Series)" - elif which_mic == "beyond": - target_str = "Microphone (Beyond)" - else: - print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID", file=sys.stderr) - device_index = int(which_mic) - got_match = True - if not got_match: - info = self.p.get_host_api_info_by_index(0) - numdevices = info.get('deviceCount') - for i in range(0, numdevices): - if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - if target_str in device_name: - print(f"Got matching mic: {device_name}", - file=sys.stderr) - device_index = i - got_match = True - break - if not got_match: - raise KeyError(f"Mic {which_mic} not found") - - info = self.p.get_device_info_by_host_api_device_index(0, device_index) - print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) - self.sample_rate = int(info['defaultSampleRate']) - print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) - - self.stream = self.p.open( - rate=self.sample_rate, - channels=AudioStream.CHANNELS, - format=AudioStream.FORMAT, - input=True, - frames_per_buffer=MicStream.CHUNK_SZ, - input_device_index=device_index, - stream_callback=self.onAudioFramesAvailable) - - self.stream.start_stream() - - AudioStream.__init__(self) - - def pause(self, state: bool = True): - self.paused = state - - def dumpMicDevices(self): - info = self.p.get_host_api_info_by_index(0) - numdevices = info.get('deviceCount') - - for i in range(0, numdevices): - if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name') - print("Input Device id ", i, " - ", device_name) - - def onAudioFramesAvailable(self, - frames, - frame_count, - time_info, - status_flags): - if self.paused: - # Don't literally pause, just start returning silence. This allows - # the `min_segment_age_s` check to work while paused. - n_frames = int(frame_count * AudioStream.FPS / - float(self.sample_rate)) - self.chunks.append(np.zeros(n_frames, - dtype=np.int16).tobytes()) - return (frames, pyaudio.paContinue) - - decimated = b'' - # In pyaudio, a `frame` is a single sample of audio data. - frame_len = AudioStream.FRAME_SZ - next_frame = 0.0 - # The mic probably has a higher sample rate than Whisper wants, so - # decrease the sample rate by dropping samples. Note that this - # algorithm only works if the mic's rate is higher than whisper's - # expected rate. - keep_every = float(self.sample_rate) / AudioStream.FPS - for i in range(frame_count): - if i >= next_frame: - decimated += frames[i*frame_len:(i+1)*frame_len] - next_frame += keep_every - self.chunks.append(decimated) - - return (frames, pyaudio.paContinue) - - # Get audio data and the corresponding timestamp. - def getSamples(self) -> bytes: - chunks = self.chunks - self.chunks = [] - result = b''.join(chunks) - return result - -class AudioCollector: - def __init__(self, stream: AudioStream): - self.stream = stream - self.frames = b'' - # Note: by design, this is the only spot where we anchor our timestamps - # against the real world. This is done to make it possible to profile - # test cases which read from disk (at much faster than real speed) in - # the same way that we profile real-time data. - self.wall_ts = time.time() - - def getAudio(self) -> bytes: - frames = self.stream.getSamples() - if frames: - self.frames += frames - return self.frames - - def dropAudioPrefix(self, dur_s: float) -> bytes: - n_bytes = int(dur_s * AudioStream.FPS) * self.stream.FRAME_SZ - n_bytes = min(n_bytes, len(self.frames)) - cut_portion = self.frames[:n_bytes] - self.frames = self.frames[n_bytes:] - self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS - return cut_portion - - def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes: - n_bytes = dur_frames * self.stream.FRAME_SZ - n_bytes = min(n_bytes, len(self.frames)) - cut_portion = self.frames[:n_bytes] - self.frames = self.frames[n_bytes:] - self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS - return cut_portion - - def keepLast(self, dur_s: float) -> bytes: - drop_len = max(0, self.duration() - dur_s) - return self.dropAudioPrefix(drop_len) - - def dropAudio(self): - self.wall_ts += self.duration() - cut_portion = self.frames - self.frames = b'' - return cut_portion - - def duration(self): - return len(self.frames) / (AudioStream.FPS * self.stream.FRAME_SZ) - - def begin(self): - return self.wall_ts - - def now(self): - return self.begin() + self.duration() - -class AudioCollectorFilter: - def __init__(self, parent: AudioCollector): - self.parent = parent - - def getAudio(self) -> bytes: - return self.parent.getAudio() - def dropAudioPrefix(self, dur_s: float): - return self.parent.dropAudioPrefix(dur_s) - def dropAudioPrefixByFrames(self, dur_frames: int): - return self.parent.dropAudioPrefixByFrames(dur_frames) - def keepLast(self, dur_s): - return self.parent.keepLast(dur_s) - def dropAudio(self): - return self.parent.dropAudio() - def duration(self): - return self.parent.duration() - def begin(self): - return self.parent.begin() - def now(self): - return self.parent.now() - -# Audio collector that enforces a minimum length on its audio data. -class LengthEnforcingAudioCollector(AudioCollectorFilter): - def __init__(self, parent: AudioCollector, min_duration_s: float): - AudioCollectorFilter.__init__(self, parent) - self.min_duration_s = min_duration_s - - def getAudio(self) -> bytes: - audio = self.parent.getAudio() - min_duration_frames = int(self.min_duration_s * AudioStream.FPS) - pad_len_frames = max(0, min_duration_frames - int(len(audio) / - AudioStream.FRAME_SZ)) - pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes() - return pad + audio - -class NormalizingAudioCollector(AudioCollectorFilter): - def __init__(self, parent: AudioCollector): - AudioCollectorFilter.__init__(self, parent) - - def getAudio(self) -> bytes: - audio = self.parent.getAudio() - - audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, - frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) - audio = audio.normalize() - - frames = np.array(audio.get_array_of_samples()) - frames = np.int16(frames).tobytes() - - return frames - -class CompressingAudioCollector(AudioCollectorFilter): - def __init__(self, parent: AudioCollector): - AudioCollectorFilter.__init__(self, parent) - - def getAudio(self) -> bytes: - audio = self.parent.getAudio() - - audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, - frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) - # subtle compression has a slight positive effect on my benchmark - audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0) - - frames = np.array(audio.get_array_of_samples()) - frames = np.int16(frames).tobytes() - - return frames - -class AudioSegmenter: - def __init__(self, - min_silence_ms=250, - max_speech_s=5): - self.vad_options = vad.VadOptions( - min_silence_duration_ms=min_silence_ms, - max_speech_duration_s=max_speech_s) - pass - - def segmentAudio(self, audio: bytes): - audio = np.frombuffer(audio, - dtype=np.int16).flatten().astype(np.float32) / 32768.0 - return vad.get_speech_timestamps(audio, vad_options=self.vad_options) - - # Returns the stable cutoff (if any) and whether there are any segments. - def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]: - min_delta_frames = int((self.vad_options.min_silence_duration_ms * - AudioStream.FPS) / 1000.0) - cutoff = None - - last_end = None - segments = self.segmentAudio(audio) - - for i in range(len(segments)): - s = segments[i] - #print(f"s: {s}") - #print(f"last_end: {last_end}") - - if last_end: - delta_frames = s['start'] - last_end - #print(f"delta frames: {delta_frames}") - if delta_frames > min_delta_frames: - cutoff = s['start'] - else: - last_end = s['end'] - - if i == len(segments) - 1: - now = int(len(audio) / AudioStream.FRAME_SZ) - #print(f"now: {now}") - #print(f"min d: {min_delta_frames}") - delta_frames = now - s['end'] - if delta_frames > min_delta_frames: - cutoff = now - int(min_delta_frames / 2) - - return (cutoff, len(segments) > 0) - -# A segment of transcribed audio. `start_ts` and `end_ts` are floating point -# number of seconds since the beginning of audio data. -class Segment: - def __init__(self, - transcript: str, - start_ts: float, - end_ts: float, - wall_ts: float, - avg_logprob: float, - no_speech_prob: float, - compression_ratio: float): - self.transcript = transcript - # start_ts, end_ts are timestamps in seconds relative to `wall_ts`. - self.start_ts = start_ts - self.end_ts = end_ts - # wall_ts is the time.time() at which the oldest audio sample leading - # to this transcript was collected. - self.wall_ts = wall_ts - self.avg_logprob = avg_logprob - self.no_speech_prob = no_speech_prob - self.compression_ratio = compression_ratio - - def __str__(self): - ts = f"(ts: {self.start_ts}-{self.end_ts}) " - - wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S') - wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S') - wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) " - - no_speech = f"(no_speech: {self.no_speech_prob}) " - avg_logprob = f"(avg_logprob: {self.avg_logprob}) " - return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob - -class Whisper: - def __init__(self, - collector: AudioCollector, - cfg: typing.Dict): - self.collector = collector - self.model = None - self.cfg = cfg - - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) - - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", - os.path.normpath(model_str)) - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) - - model_device = "cuda" - if cfg["use_cpu"]: - model_device = "cpu" - - already_downloaded = os.path.exists(model_root) - - self.model = WhisperModel(model_str, - device = model_device, - device_index = cfg["gpu_idx"], - compute_type = cfg["compute_type"], - download_root = model_root, - local_files_only = already_downloaded) - - def transcribe(self, frames: bytes = None) -> typing.List[Segment]: - if frames is None: - frames = self.collector.getAudio() - # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on - # [-1, 1]. - audio = np.frombuffer(frames, - dtype=np.int16).flatten().astype(np.float32) / 32768.0 - - t0 = time.time() - segments, info = self.model.transcribe( - audio, - language = langcodes.find(self.cfg["language"]).language, - vad_filter = True, - temperature=0.0, - without_timestamps = False) - res = [] - for s in segments: - # Manual touchup. I see a decent number of hallucinations sneaking - # in with high `no_speech_prob` and modest `avg_logprob`. - if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 1) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - # Another touchup targeted at the vexatious "thanks for watching!" - # hallucination. This triggers a lot when listening to - # instrumental/electronic music. - if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7: - if self.cfg["enable_debug_mode"]: - print(f"Drop probable hallucination (case 2) " + - f"(text='{s.text}', " + - f"no_speech_prob={s.no_speech_prob}, " + - f"avg_logprob={s.avg_logprob})", file=sys.stderr) - continue - if self.cfg["enable_debug_mode"]: - print(f"s get: {s}") - if s.avg_logprob < -1.0: - continue - if s.compression_ratio > 2.4: - continue - res.append(Segment(s.text, s.start, s.end, - self.collector.begin(), - s.avg_logprob, s.no_speech_prob, s.compression_ratio)) - t1 = time.time() - if self.cfg["enable_debug_mode"]: - print(f"Transcription latency (s): {t1 - t0}") - return res - -class TranscriptCommit: - def __init__(self, - delta: str, - preview: str, - latency_s: float = None, - thresh_at_commit: int = None, - audio: bytes = None, - duration_s: float = None, - start_ts: float = None): - self.delta = delta - self.preview = preview - self.latency_s = latency_s - self.thresh_at_commit = thresh_at_commit - self.audio = audio - # Time at which the commit is generated - self.ts = time.time() - # Time corresponding to the start of the segment - self.start_ts = start_ts - # The duration of the audio segment, in seconds. - self.duration_s = duration_s - - -class VadCommitter: - def __init__(self, - cfg: typing.Dict, - collector: AudioCollector, - whisper: Whisper, - segmenter: AudioSegmenter): - self.cfg = cfg - self.collector = collector - self.whisper = whisper - self.segmenter = segmenter - - def getDelta(self) -> TranscriptCommit: - audio = self.collector.getAudio() - stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio) - - delta = "" - commit_audio = None - latency_s = None - duration_s = self.collector.duration() - start_ts = self.collector.begin() - - if has_audio and stable_cutoff: - #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr) - latency_s = self.collector.now() - self.collector.begin() - duration_s = stable_cutoff / AudioStream.FPS - start_ts = self.collector.begin() - commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff) - - segments = self.whisper.transcribe(commit_audio) - delta = ''.join(s.transcript for s in segments) - audio = self.collector.getAudio() - if self.cfg["enable_debug_mode"]: - for s in segments: - print(f"commit segment: {s}", file=sys.stderr) - print(f"delta get: {delta}", file=sys.stderr) - - if False: - ts = datetime.fromtimestamp(self.collector.now() - latency_s) - filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav" - saveAudio(commit_audio, filename) - - preview = "" - if self.cfg["enable_previews"] and has_audio: - segments = self.whisper.transcribe(audio) - preview = "".join(s.transcript for s in segments) - - if not has_audio: - #print("VAD detects no audio, skip transcription", file=sys.stderr) - self.collector.keepLast(1.0) - - return TranscriptCommit( - delta.strip(), - preview.strip(), - latency_s, - audio=audio, - duration_s=duration_s, - start_ts=start_ts) - -def transcriptionThread(shared_data: SharedThreadData): - last_stable_commit = None - - stream = MicStream(shared_data.cfg["microphone"]) - collector = AudioCollector(stream) - collector = NormalizingAudioCollector(collector) - collector = CompressingAudioCollector(collector) - whisper = Whisper(collector, shared_data.cfg) - segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], - max_speech_s=shared_data.cfg["max_speech_duration_s"]) - committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) - - transcript = "" - preview = "" - - while not shared_data.exit_event.is_set(): - time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); - - op = None - - commit = committer.getDelta() - - if len(commit.delta) > 0 or len(commit.preview) > 0: - # Avoid re-sending text after long pauses. User controls the length - # of the pause in the UI. - if shared_data.cfg["reset_after_silence_s"] > 0: - silence_duration = 0 - if last_stable_commit: - last_commit_end_ts = \ - last_stable_commit.start_ts + \ - last_stable_commit.duration_s - silence_duration = commit.start_ts - last_commit_end_ts - if silence_duration > shared_data.cfg["reset_after_silence_s"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) - transcript = "" - preview = "" - if commit.delta: - last_stable_commit = commit - - # Hard-cap displayed transcript length at 4k characters to prevent - # runaway memory use in UI. Keep the full transcript to avoid - # breaking OSC pager. - transcript = transcript[-4096:] - def join_segments(a, b): - if len(a) > 0 and a[-1] != ' ': - return a + ' ' + b - else: - return a + b - transcript = join_segments(transcript, commit.delta) - preview = commit.preview - - try: - print(f"Transcript: {transcript}") - except UnicodeEncodeError: - print("Failed to encode transcript - discarding delta", - file=sys.stderr) - continue - try: - print(f"Preview: {preview}") - except UnicodeEncodeError: - print("Failed to encode preview - discarding", file=sys.stderr) - - with shared_data.word_lock: - shared_data.word = join_segments(transcript, preview) - - if shared_data.cfg["enable_debug_mode"]: - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", - file=sys.stderr) - - if len(transcript) > 0 and \ - (not transcript.endswith(' ')) and \ - (not commit.delta.startswith(' ')): - commit.delta = ' ' + commit.delta - if len(commit.delta) > 0 and \ - (not commit.delta.endswith(' ')) and \ - (not commit.preview.startswith(' ')): - commit.preview = ' ' + commit.preview - diff --git a/vad.py b/vad.py deleted file mode 100644 index 10a72d3..0000000 --- a/vad.py +++ /dev/null @@ -1,313 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Guillaume Klein -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import bisect -import functools -import os -import warnings - -from typing import List, NamedTuple, Optional - -import numpy as np - - -# The code below is adapted from https://github.com/snakers4/silero-vad. -class VadOptions(NamedTuple): - """VAD options. - - Attributes: - threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, - probabilities ABOVE this value are considered as SPEECH. It is better to tune this - parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. - min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. - max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer - than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be - split aggressively just before max_speech_duration_s. - min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms - before separating it - window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. - WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. - Values other than these may affect model performance!! - speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side - """ - - threshold: float = 0.5 - min_speech_duration_ms: int = 250 - max_speech_duration_s: float = float("inf") - min_silence_duration_ms: int = 2000 - window_size_samples: int = 1024 - speech_pad_ms: int = 400 - - -def get_speech_timestamps( - audio: np.ndarray, - vad_options: Optional[VadOptions] = None, - **kwargs, -) -> List[dict]: - """This method is used for splitting long audios into speech chunks using silero VAD. - - Args: - audio: One dimensional float array. - vad_options: Options for VAD processing. - kwargs: VAD options passed as keyword arguments for backward compatibility. - - Returns: - List of dicts containing begin and end samples of each speech chunk. - """ - if vad_options is None: - vad_options = VadOptions(**kwargs) - - threshold = vad_options.threshold - min_speech_duration_ms = vad_options.min_speech_duration_ms - max_speech_duration_s = vad_options.max_speech_duration_s - min_silence_duration_ms = vad_options.min_silence_duration_ms - window_size_samples = vad_options.window_size_samples - speech_pad_ms = vad_options.speech_pad_ms - - if window_size_samples not in [512, 1024, 1536]: - warnings.warn( - "Unusual window_size_samples! Supported window_size_samples:\n" - " - [512, 1024, 1536] for 16000 sampling_rate" - ) - - sampling_rate = 16000 - min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 - speech_pad_samples = sampling_rate * speech_pad_ms / 1000 - max_speech_samples = ( - sampling_rate * max_speech_duration_s - - window_size_samples - - 2 * speech_pad_samples - ) - min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 - min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 - - audio_length_samples = len(audio) - - model = get_vad_model() - state = model.get_initial_state(batch_size=1) - - speech_probs = [] - for current_start_sample in range(0, audio_length_samples, window_size_samples): - chunk = audio[current_start_sample : current_start_sample + window_size_samples] - if len(chunk) < window_size_samples: - chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) - speech_prob, state = model(chunk, state, sampling_rate) - speech_probs.append(speech_prob) - - triggered = False - speeches = [] - current_speech = {} - neg_threshold = threshold - 0.15 - - # to save potential segment end (and tolerate some silence) - temp_end = 0 - # to save potential segment limits in case of maximum segment size reached - prev_end = next_start = 0 - - for i, speech_prob in enumerate(speech_probs): - if (speech_prob >= threshold) and temp_end: - temp_end = 0 - if next_start < prev_end: - next_start = window_size_samples * i - - if (speech_prob >= threshold) and not triggered: - triggered = True - current_speech["start"] = window_size_samples * i - continue - - if ( - triggered - and (window_size_samples * i) - current_speech["start"] > max_speech_samples - ): - if prev_end: - current_speech["end"] = prev_end - speeches.append(current_speech) - current_speech = {} - # previously reached silence (< neg_thres) and is still not speech (< thres) - if next_start < prev_end: - triggered = False - else: - current_speech["start"] = next_start - prev_end = next_start = temp_end = 0 - else: - current_speech["end"] = window_size_samples * i - speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue - - if (speech_prob < neg_threshold) and triggered: - if not temp_end: - temp_end = window_size_samples * i - # condition to avoid cutting in very short silence - if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: - prev_end = temp_end - if (window_size_samples * i) - temp_end < min_silence_samples: - continue - else: - current_speech["end"] = temp_end - if ( - current_speech["end"] - current_speech["start"] - ) > min_speech_samples: - speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue - - if ( - current_speech - and (audio_length_samples - current_speech["start"]) > min_speech_samples - ): - current_speech["end"] = audio_length_samples - speeches.append(current_speech) - - for i, speech in enumerate(speeches): - if i == 0: - speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) - if i != len(speeches) - 1: - silence_duration = speeches[i + 1]["start"] - speech["end"] - if silence_duration < 2 * speech_pad_samples: - speech["end"] += int(silence_duration // 2) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - silence_duration // 2) - ) - else: - speech["end"] = int( - min(audio_length_samples, speech["end"] + speech_pad_samples) - ) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - speech_pad_samples) - ) - else: - speech["end"] = int( - min(audio_length_samples, speech["end"] + speech_pad_samples) - ) - - return speeches - - -def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: - """Collects and concatenates audio chunks.""" - if not chunks: - return np.array([], dtype=np.float32) - - return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) - - -class SpeechTimestampsMap: - """Helper class to restore original speech timestamps.""" - - def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): - self.sampling_rate = sampling_rate - self.time_precision = time_precision - self.chunk_end_sample = [] - self.total_silence_before = [] - - previous_end = 0 - silent_samples = 0 - - for chunk in chunks: - silent_samples += chunk["start"] - previous_end - previous_end = chunk["end"] - - self.chunk_end_sample.append(chunk["end"] - silent_samples) - self.total_silence_before.append(silent_samples / sampling_rate) - - def get_original_time( - self, - time: float, - chunk_index: Optional[int] = None, - ) -> float: - if chunk_index is None: - chunk_index = self.get_chunk_index(time) - - total_silence_before = self.total_silence_before[chunk_index] - return round(total_silence_before + time, self.time_precision) - - def get_chunk_index(self, time: float) -> int: - sample = int(time * self.sampling_rate) - return min( - bisect.bisect(self.chunk_end_sample, sample), - len(self.chunk_end_sample) - 1, - ) - - -@functools.lru_cache -def get_vad_model(): - """Returns the VAD model instance.""" - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - path = os.path.join(my_dir, "Models/silero_vad.onnx") - return SileroVADModel(path) - - -class SileroVADModel: - def __init__(self, path): - try: - import onnxruntime - except ImportError as e: - raise RuntimeError( - "Applying the VAD filter requires the onnxruntime package" - ) from e - - opts = onnxruntime.SessionOptions() - opts.inter_op_num_threads = 1 - opts.intra_op_num_threads = 1 - opts.log_severity_level = 4 - - self.session = onnxruntime.InferenceSession( - path, - providers=["CPUExecutionProvider"], - sess_options=opts, - ) - - def get_initial_state(self, batch_size: int): - h = np.zeros((2, batch_size, 64), dtype=np.float32) - c = np.zeros((2, batch_size, 64), dtype=np.float32) - return h, c - - def __call__(self, x, state, sr: int): - if len(x.shape) == 1: - x = np.expand_dims(x, 0) - if len(x.shape) > 2: - raise ValueError( - f"Too many dimensions for input audio chunk {len(x.shape)}" - ) - if sr / x.shape[1] > 31.25: - raise ValueError("Input audio chunk is too short") - - h, c = state - - ort_inputs = { - "input": x, - "h": h, - "c": c, - "sr": np.array(sr, dtype="int64"), - } - - out, h, c = self.session.run(None, ort_inputs) - state = (h, c) - - return out, state -- cgit v1.2.3 From 1ede199387c072a85e8757a6aaec04d2c7cdeba4 Mon Sep 17 00:00:00 2001 From: yum Date: Thu, 29 May 2025 15:56:51 -0700 Subject: Add basic electron+tailwind hello world --- ui/.gitignore | 3 +++ ui/index.html | 20 ++++++++++++++++++++ ui/index.js | 29 +++++++++++++++++++++++++++++ ui/package.json | 24 ++++++++++++++++++++++++ ui/postcss.config.js | 6 ++++++ ui/preload.js | 7 +++++++ ui/src/input.css | 3 +++ ui/tailwind.config.js | 12 ++++++++++++ ui_design.md | 29 +++++++++++++++++++++++++++++ 9 files changed, 133 insertions(+) create mode 100644 ui/.gitignore create mode 100644 ui/index.html create mode 100644 ui/index.js create mode 100644 ui/package.json create mode 100644 ui/postcss.config.js create mode 100644 ui/preload.js create mode 100644 ui/src/input.css create mode 100644 ui/tailwind.config.js create mode 100644 ui_design.md diff --git a/ui/.gitignore b/ui/.gitignore new file mode 100644 index 0000000..2109e19 --- /dev/null +++ b/ui/.gitignore @@ -0,0 +1,3 @@ +build +node_modules +package-lock.json diff --git a/ui/index.html b/ui/index.html new file mode 100644 index 0000000..240e6ca --- /dev/null +++ b/ui/index.html @@ -0,0 +1,20 @@ + + + + + + Hello World! + + +
+

+ Hello World! +

+

+ Welcome to your Electron app with Tailwind CSS! +

+
+ + + + diff --git a/ui/index.js b/ui/index.js new file mode 100644 index 0000000..9751fb2 --- /dev/null +++ b/ui/index.js @@ -0,0 +1,29 @@ +const { app, BrowserWindow, ipcMain } = require('electron'); +const path = require('node:path'); + +function createWindow () { + const mainWindow = new BrowserWindow({ + width: 800, + height: 600, + webPreferences: { + preload: path.join(__dirname, 'preload.js'), + contextIsolation: true, + nodeIntegration: false + } + }); + + mainWindow.loadFile('index.html'); +} + +app.whenReady().then(() => { + createWindow(); + + app.on('activate', function () { + if (BrowserWindow.getAllWindows().length === 0) createWindow(); + }); +}); + +app.on('window-all-closed', function () { + if (process.platform !== 'darwin') app.quit(); +}); + diff --git a/ui/package.json b/ui/package.json new file mode 100644 index 0000000..1c56341 --- /dev/null +++ b/ui/package.json @@ -0,0 +1,24 @@ +{ + "name": "TaSTT", + "version": "1.0.0", + "description": "Speech-to-text tool for VRChat", + "main": "index.js", + "scripts": { + "start": "npm run build:css && electron .", + "build:css": "tailwindcss -i ./src/input.css -o ./build/output.css", + "watch:css": "tailwindcss -i ./src/input.css -o ./build/output.css --watch", + "dev": "npm run watch:css & electron .", + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "yum_food", + "license": "MIT", + "devDependencies": { + "autoprefixer": "^10.4.21", + "concurrently": "^9.1.2", + "cross-env": "^7.0.3", + "electron": "^36.3.2", + "postcss": "^8.5.4", + "tailwindcss": "^3.4.17" + } +} diff --git a/ui/postcss.config.js b/ui/postcss.config.js new file mode 100644 index 0000000..33ad091 --- /dev/null +++ b/ui/postcss.config.js @@ -0,0 +1,6 @@ +module.exports = { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/ui/preload.js b/ui/preload.js new file mode 100644 index 0000000..9f87d19 --- /dev/null +++ b/ui/preload.js @@ -0,0 +1,7 @@ +const { contextBridge, ipcRenderer } = require('electron'); + +contextBridge.exposeInMainWorld('electronAPI', { +}); + +console.log('Preload script loaded.'); + diff --git a/ui/src/input.css b/ui/src/input.css new file mode 100644 index 0000000..b5c61c9 --- /dev/null +++ b/ui/src/input.css @@ -0,0 +1,3 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/ui/tailwind.config.js b/ui/tailwind.config.js new file mode 100644 index 0000000..fa93053 --- /dev/null +++ b/ui/tailwind.config.js @@ -0,0 +1,12 @@ +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: [ + "./index.html", + "./src/**/*.{html,js}", + ], + theme: { + extend: {}, + }, + plugins: [], +} + diff --git a/ui_design.md b/ui_design.md new file mode 100644 index 0000000..e38c632 --- /dev/null +++ b/ui_design.md @@ -0,0 +1,29 @@ +# TaSTT UI + +The TaSTT is built using electron and tailwind.css. + +First, install nodejs. Open PowerShell as administrator: + +```bash +# Delete any existing install. +$ choco uninstall nodejs -y +$ choco install nodejs-lts -y +``` + +Now open a non-admin PowerShell terminal: + +```bash +# Check your node and npm versions. +$ node -v +v22.16.0 +$ npm -v +10.9.2 +# Set up directory +$ mkdir ui +cd ui +npm init -y +npm install --save-dev electron +# Get tailwind and deps +npm install --save-dev tailwindcss@3 postcss autoprefixer concurrently cross-env +npx tailwindcss init -p +``` -- cgit v1.2.3 From 82a5b3805b2a54faea501ee362419330664c277a Mon Sep 17 00:00:00 2001 From: yum Date: Thu, 29 May 2025 17:23:09 -0700 Subject: Begin roughing out STT UI HEAVILY VIBE CODED! --- ui/index.html | 189 +++++++++++++++++++++++++++++++++++--- ui/index.js | 155 ++++++++++++++++++++++++++++++- ui/package.json | 14 ++- ui/preload.js | 6 ++ ui/renderer.js | 249 ++++++++++++++++++++++++++++++++++++++++++++++++++ ui/src/components.css | 110 ++++++++++++++++++++++ ui/tailwind.config.js | 5 +- 7 files changed, 708 insertions(+), 20 deletions(-) create mode 100644 ui/renderer.js create mode 100644 ui/src/components.css diff --git a/ui/index.html b/ui/index.html index 240e6ca..14cc354 100644 --- a/ui/index.html +++ b/ui/index.html @@ -3,18 +3,185 @@ - Hello World! - - -
-

- Hello World! -

-

- Welcome to your Electron app with Tailwind CSS! -

+ TaSTT + + + +
+

TaSTT

+ +
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + + + + + + +
+ +
+
+ + + +
+ + +
+
+

Python Output

+ +
+
+
+
+
+
- + + diff --git a/ui/index.js b/ui/index.js index 9751fb2..0a7fdf9 100644 --- a/ui/index.js +++ b/ui/index.js @@ -1,10 +1,71 @@ const { app, BrowserWindow, ipcMain } = require('electron'); const path = require('node:path'); +const fs = require('node:fs').promises; +const yaml = require('js-yaml'); +const { spawn } = require('child_process'); + +let mainWindow; + +// Helper function to get the correct Python executable from venv +function getVenvPython() { + const venvPath = path.join(__dirname, '..', 'venv'); + const isWindows = process.platform === 'win32'; + const pythonExecutable = isWindows ? 'python.exe' : 'python'; + const pythonPath = path.join(venvPath, isWindows ? 'Scripts' : 'bin', pythonExecutable); + return pythonPath; +} + +// Helper function to send Python output to renderer +function sendPythonOutput(message, type = 'stdout') { + if (mainWindow && !mainWindow.isDestroyed()) { + mainWindow.webContents.send('python-output', { message, type }); + } +} + +// Helper function to execute Python commands using venv +function executePythonCommand(args, options = {}) { + return new Promise((resolve, reject) => { + const pythonPath = getVenvPython(); + const commandStr = `${path.basename(pythonPath)} ${args.join(' ')}`; + sendPythonOutput(`> ${commandStr}`, 'info'); + + const pythonProcess = spawn(pythonPath, args, options); + + let stdout = ''; + let stderr = ''; + + pythonProcess.stdout.on('data', (data) => { + const text = data.toString(); + stdout += text; + sendPythonOutput(text.trimEnd(), 'stdout'); + }); + + pythonProcess.stderr.on('data', (data) => { + const text = data.toString(); + stderr += text; + sendPythonOutput(text.trimEnd(), 'stderr'); + }); + + pythonProcess.on('error', (error) => { + sendPythonOutput(`Failed to start Python process: ${error.message}`, 'stderr'); + reject({ error: error.message, stdout, stderr }); + }); + + pythonProcess.on('close', (code) => { + if (code !== 0) { + sendPythonOutput(`Process exited with code ${code}`, 'stderr'); + reject({ code, stdout, stderr }); + } else { + resolve({ stdout, stderr }); + } + }); + }); +} function createWindow () { - const mainWindow = new BrowserWindow({ - width: 800, - height: 600, + mainWindow = new BrowserWindow({ + width: 1000, + height: 800, webPreferences: { preload: path.join(__dirname, 'preload.js'), contextIsolation: true, @@ -15,6 +76,94 @@ function createWindow () { mainWindow.loadFile('index.html'); } +// Path to config.yaml (one level up from ui directory) +const configPath = path.join(__dirname, '..', 'config.yaml'); + +// IPC handlers +ipcMain.handle('load-config', async () => { + try { + const fileContent = await fs.readFile(configPath, 'utf8'); + return yaml.load(fileContent); + } catch (error) { + console.error('Error loading config:', error); + throw error; + } +}); + +ipcMain.handle('save-config', async (event, config) => { + try { + const yamlContent = yaml.dump(config, { lineWidth: -1 }); + await fs.writeFile(configPath, yamlContent, 'utf8'); + return { success: true }; + } catch (error) { + console.error('Error saving config:', error); + throw error; + } +}); + +ipcMain.handle('restart-app', () => { + app.relaunch(); + app.exit(); +}); + +ipcMain.handle('install-requirements', async (event) => { + const requirementsPath = path.join(__dirname, '..', 'app', 'requirements.txt'); + + try { + // Check if requirements.txt exists + await fs.access(requirementsPath); + + const result = await executePythonCommand(['-m', 'pip', 'install', '-r', requirementsPath]); + + return { success: true, message: 'Requirements installed successfully' }; + } catch (error) { + console.error('Error installing requirements:', error); + if (error.code === 'ENOENT') { + throw new Error('requirements.txt not found'); + } + throw new Error(`Installation failed: ${error.stderr || error.error || 'Unknown error'}`); + } +}); + +ipcMain.handle('get-microphones', async () => { + const pythonScript = ` +import pyaudio +import json +import sys + +try: + p = pyaudio.PyAudio() + info = p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + microphones = [] + for i in range(0, numdevices): + device_info = p.get_device_info_by_host_api_device_index(0, i) + if device_info.get('maxInputChannels') > 0: + microphones.append({ + 'index': i, + 'name': device_info.get('name'), + 'defaultSampleRate': device_info.get('defaultSampleRate') + }) + + print(json.dumps(microphones)) + p.terminate() +except Exception as e: + print(json.dumps({'error': str(e)}), file=sys.stderr) + sys.exit(1) +`; + + try { + const result = await executePythonCommand(['-c', pythonScript]); + const microphones = JSON.parse(result.stdout.trim()); + console.log('Successfully retrieved microphones:', microphones); + return microphones; + } catch (error) { + console.error('Failed to get microphones:', error); + throw new Error(`Failed to get microphones: ${error.stderr || error.error || 'Unknown error'}`); + } +}); + app.whenReady().then(() => { createWindow(); diff --git a/ui/package.json b/ui/package.json index 1c56341..fee2d67 100644 --- a/ui/package.json +++ b/ui/package.json @@ -5,20 +5,26 @@ "main": "index.js", "scripts": { "start": "npm run build:css && electron .", - "build:css": "tailwindcss -i ./src/input.css -o ./build/output.css", - "watch:css": "tailwindcss -i ./src/input.css -o ./build/output.css --watch", - "dev": "npm run watch:css & electron .", + "build:css": "tailwindcss -i ./src/components.css -o ./build/output.css", + "watch:css": "tailwindcss -i ./src/components.css -o ./build/output.css --watch", + "dev": "concurrently \"npm run watch:css\" \"electron .\"", "test": "echo \"Error: no test specified\" && exit 1" }, "keywords": [], "author": "yum_food", "license": "MIT", + "dependencies": { + "js-yaml": "^4.1.0" + }, "devDependencies": { + "@vitejs/plugin-vue": "^5.2.4", "autoprefixer": "^10.4.21", "concurrently": "^9.1.2", "cross-env": "^7.0.3", "electron": "^36.3.2", "postcss": "^8.5.4", - "tailwindcss": "^3.4.17" + "tailwindcss": "^3.4.17", + "vite": "^6.3.5", + "vue": "^3.5.16" } } diff --git a/ui/preload.js b/ui/preload.js index 9f87d19..108bffe 100644 --- a/ui/preload.js +++ b/ui/preload.js @@ -1,6 +1,12 @@ const { contextBridge, ipcRenderer } = require('electron'); contextBridge.exposeInMainWorld('electronAPI', { + loadConfig: () => ipcRenderer.invoke('load-config'), + saveConfig: (config) => ipcRenderer.invoke('save-config', config), + restartApp: () => ipcRenderer.invoke('restart-app'), + getMicrophones: () => ipcRenderer.invoke('get-microphones'), + installRequirements: () => ipcRenderer.invoke('install-requirements'), + onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)) }); console.log('Preload script loaded.'); diff --git a/ui/renderer.js b/ui/renderer.js new file mode 100644 index 0000000..83c652c --- /dev/null +++ b/ui/renderer.js @@ -0,0 +1,249 @@ +// Handle status messages +function showStatus(message, type = 'info') { + const statusEl = document.getElementById('status-message'); + statusEl.textContent = message; + statusEl.classList.remove('hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800'); + + if (type === 'success') { + statusEl.classList.add('bg-green-100', 'text-green-800'); + } else if (type === 'error') { + statusEl.classList.add('bg-red-100', 'text-red-800'); + } else { + statusEl.classList.add('bg-blue-100', 'text-blue-800'); + } + + // Also log to console + appendToConsole(message, type === 'error' ? 'stderr' : 'info'); + + setTimeout(() => { + statusEl.classList.add('hidden'); + }, 5000); +} + +// Get form values +function getFormValues() { + return { + compute_type: document.getElementById('compute_type').value, + enable_debug_mode: document.getElementById('enable_debug_mode').checked ? 1 : 0, + enable_previews: document.getElementById('enable_previews').checked ? 1 : 0, + language: document.getElementById('language').value, + gpu_idx: parseInt(document.getElementById('gpu_idx').value), + max_speech_duration_s: parseInt(document.getElementById('max_speech_duration_s').value), + min_silence_duration_ms: parseInt(document.getElementById('min_silence_duration_ms').value), + microphone: document.getElementById('microphone').value, + model: document.getElementById('model').value, + reset_after_silence_s: parseInt(document.getElementById('reset_after_silence_s').value), + transcription_loop_delay_ms: parseInt(document.getElementById('transcription_loop_delay_ms').value), + use_cpu: document.getElementById('use_cpu').checked ? 1 : 0, + block_width: parseInt(document.getElementById('block_width').value), + num_blocks: parseInt(document.getElementById('num_blocks').value), + rows: parseInt(document.getElementById('rows').value), + cols: parseInt(document.getElementById('cols').value) + }; +} + +// Add a flag to prevent auto-save during programmatic updates +let isSettingValues = false; + +// Set form values +function setFormValues(config) { + isSettingValues = true; // Disable auto-save temporarily + + document.getElementById('compute_type').value = config.compute_type || 'int8'; + document.getElementById('enable_debug_mode').checked = config.enable_debug_mode === 1; + document.getElementById('enable_previews').checked = config.enable_previews === 1; + document.getElementById('language').value = config.language || 'english'; + document.getElementById('gpu_idx').value = config.gpu_idx || 0; + document.getElementById('max_speech_duration_s').value = config.max_speech_duration_s || 10; + document.getElementById('min_silence_duration_ms').value = config.min_silence_duration_ms || 250; + document.getElementById('microphone').value = config.microphone || 'motu'; + document.getElementById('model').value = config.model || 'turbo'; + document.getElementById('reset_after_silence_s').value = config.reset_after_silence_s || 15; + document.getElementById('transcription_loop_delay_ms').value = config.transcription_loop_delay_ms || 100; + document.getElementById('use_cpu').checked = config.use_cpu === 1; + document.getElementById('block_width').value = config.block_width || 2; + document.getElementById('num_blocks').value = config.num_blocks || 40; + document.getElementById('rows').value = config.rows || 10; + document.getElementById('cols').value = config.cols || 24; + + isSettingValues = false; // Re-enable auto-save +} + +// Toggle advanced settings +document.getElementById('toggle-advanced').addEventListener('click', () => { + const advancedSettings = document.getElementById('advanced-settings'); + const chevron = document.getElementById('chevron'); + + if (advancedSettings.classList.contains('hidden')) { + advancedSettings.classList.remove('hidden'); + chevron.classList.add('rotate-90'); + } else { + advancedSettings.classList.add('hidden'); + chevron.classList.remove('rotate-90'); + } +}); + +// Simplify button handlers by extracting common patterns +async function handleAsyncAction(actionName, actionFn) { + try { + const result = await actionFn(); + if (result && result.message) { + showStatus(result.message, 'success'); + } + return result; + } catch (error) { + showStatus(`${actionName} failed: ${error.message}`, 'error'); + throw error; + } +} + +// Auto-save functionality with debouncing +let saveTimeout; +const SAVE_DELAY = 500; // milliseconds + +async function autoSaveConfig() { + if (isSettingValues) return; // Don't save during programmatic updates + + clearTimeout(saveTimeout); + saveTimeout = setTimeout(async () => { + try { + const config = getFormValues(); + await window.electronAPI.saveConfig(config); + showStatus('Configuration saved', 'success'); + } catch (error) { + showStatus(`Failed to save configuration: ${error.message}`, 'error'); + } + }, SAVE_DELAY); +} + +// Add event listeners to all form inputs for auto-save +function setupAutoSave() { + // Get all form inputs + const form = document.getElementById('config-form'); + const inputs = form.querySelectorAll('input, select'); + + // Add change listener to each input + inputs.forEach(input => { + if (input.type === 'checkbox') { + input.addEventListener('change', autoSaveConfig); + } else if (input.type === 'number' || input.type === 'text') { + input.addEventListener('input', autoSaveConfig); + } else if (input.tagName === 'SELECT') { + input.addEventListener('change', autoSaveConfig); + } + }); +} + +// Update the setup-venv handler +document.getElementById('setup-venv').addEventListener('click', async () => { + const setupButton = document.getElementById('setup-venv'); + setupButton.disabled = true; + setupButton.classList.add('opacity-50', 'cursor-not-allowed'); + + try { + await handleAsyncAction('Install requirements', async () => { + return await window.electronAPI.installRequirements(); + }); + // Reload microphones after successful installation + await loadMicrophones(); + } finally { + setupButton.disabled = false; + setupButton.classList.remove('opacity-50', 'cursor-not-allowed'); + } +}); + +// Simplified microphone loading +async function loadMicrophones() { + const microphoneSelect = document.getElementById('microphone'); + + try { + appendToConsole('Loading available microphones...', 'info'); + const microphones = await window.electronAPI.getMicrophones(); + + microphoneSelect.innerHTML = ''; + + if (microphones.length === 0) { + microphoneSelect.innerHTML = ''; + appendToConsole('No microphones found', 'stderr'); + return; + } + + appendToConsole(`Found ${microphones.length} microphone(s)`, 'info'); + microphones.forEach(mic => { + const option = document.createElement('option'); + option.value = mic.index.toString(); + option.textContent = mic.name; + microphoneSelect.appendChild(option); + appendToConsole(` - ${mic.name} (Device ${mic.index})`, 'stdout'); + }); + + // Restore previously selected microphone if possible + try { + const config = await window.electronAPI.loadConfig(); + if (config.microphone) { + microphoneSelect.value = config.microphone; + } + } catch (error) { + // Ignore config load errors here + } + + } catch (error) { + appendToConsole(`Failed to load microphones: ${error.message}`, 'stderr'); + microphoneSelect.innerHTML = ''; + } +} + +// Update window load to include auto-save setup +window.addEventListener('load', async () => { + appendToConsole('TaSTT Configuration UI initialized', 'info'); + + // Load config first + try { + const config = await window.electronAPI.loadConfig(); + setFormValues(config); + appendToConsole('Configuration loaded', 'info'); + } catch (error) { + appendToConsole(`Failed to load configuration: ${error.message}`, 'stderr'); + } + + // Load microphones + await loadMicrophones(); + + // Set up auto-save after everything is loaded + setupAutoSave(); +}); + +// Console management +const consoleContent = document.getElementById('console-content'); + +function appendToConsole(message, type = 'stdout') { + const timestamp = new Date().toLocaleTimeString(); + const timestampSpan = document.createElement('span'); + timestampSpan.className = 'console-timestamp'; + timestampSpan.textContent = `[${timestamp}] `; + + const messageSpan = document.createElement('span'); + messageSpan.className = `console-${type}`; + messageSpan.textContent = message; + + const lineDiv = document.createElement('div'); + lineDiv.appendChild(timestampSpan); + lineDiv.appendChild(messageSpan); + + consoleContent.appendChild(lineDiv); + + // Auto-scroll to bottom + const pythonConsole = document.getElementById('python-console'); + pythonConsole.scrollTop = pythonConsole.scrollHeight; +} + +// Clear console button +document.getElementById('clear-console').addEventListener('click', () => { + consoleContent.innerHTML = ''; + appendToConsole('Console cleared', 'info'); +}); + +// Listen for Python output +window.electronAPI.onPythonOutput((data) => { + appendToConsole(data.message, data.type); +}); \ No newline at end of file diff --git a/ui/src/components.css b/ui/src/components.css new file mode 100644 index 0000000..be046ea --- /dev/null +++ b/ui/src/components.css @@ -0,0 +1,110 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +@layer components { + .config-section { + @apply bg-white rounded-lg shadow-md p-6; + } + + .section-title { + @apply text-xl font-semibold text-gray-700 mb-4; + } + + .form-label { + @apply block text-sm font-medium text-gray-700 mb-2; + } + + .form-input { + @apply w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm; + } + + .checkbox-label { + @apply flex items-center cursor-pointer hover:bg-gray-50 p-2 rounded; + } + + .checkbox-text { + @apply text-sm text-gray-700; + } + + .btn { + @apply px-4 py-2 font-medium text-sm rounded-md transition-colors focus:outline-none focus:ring-2 focus:ring-offset-2; + } + + .btn-blue { + @apply bg-blue-600 text-white hover:bg-blue-700 focus:ring-blue-500; + } + + .btn-green { + @apply bg-green-600 text-white hover:bg-green-700 focus:ring-green-500; + } + + .btn-gray { + @apply bg-gray-600 text-white hover:bg-gray-700 focus:ring-gray-500; + } +} + +/* Console styling */ +#python-console { + background-color: #1a1a1a; + font-family: 'Consolas', 'Monaco', 'Courier New', monospace; + line-height: 1.4; +} + +#console-content { + word-wrap: break-word; +} + +/* Console text colors */ +.console-stdout { + color: #a8cc8c; +} + +.console-stderr { + color: #e88388; +} + +.console-info { + color: #66c2cd; +} + +.console-timestamp { + color: #6c7986; + font-size: 0.875rem; +} + +/* Ensure full height layout */ +html, body { + height: 100%; + margin: 0; + padding: 0; +} + +.container-fluid { + max-width: 100%; + height: 100vh; +} + +/* Scrollbar styling for console */ +#python-console::-webkit-scrollbar { + width: 8px; +} + +#python-console::-webkit-scrollbar-track { + background: #2a2a2a; +} + +#python-console::-webkit-scrollbar-thumb { + background: #4a4a4a; + border-radius: 4px; +} + +#python-console::-webkit-scrollbar-thumb:hover { + background: #5a5a5a; +} + +/* Ensure buttons have proper disabled states */ +button:disabled { + cursor: not-allowed; + opacity: 0.5; +} diff --git a/ui/tailwind.config.js b/ui/tailwind.config.js index fa93053..804b7f0 100644 --- a/ui/tailwind.config.js +++ b/ui/tailwind.config.js @@ -1,8 +1,9 @@ /** @type {import('tailwindcss').Config} */ module.exports = { content: [ - "./index.html", - "./src/**/*.{html,js}", + "./*.html", + "./*.js", + "./src/**/*.{html,js}" ], theme: { extend: {}, -- cgit v1.2.3 From f97cef182de55b6dbae8d2bc0477acfca6cc1f66 Mon Sep 17 00:00:00 2001 From: yum Date: Thu, 29 May 2025 19:45:48 -0700 Subject: More UI work 1. main STT app works in new project structure 2. UI dumps mics on startup to populate mic list 3. add missing deps (hf-xet, wave) 4. normalize audio volume when transcribing. Probably still wrong tbqh. 5. add checkbox to save audio segments & improve logic so only segments with speech get saved. 6. add default config settings --- app/hi.py | 7 +- app/list_microphones.py | 24 ++++++ app/requirements.txt | 3 +- app/stt.py | 55 ++++++++++++-- app/vad.py | 3 +- config.yaml | 7 +- ui/index.html | 27 ++++--- ui/index.js | 196 ++++++++++++++++++++++++++++++++++++++---------- ui/preload.js | 5 +- ui/renderer.js | 87 ++++++++++++++++++++- ui/src/components.css | 4 + ui_design.md | 3 + 12 files changed, 355 insertions(+), 66 deletions(-) create mode 100644 app/list_microphones.py diff --git a/app/hi.py b/app/hi.py index 0129958..0d80b9d 100644 --- a/app/hi.py +++ b/app/hi.py @@ -2,6 +2,7 @@ import app_config import argparse from math import floor, ceil import msvcrt +import os from pythonosc import udp_client import sentencepiece as spm from shared_thread_data import SharedThreadData @@ -15,8 +16,11 @@ TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose LOG_LEVEL = 0 +APP_ROOT = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.dirname(APP_ROOT) + def get_tokenizer(): - model_path = "./custom_unigram_tokenizer_65k/unigram.model" + model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model") print(f"Loading SentencePiece tokenizer from: {model_path}") sp = spm.SentencePieceProcessor() sp.load(model_path) @@ -346,7 +350,6 @@ if __name__ == "__main__": time.sleep(0.1) continue - try: char = char_bytes.decode('utf-8') if char == '\r' or char == '\n': diff --git a/app/list_microphones.py b/app/list_microphones.py new file mode 100644 index 0000000..a6b1f36 --- /dev/null +++ b/app/list_microphones.py @@ -0,0 +1,24 @@ +import pyaudio +import json +import sys + +try: + p = pyaudio.PyAudio() + info = p.get_host_api_info_by_index(0) + numdevices = info.get('deviceCount') + + microphones = [] + for i in range(0, numdevices): + device_info = p.get_device_info_by_host_api_device_index(0, i) + if device_info.get('maxInputChannels') > 0: + microphones.append({ + 'index': i, + 'name': device_info.get('name'), + 'defaultSampleRate': device_info.get('defaultSampleRate') + }) + + print(json.dumps(microphones)) + p.terminate() +except Exception as e: + print(json.dumps({'error': str(e)}), file=sys.stderr) + sys.exit(1) \ No newline at end of file diff --git a/app/requirements.txt b/app/requirements.txt index 4e79312..07f94cd 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,7 +1,8 @@ faster-whisper +hf-xet langcodes pyaudio pydub python-osc sentencepiece - +wave diff --git a/app/stt.py b/app/stt.py index 34ef2e9..c157f6d 100644 --- a/app/stt.py +++ b/app/stt.py @@ -1,3 +1,4 @@ +from datetime import datetime from faster_whisper import WhisperModel import langcodes import numpy as np @@ -9,6 +10,11 @@ import sys import time import typing import vad +import wave + + +APP_ROOT = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.dirname(APP_ROOT) class AudioStream(): FORMAT = pyaudio.paInt16 @@ -242,6 +248,26 @@ class NormalizingAudioCollector(AudioCollectorFilter): return frames +class BoostingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, target_dBFS: float, cfg: typing.Dict): + AudioCollectorFilter.__init__(self, parent) + self.target_dBFS = target_dBFS + self.cfg = cfg + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + + audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, + frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + if self.cfg["enable_debug_mode"]: + print(f"Boosting audio from {audio.dBFS}dB to {self.target_dBFS}dB", file=sys.stderr) + audio = audio.apply_gain(self.target_dBFS - audio.dBFS) + + frames = np.array(audio.get_array_of_samples()) + frames = np.int16(frames).tobytes() + + return frames + class CompressingAudioCollector(AudioCollectorFilter): def __init__(self, parent: AudioCollector): AudioCollectorFilter.__init__(self, parent) @@ -441,6 +467,16 @@ class TranscriptCommit: self.duration_s = duration_s +def saveAudio(audio: bytes, path: str, cfg: typing.Dict): + with wave.open(path, 'wb') as wf: + if cfg["enable_debug_mode"]: + print(f"Saving audio to {path}", file=sys.stderr) + wf.setnchannels(AudioStream.CHANNELS) + wf.setsampwidth(AudioStream.FRAME_SZ) + wf.setframerate(AudioStream.FPS) + wf.writeframes(audio) + + class VadCommitter: def __init__(self, cfg: typing.Dict, @@ -463,7 +499,6 @@ class VadCommitter: start_ts = self.collector.begin() if has_audio and stable_cutoff: - #print(f"stable cutoff get: {stable_cutoff}", file=sys.stderr) latency_s = self.collector.now() - self.collector.begin() duration_s = stable_cutoff / AudioStream.FPS start_ts = self.collector.begin() @@ -475,12 +510,16 @@ class VadCommitter: if self.cfg["enable_debug_mode"]: for s in segments: print(f"commit segment: {s}", file=sys.stderr) - print(f"delta get: {delta}", file=sys.stderr) + if len(delta) > 0: + print(f"delta get: {delta}", file=sys.stderr) - if False: + if self.cfg["save_audio"] and len(delta) > 0: ts = datetime.fromtimestamp(self.collector.now() - latency_s) filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav" - saveAudio(commit_audio, filename) + audio_dir = os.path.join(PROJECT_ROOT, "audio") + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + saveAudio(commit_audio, os.path.join(audio_dir, filename), self.cfg) preview = "" if self.cfg["enable_previews"] and has_audio: @@ -488,7 +527,6 @@ class VadCommitter: preview = "".join(s.transcript for s in segments) if not has_audio: - #print("VAD detects no audio, skip transcription", file=sys.stderr) self.collector.keepLast(1.0) return TranscriptCommit( @@ -504,8 +542,9 @@ def transcriptionThread(shared_data: SharedThreadData): stream = MicStream(shared_data.cfg["microphone"]) collector = AudioCollector(stream) - collector = NormalizingAudioCollector(collector) collector = CompressingAudioCollector(collector) + collector = NormalizingAudioCollector(collector) + collector = BoostingAudioCollector(collector, 0.0, shared_data.cfg) whisper = Whisper(collector, shared_data.cfg) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"]) @@ -552,13 +591,13 @@ def transcriptionThread(shared_data: SharedThreadData): preview = commit.preview try: - print(f"Transcript: {transcript}") + print(f"Transcript: {transcript}", flush=True) except UnicodeEncodeError: print("Failed to encode transcript - discarding delta", file=sys.stderr) continue try: - print(f"Preview: {preview}") + print(f"Preview: {preview}", flush=True) except UnicodeEncodeError: print("Failed to encode preview - discarding", file=sys.stderr) diff --git a/app/vad.py b/app/vad.py index 10a72d3..1dea765 100644 --- a/app/vad.py +++ b/app/vad.py @@ -259,7 +259,8 @@ def get_vad_model(): """Returns the VAD model instance.""" abspath = os.path.abspath(__file__) my_dir = os.path.dirname(abspath) - path = os.path.join(my_dir, "Models/silero_vad.onnx") + parent_dir = os.path.dirname(my_dir) + path = os.path.join(parent_dir, "Models", "silero_vad.onnx") return SileroVADModel(path) diff --git a/config.yaml b/config.yaml index 164b4e6..34d88f1 100644 --- a/config.yaml +++ b/config.yaml @@ -1,18 +1,17 @@ -compute_type: int8 +compute_type: float16 enable_debug_mode: 0 enable_previews: 1 +save_audio: 0 language: english gpu_idx: 0 max_speech_duration_s: 10 min_silence_duration_ms: 250 -microphone: motu +microphone: 0 model: turbo reset_after_silence_s: 15 transcription_loop_delay_ms: 100 use_cpu: 0 - block_width: 2 num_blocks: 40 rows: 10 cols: 24 - diff --git a/ui/index.html b/ui/index.html index 14cc354..b06e56b 100644 --- a/ui/index.html +++ b/ui/index.html @@ -8,11 +8,9 @@
-

TaSTT

-
- -
+ +
@@ -127,6 +125,10 @@ Enable Previews +
@@ -156,9 +158,17 @@
- +
+ + + +
@@ -167,9 +177,8 @@
-
+
-

Python Output

diff --git a/ui/index.js b/ui/index.js index 0a7fdf9..a056156 100644 --- a/ui/index.js +++ b/ui/index.js @@ -4,14 +4,16 @@ const fs = require('node:fs').promises; const yaml = require('js-yaml'); const { spawn } = require('child_process'); +const APP_ROOT = path.join(__dirname, '..'); +const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); + let mainWindow; +let runningProcess = null; // Track the running Python process // Helper function to get the correct Python executable from venv function getVenvPython() { - const venvPath = path.join(__dirname, '..', 'venv'); - const isWindows = process.platform === 'win32'; - const pythonExecutable = isWindows ? 'python.exe' : 'python'; - const pythonPath = path.join(venvPath, isWindows ? 'Scripts' : 'bin', pythonExecutable); + const venvPath = path.join(APP_ROOT, 'venv'); + const pythonPath = path.join(venvPath, 'Scripts', 'python.exe'); return pythonPath; } @@ -29,7 +31,17 @@ function executePythonCommand(args, options = {}) { const commandStr = `${path.basename(pythonPath)} ${args.join(' ')}`; sendPythonOutput(`> ${commandStr}`, 'info'); - const pythonProcess = spawn(pythonPath, args, options); + // Add dll directory to PATH for Windows DLL loading + const dllPath = path.join(APP_ROOT, 'dll'); + const env = { ...process.env }; + env.PATH = `${dllPath};${env.PATH}`; + + const spawnOptions = { + ...options, + env + }; + + const pythonProcess = spawn(pythonPath, args, spawnOptions); let stdout = ''; let stderr = ''; @@ -76,15 +88,47 @@ function createWindow () { mainWindow.loadFile('index.html'); } -// Path to config.yaml (one level up from ui directory) -const configPath = path.join(__dirname, '..', 'config.yaml'); +// Default configuration based on user's current config.yaml +const DEFAULT_CONFIG = { + compute_type: 'float16', + enable_debug_mode: 0, + enable_previews: 1, + save_audio: 0, + language: 'english', + gpu_idx: 0, + max_speech_duration_s: 10, + min_silence_duration_ms: 250, + microphone: 0, + model: 'turbo', + reset_after_silence_s: 15, + transcription_loop_delay_ms: 100, + use_cpu: 0, + block_width: 2, + num_blocks: 40, + rows: 10, + cols: 24 +}; // IPC handlers ipcMain.handle('load-config', async () => { try { - const fileContent = await fs.readFile(configPath, 'utf8'); + const fileContent = await fs.readFile(CONFIG_PATH, 'utf8'); return yaml.load(fileContent); } catch (error) { + if (error.code === 'ENOENT') { + // Config file doesn't exist, create it with defaults + console.log('Config file not found, creating with defaults...'); + try { + const yamlContent = yaml.dump(DEFAULT_CONFIG, { lineWidth: -1 }); + await fs.writeFile(CONFIG_PATH, yamlContent, 'utf8'); + console.log('Created config.yaml with default values'); + return DEFAULT_CONFIG; + } catch (writeError) { + console.error('Error creating default config:', writeError); + // Return defaults even if we can't write the file + return DEFAULT_CONFIG; + } + } console.error('Error loading config:', error); throw error; } @@ -93,7 +137,7 @@ ipcMain.handle('load-config', async () => { ipcMain.handle('save-config', async (event, config) => { try { const yamlContent = yaml.dump(config, { lineWidth: -1 }); - await fs.writeFile(configPath, yamlContent, 'utf8'); + await fs.writeFile(CONFIG_PATH, yamlContent, 'utf8'); return { success: true }; } catch (error) { console.error('Error saving config:', error); @@ -107,7 +151,7 @@ ipcMain.handle('restart-app', () => { }); ipcMain.handle('install-requirements', async (event) => { - const requirementsPath = path.join(__dirname, '..', 'app', 'requirements.txt'); + const requirementsPath = path.join(APP_ROOT, 'app', 'requirements.txt'); try { // Check if requirements.txt exists @@ -126,35 +170,10 @@ ipcMain.handle('install-requirements', async (event) => { }); ipcMain.handle('get-microphones', async () => { - const pythonScript = ` -import pyaudio -import json -import sys - -try: - p = pyaudio.PyAudio() - info = p.get_host_api_info_by_index(0) - numdevices = info.get('deviceCount') - - microphones = [] - for i in range(0, numdevices): - device_info = p.get_device_info_by_host_api_device_index(0, i) - if device_info.get('maxInputChannels') > 0: - microphones.append({ - 'index': i, - 'name': device_info.get('name'), - 'defaultSampleRate': device_info.get('defaultSampleRate') - }) - - print(json.dumps(microphones)) - p.terminate() -except Exception as e: - print(json.dumps({'error': str(e)}), file=sys.stderr) - sys.exit(1) -`; - + const scriptPath = path.join(APP_ROOT, 'app', 'list_microphones.py'); + try { - const result = await executePythonCommand(['-c', pythonScript]); + const result = await executePythonCommand([scriptPath]); const microphones = JSON.parse(result.stdout.trim()); console.log('Successfully retrieved microphones:', microphones); return microphones; @@ -164,6 +183,105 @@ except Exception as e: } }); +// Add handlers for starting and stopping the process +ipcMain.handle('start-process', async () => { + if (runningProcess) { + throw new Error('Process is already running'); + } + + const scriptPath = path.join(APP_ROOT, 'app', 'hi.py'); + const configPath = CONFIG_PATH; + + try { + const pythonPath = getVenvPython(); + const args = [scriptPath, '--config', configPath]; + + sendPythonOutput(`Starting process: ${path.basename(pythonPath)} ${args.join(' ')}`, 'info'); + + // Add dll directory to PATH for Windows DLL loading + const dllPath = path.join(APP_ROOT, 'dll'); + const env = { ...process.env }; + env.PATH = `${dllPath};${env.PATH}`; + + runningProcess = spawn(pythonPath, args, { env }); + + runningProcess.stdout.on('data', (data) => { + const text = data.toString(); + sendPythonOutput(text.trimEnd(), 'stdout'); + }); + + runningProcess.stderr.on('data', (data) => { + const text = data.toString(); + sendPythonOutput(text.trimEnd(), 'stderr'); + }); + + runningProcess.on('error', (error) => { + sendPythonOutput(`Process error: ${error.message}`, 'stderr'); + runningProcess = null; + if (mainWindow && !mainWindow.isDestroyed()) { + mainWindow.webContents.send('process-stopped'); + } + }); + + runningProcess.on('close', (code) => { + sendPythonOutput(`Process exited with code ${code}`, 'info'); + runningProcess = null; + if (mainWindow && !mainWindow.isDestroyed()) { + mainWindow.webContents.send('process-stopped'); + } + }); + + return { success: true }; + } catch (error) { + runningProcess = null; + throw error; + } +}); + +ipcMain.handle('stop-process', async () => { + if (!runningProcess) { + throw new Error('No process is running'); + } + + return new Promise((resolve, reject) => { + let forcefullyKilled = false; + + // Set up a timeout to force kill after 10 seconds + const killTimeout = setTimeout(() => { + if (runningProcess) { + sendPythonOutput('Process did not stop gracefully, forcing termination...', 'stderr'); + forcefullyKilled = true; + runningProcess.kill(); + } + }, 10000); + + // Listen for the process to exit + runningProcess.once('exit', (code, signal) => { + clearTimeout(killTimeout); + runningProcess = null; + + if (forcefullyKilled) { + sendPythonOutput('Process forcefully terminated', 'info'); + } else { + sendPythonOutput('Process stopped gracefully', 'info'); + } + + resolve({ success: true, forcefullyKilled }); + }); + + // Send termination signal + sendPythonOutput('Stopping process gracefully...', 'info'); + runningProcess.kill(); + }); +}); + +// Clean up on app quit +app.on('before-quit', () => { + if (runningProcess) { + runningProcess.kill(); + } +}); + app.whenReady().then(() => { createWindow(); @@ -173,6 +291,6 @@ app.whenReady().then(() => { }); app.on('window-all-closed', function () { - if (process.platform !== 'darwin') app.quit(); + app.quit(); }); diff --git a/ui/preload.js b/ui/preload.js index 108bffe..e6c0623 100644 --- a/ui/preload.js +++ b/ui/preload.js @@ -6,7 +6,10 @@ contextBridge.exposeInMainWorld('electronAPI', { restartApp: () => ipcRenderer.invoke('restart-app'), getMicrophones: () => ipcRenderer.invoke('get-microphones'), installRequirements: () => ipcRenderer.invoke('install-requirements'), - onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)) + startProcess: () => ipcRenderer.invoke('start-process'), + stopProcess: () => ipcRenderer.invoke('stop-process'), + onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)), + onProcessStopped: (callback) => ipcRenderer.on('process-stopped', (event) => callback()) }); console.log('Preload script loaded.'); diff --git a/ui/renderer.js b/ui/renderer.js index 83c652c..b3f05a6 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -22,15 +22,20 @@ function showStatus(message, type = 'info') { // Get form values function getFormValues() { + const microphoneValue = document.getElementById('microphone').value; + // Convert to number if it's a numeric string (device index) + const microphoneForConfig = /^\d+$/.test(microphoneValue) ? parseInt(microphoneValue) : microphoneValue; + return { compute_type: document.getElementById('compute_type').value, enable_debug_mode: document.getElementById('enable_debug_mode').checked ? 1 : 0, enable_previews: document.getElementById('enable_previews').checked ? 1 : 0, + save_audio: document.getElementById('save_audio').checked ? 1 : 0, language: document.getElementById('language').value, gpu_idx: parseInt(document.getElementById('gpu_idx').value), max_speech_duration_s: parseInt(document.getElementById('max_speech_duration_s').value), min_silence_duration_ms: parseInt(document.getElementById('min_silence_duration_ms').value), - microphone: document.getElementById('microphone').value, + microphone: microphoneForConfig, model: document.getElementById('model').value, reset_after_silence_s: parseInt(document.getElementById('reset_after_silence_s').value), transcription_loop_delay_ms: parseInt(document.getElementById('transcription_loop_delay_ms').value), @@ -52,6 +57,7 @@ function setFormValues(config) { document.getElementById('compute_type').value = config.compute_type || 'int8'; document.getElementById('enable_debug_mode').checked = config.enable_debug_mode === 1; document.getElementById('enable_previews').checked = config.enable_previews === 1; + document.getElementById('save_audio').checked = config.save_audio === 1; document.getElementById('language').value = config.language || 'english'; document.getElementById('gpu_idx').value = config.gpu_idx || 0; document.getElementById('max_speech_duration_s').value = config.max_speech_duration_s || 10; @@ -97,6 +103,30 @@ async function handleAsyncAction(actionName, actionFn) { } } +// Process control buttons +const startButton = document.getElementById('start-process'); +const stopButton = document.getElementById('stop-process'); + +// Helper functions for button state management +function setButtonState(button, disabled) { + button.disabled = disabled; + if (disabled) { + button.classList.add('opacity-50', 'cursor-not-allowed'); + } else { + button.classList.remove('opacity-50', 'cursor-not-allowed'); + } +} + +function setProcessRunningState() { + setButtonState(startButton, true); + setButtonState(stopButton, false); +} + +function setProcessStoppedState() { + setButtonState(startButton, false); + setButtonState(stopButton, true); +} + // Auto-save functionality with debouncing let saveTimeout; const SAVE_DELAY = 500; // milliseconds @@ -110,6 +140,31 @@ async function autoSaveConfig() { const config = getFormValues(); await window.electronAPI.saveConfig(config); showStatus('Configuration saved', 'success'); + + // Check if process is running (stop button is enabled means process is running) + const stopButton = document.getElementById('stop-process'); + + if (!stopButton.disabled) { + // Process is running, restart it with new config + appendToConsole('Restarting process with new configuration...', 'info'); + + try { + await window.electronAPI.stopProcess(); + + await new Promise(resolve => setTimeout(resolve, 1000)); + + await window.electronAPI.startProcess(); + + // Update button states to reflect running process + setProcessRunningState(); + + appendToConsole('Process restarted with new configuration', 'info'); + } catch (error) { + appendToConsole(`Failed to restart process: ${error.message}`, 'stderr'); + // Process is stopped, update button states + setProcessStoppedState(); + } + } } catch (error) { showStatus(`Failed to save configuration: ${error.message}`, 'error'); } @@ -246,4 +301,34 @@ document.getElementById('clear-console').addEventListener('click', () => { // Listen for Python output window.electronAPI.onPythonOutput((data) => { appendToConsole(data.message, data.type); +}); + +document.getElementById('start-process').addEventListener('click', async () => { + setButtonState(startButton, true); + + try { + await window.electronAPI.startProcess(); + setProcessRunningState(); + appendToConsole('Process started successfully', 'info'); + } catch (error) { + appendToConsole(`Failed to start process: ${error.message}`, 'stderr'); + setButtonState(startButton, false); + } +}); + +document.getElementById('stop-process').addEventListener('click', async () => { + setButtonState(stopButton, true); + + try { + const result = await window.electronAPI.stopProcess(); + appendToConsole('Process stop initiated', 'info'); + } catch (error) { + appendToConsole(`Failed to stop process: ${error.message}`, 'stderr'); + setButtonState(stopButton, false); + } +}); + +// Listen for process stopped event +window.electronAPI.onProcessStopped(() => { + setProcessStoppedState(); }); \ No newline at end of file diff --git a/ui/src/components.css b/ui/src/components.css index be046ea..d8d909d 100644 --- a/ui/src/components.css +++ b/ui/src/components.css @@ -42,6 +42,10 @@ .btn-gray { @apply bg-gray-600 text-white hover:bg-gray-700 focus:ring-gray-500; } + + .btn-red { + @apply bg-red-600 text-white hover:bg-red-700 focus:ring-red-500; + } } /* Console styling */ diff --git a/ui_design.md b/ui_design.md index e38c632..06eee65 100644 --- a/ui_design.md +++ b/ui_design.md @@ -26,4 +26,7 @@ npm install --save-dev electron # Get tailwind and deps npm install --save-dev tailwindcss@3 postcss autoprefixer concurrently cross-env npx tailwindcss init -p +# Install vue.js +npm install --save-dev vue@3 @vitejs/plugin-vue vite yaml +npm install --save-dev js-yaml ``` -- cgit v1.2.3 From e1b3f638a1ea448de9691f69eb62ebf4c3944c9f Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 30 May 2025 02:50:55 -0700 Subject: More polish - Filters actually get applied now, huge accuracy boost - Use silero-vad python library instead of rolling our own - Expose prompt parameter - Auto setup venv on launch - Clean up python output - Auto acquire all dependencies on launch - Add icon --- .cursorignore | 2 + .gitignore | 2 +- Images/favicon.ico | Bin 0 -> 92015 bytes app/hi.py | 12 +- app/requirements.txt | 2 +- app/stt.py | 128 +++++++++--- app/vad.py | 314 ---------------------------- config.yaml | 1 + ui/index.html | 336 +++++++++++++++++------------- ui/index.js | 382 +++++++++++++++++++++++++++++----- ui/package.json | 76 ++++++- ui/preload.js | 7 +- ui/renderer.js | 564 +++++++++++++++++++++++++++++++------------------- ui/src/components.css | 8 + ui_design.md | 9 +- 15 files changed, 1085 insertions(+), 758 deletions(-) create mode 100644 .cursorignore create mode 100644 Images/favicon.ico delete mode 100644 app/vad.py diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 0000000..a8f4624 --- /dev/null +++ b/.cursorignore @@ -0,0 +1,2 @@ +**/node_modules +**/site-packages \ No newline at end of file diff --git a/.gitignore b/.gitignore index a102cf0..d3886ca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .*.sw[po] *.meta - +.venv_is_set_up diff --git a/Images/favicon.ico b/Images/favicon.ico new file mode 100644 index 0000000..25ea9ac Binary files /dev/null and b/Images/favicon.ico differ diff --git a/app/hi.py b/app/hi.py index 0d80b9d..e6877ff 100644 --- a/app/hi.py +++ b/app/hi.py @@ -330,10 +330,11 @@ if __name__ == "__main__": cli_args = parse_args() cfg = app_config.getConfig(cli_args.config) shared_data = SharedThreadData(cfg) - osc_thread = threading.Thread( - target=osc_thread, - args=(shared_data,)) - osc_thread.start() + if False: + osc_thread = threading.Thread( + target=osc_thread, + args=(shared_data,)) + osc_thread.start() transcribe_thread = threading.Thread( target=stt.transcriptionThread, @@ -382,6 +383,7 @@ if __name__ == "__main__": local_word = shared_data.word print(local_word + "_") shared_data.exit_event.set() - osc_thread.join() + if False: + osc_thread.join() transcribe_thread.join() diff --git a/app/requirements.txt b/app/requirements.txt index 07f94cd..f8b7069 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -5,4 +5,4 @@ pyaudio pydub python-osc sentencepiece -wave +silero-vad diff --git a/app/stt.py b/app/stt.py index c157f6d..7d76333 100644 --- a/app/stt.py +++ b/app/stt.py @@ -6,10 +6,10 @@ import os import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData +from silero_vad import load_silero_vad, get_speech_timestamps import sys import time import typing -import vad import wave @@ -33,7 +33,7 @@ class AudioStream(): class MicStream(AudioStream): CHUNK_SZ = 1024 - def __init__(self, which_mic: str): + def __init__(self, cfg: typing.Dict): self.p = pyaudio.PyAudio() self.stream = None self.sample_rate = None @@ -45,8 +45,11 @@ class MicStream(AudioStream): # If set, incoming frames are simply discarded. self.paused = False - print(f"Finding mic {which_mic}", file=sys.stderr) - self.dumpMicDevices() + which_mic = cfg["microphone"] + + if cfg["enable_debug_mode"]: + print(f"Finding mic {which_mic}", file=sys.stderr) + self.dumpMicDevices() got_match = False device_index = -1 @@ -59,8 +62,9 @@ class MicStream(AudioStream): elif which_mic == "beyond": target_str = "Microphone (Beyond)" else: - print(f"Mic {which_mic} requested, treating it as a numerical " + - "device ID", file=sys.stderr) + if cfg["enable_debug_mode"]: + print(f"Mic {which_mic} requested, treating it as a numerical " + + "device ID", file=sys.stderr) device_index = int(which_mic) got_match = True if not got_match: @@ -79,9 +83,11 @@ class MicStream(AudioStream): raise KeyError(f"Mic {which_mic} not found") info = self.p.get_device_info_by_host_api_device_index(0, device_index) - print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) + if cfg["enable_debug_mode"]: + print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr) self.sample_rate = int(info['defaultSampleRate']) - print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) + if cfg["enable_debug_mode"]: + print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr) self.stream = self.p.open( rate=self.sample_rate, @@ -289,19 +295,40 @@ class AudioSegmenter: def __init__(self, min_silence_ms=250, max_speech_s=5): - self.vad_options = vad.VadOptions( - min_silence_duration_ms=min_silence_ms, - max_speech_duration_s=max_speech_s) - pass + self.min_silence_ms = min_silence_ms + self.max_speech_s = max_speech_s + + # Load Silero VAD model + self.model = load_silero_vad() + + self.vad_threshold = 0.3 + self.min_silence_duration_ms = min_silence_ms + self.max_speech_duration_s = max_speech_s + + self.speech_pad_ms = 300 def segmentAudio(self, audio: bytes): - audio = np.frombuffer(audio, + # Convert audio bytes to numpy array expected by silero-vad + audio_array = np.frombuffer(audio, dtype=np.int16).flatten().astype(np.float32) / 32768.0 - return vad.get_speech_timestamps(audio, vad_options=self.vad_options) + + # Get speech timestamps using silero-vad + # Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS + speech_timestamps = get_speech_timestamps( + audio_array, + self.model, + sampling_rate=AudioStream.FPS, + threshold=self.vad_threshold, + min_silence_duration_ms=self.min_silence_duration_ms, + max_speech_duration_s=self.max_speech_duration_s, + return_seconds=False # We want frame indices, not seconds + ) + + return speech_timestamps # Returns the stable cutoff (if any) and whether there are any segments. def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]: - min_delta_frames = int((self.vad_options.min_silence_duration_ms * + min_delta_frames = int((self.min_silence_duration_ms * AudioStream.FPS) / 1000.0) cutoff = None @@ -379,8 +406,9 @@ class Whisper: model_str = cfg["model"] model_root = os.path.join(parent_dir, "Models", os.path.normpath(model_str)) - print(f"Model {cfg['model']} will be saved to {model_root}", - file=sys.stderr) + if cfg["enable_debug_mode"]: + print(f"Model {cfg['model']} will be saved to {model_root}", + file=sys.stderr) model_device = "cuda" if cfg["use_cpu"]: @@ -395,21 +423,42 @@ class Whisper: download_root = model_root, local_files_only = already_downloaded) + self.context_window_chars = 200 # Keep last 200 chars of context + self.recent_context = "" # Store recent committed text + + def update_context(self, committed_text: str): + """Update the context with recently committed text.""" + self.recent_context = (self.recent_context + " " + committed_text).strip() + # Keep only the last N characters to avoid prompt getting too long + if len(self.recent_context) > self.context_window_chars: + self.recent_context = self.recent_context[-self.context_window_chars:] + def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: frames = self.collector.getAudio() - # Convert from signed 16-bit int [-32768, 32767] to signed 32-bit float on - # [-1, 1]. + + # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 + # Build context-aware prompt + prompt = self._build_prompt() + t0 = time.time() segments, info = self.model.transcribe( audio, language = langcodes.find(self.cfg["language"]).language, vad_filter = True, temperature=0.0, - without_timestamps = False) + without_timestamps = False, + initial_prompt=prompt, + beam_size=5, + best_of=5, + condition_on_previous_text=True, + compression_ratio_threshold=2.4, + log_prob_threshold=-1.0, + no_speech_threshold=0.6 + ) res = [] for s in segments: # Manual touchup. I see a decent number of hallucinations sneaking @@ -445,6 +494,17 @@ class Whisper: print(f"Transcription latency (s): {t1 - t0}") return res + def _build_prompt(self) -> str: + """Build a context-aware prompt for Whisper.""" + user_prompt = self.cfg["user_prompt"] + context_prompt = "" + if self.recent_context and len(self.recent_context) > 0: + context_prompt = f"Here is the context so far: {self.recent_context}" + + prompts = [user_prompt, context_prompt] + prompts = [p for p in prompts if p and len(p) > 0] + return " ".join(prompts) + class TranscriptCommit: def __init__(self, delta: str, @@ -502,10 +562,21 @@ class VadCommitter: latency_s = self.collector.now() - self.collector.begin() duration_s = stable_cutoff / AudioStream.FPS start_ts = self.collector.begin() - commit_audio = self.collector.dropAudioPrefixByFrames(stable_cutoff) + + # Get the filtered audio first, then extract the portion we need + filtered_audio = self.collector.getAudio() + commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ] + + # Now drop the prefix from the collector + self.collector.dropAudioPrefixByFrames(stable_cutoff) segments = self.whisper.transcribe(commit_audio) delta = ''.join(s.transcript for s in segments) + + # Update whisper's context with the committed text + if delta.strip(): + self.whisper.update_context(delta.strip()) + audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: @@ -540,11 +611,11 @@ class VadCommitter: def transcriptionThread(shared_data: SharedThreadData): last_stable_commit = None - stream = MicStream(shared_data.cfg["microphone"]) + stream = MicStream(shared_data.cfg) collector = AudioCollector(stream) collector = CompressingAudioCollector(collector) + collector = BoostingAudioCollector(collector, -12.0, shared_data.cfg) collector = NormalizingAudioCollector(collector) - collector = BoostingAudioCollector(collector, 0.0, shared_data.cfg) whisper = Whisper(collector, shared_data.cfg) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"]) @@ -553,6 +624,8 @@ def transcriptionThread(shared_data: SharedThreadData): transcript = "" preview = "" + print(f"Ready to go!", flush=True) + while not shared_data.exit_event.is_set(): time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0); @@ -561,8 +634,7 @@ def transcriptionThread(shared_data: SharedThreadData): commit = committer.getDelta() if len(commit.delta) > 0 or len(commit.preview) > 0: - # Avoid re-sending text after long pauses. User controls the length - # of the pause in the UI. + # Avoid re-sending text after long pauses if shared_data.cfg["reset_after_silence_s"] > 0: silence_duration = 0 if last_stable_commit: @@ -571,10 +643,12 @@ def transcriptionThread(shared_data: SharedThreadData): last_stable_commit.duration_s silence_duration = commit.start_ts - last_commit_end_ts if silence_duration > shared_data.cfg["reset_after_silence_s"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) + if shared_data.cfg["enable_debug_mode"]: + print(f"Resetting transcript after {silence_duration}-second " + "silence", file=sys.stderr) transcript = "" preview = "" + whisper.recent_context = "" # Reset context too if commit.delta: last_stable_commit = commit diff --git a/app/vad.py b/app/vad.py deleted file mode 100644 index 1dea765..0000000 --- a/app/vad.py +++ /dev/null @@ -1,314 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Guillaume Klein -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import bisect -import functools -import os -import warnings - -from typing import List, NamedTuple, Optional - -import numpy as np - - -# The code below is adapted from https://github.com/snakers4/silero-vad. -class VadOptions(NamedTuple): - """VAD options. - - Attributes: - threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, - probabilities ABOVE this value are considered as SPEECH. It is better to tune this - parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. - min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. - max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer - than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be - split aggressively just before max_speech_duration_s. - min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms - before separating it - window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. - WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. - Values other than these may affect model performance!! - speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side - """ - - threshold: float = 0.5 - min_speech_duration_ms: int = 250 - max_speech_duration_s: float = float("inf") - min_silence_duration_ms: int = 2000 - window_size_samples: int = 1024 - speech_pad_ms: int = 400 - - -def get_speech_timestamps( - audio: np.ndarray, - vad_options: Optional[VadOptions] = None, - **kwargs, -) -> List[dict]: - """This method is used for splitting long audios into speech chunks using silero VAD. - - Args: - audio: One dimensional float array. - vad_options: Options for VAD processing. - kwargs: VAD options passed as keyword arguments for backward compatibility. - - Returns: - List of dicts containing begin and end samples of each speech chunk. - """ - if vad_options is None: - vad_options = VadOptions(**kwargs) - - threshold = vad_options.threshold - min_speech_duration_ms = vad_options.min_speech_duration_ms - max_speech_duration_s = vad_options.max_speech_duration_s - min_silence_duration_ms = vad_options.min_silence_duration_ms - window_size_samples = vad_options.window_size_samples - speech_pad_ms = vad_options.speech_pad_ms - - if window_size_samples not in [512, 1024, 1536]: - warnings.warn( - "Unusual window_size_samples! Supported window_size_samples:\n" - " - [512, 1024, 1536] for 16000 sampling_rate" - ) - - sampling_rate = 16000 - min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 - speech_pad_samples = sampling_rate * speech_pad_ms / 1000 - max_speech_samples = ( - sampling_rate * max_speech_duration_s - - window_size_samples - - 2 * speech_pad_samples - ) - min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 - min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 - - audio_length_samples = len(audio) - - model = get_vad_model() - state = model.get_initial_state(batch_size=1) - - speech_probs = [] - for current_start_sample in range(0, audio_length_samples, window_size_samples): - chunk = audio[current_start_sample : current_start_sample + window_size_samples] - if len(chunk) < window_size_samples: - chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) - speech_prob, state = model(chunk, state, sampling_rate) - speech_probs.append(speech_prob) - - triggered = False - speeches = [] - current_speech = {} - neg_threshold = threshold - 0.15 - - # to save potential segment end (and tolerate some silence) - temp_end = 0 - # to save potential segment limits in case of maximum segment size reached - prev_end = next_start = 0 - - for i, speech_prob in enumerate(speech_probs): - if (speech_prob >= threshold) and temp_end: - temp_end = 0 - if next_start < prev_end: - next_start = window_size_samples * i - - if (speech_prob >= threshold) and not triggered: - triggered = True - current_speech["start"] = window_size_samples * i - continue - - if ( - triggered - and (window_size_samples * i) - current_speech["start"] > max_speech_samples - ): - if prev_end: - current_speech["end"] = prev_end - speeches.append(current_speech) - current_speech = {} - # previously reached silence (< neg_thres) and is still not speech (< thres) - if next_start < prev_end: - triggered = False - else: - current_speech["start"] = next_start - prev_end = next_start = temp_end = 0 - else: - current_speech["end"] = window_size_samples * i - speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue - - if (speech_prob < neg_threshold) and triggered: - if not temp_end: - temp_end = window_size_samples * i - # condition to avoid cutting in very short silence - if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: - prev_end = temp_end - if (window_size_samples * i) - temp_end < min_silence_samples: - continue - else: - current_speech["end"] = temp_end - if ( - current_speech["end"] - current_speech["start"] - ) > min_speech_samples: - speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue - - if ( - current_speech - and (audio_length_samples - current_speech["start"]) > min_speech_samples - ): - current_speech["end"] = audio_length_samples - speeches.append(current_speech) - - for i, speech in enumerate(speeches): - if i == 0: - speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) - if i != len(speeches) - 1: - silence_duration = speeches[i + 1]["start"] - speech["end"] - if silence_duration < 2 * speech_pad_samples: - speech["end"] += int(silence_duration // 2) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - silence_duration // 2) - ) - else: - speech["end"] = int( - min(audio_length_samples, speech["end"] + speech_pad_samples) - ) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - speech_pad_samples) - ) - else: - speech["end"] = int( - min(audio_length_samples, speech["end"] + speech_pad_samples) - ) - - return speeches - - -def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: - """Collects and concatenates audio chunks.""" - if not chunks: - return np.array([], dtype=np.float32) - - return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) - - -class SpeechTimestampsMap: - """Helper class to restore original speech timestamps.""" - - def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2): - self.sampling_rate = sampling_rate - self.time_precision = time_precision - self.chunk_end_sample = [] - self.total_silence_before = [] - - previous_end = 0 - silent_samples = 0 - - for chunk in chunks: - silent_samples += chunk["start"] - previous_end - previous_end = chunk["end"] - - self.chunk_end_sample.append(chunk["end"] - silent_samples) - self.total_silence_before.append(silent_samples / sampling_rate) - - def get_original_time( - self, - time: float, - chunk_index: Optional[int] = None, - ) -> float: - if chunk_index is None: - chunk_index = self.get_chunk_index(time) - - total_silence_before = self.total_silence_before[chunk_index] - return round(total_silence_before + time, self.time_precision) - - def get_chunk_index(self, time: float) -> int: - sample = int(time * self.sampling_rate) - return min( - bisect.bisect(self.chunk_end_sample, sample), - len(self.chunk_end_sample) - 1, - ) - - -@functools.lru_cache -def get_vad_model(): - """Returns the VAD model instance.""" - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) - path = os.path.join(parent_dir, "Models", "silero_vad.onnx") - return SileroVADModel(path) - - -class SileroVADModel: - def __init__(self, path): - try: - import onnxruntime - except ImportError as e: - raise RuntimeError( - "Applying the VAD filter requires the onnxruntime package" - ) from e - - opts = onnxruntime.SessionOptions() - opts.inter_op_num_threads = 1 - opts.intra_op_num_threads = 1 - opts.log_severity_level = 4 - - self.session = onnxruntime.InferenceSession( - path, - providers=["CPUExecutionProvider"], - sess_options=opts, - ) - - def get_initial_state(self, batch_size: int): - h = np.zeros((2, batch_size, 64), dtype=np.float32) - c = np.zeros((2, batch_size, 64), dtype=np.float32) - return h, c - - def __call__(self, x, state, sr: int): - if len(x.shape) == 1: - x = np.expand_dims(x, 0) - if len(x.shape) > 2: - raise ValueError( - f"Too many dimensions for input audio chunk {len(x.shape)}" - ) - if sr / x.shape[1] > 31.25: - raise ValueError("Input audio chunk is too short") - - h, c = state - - ort_inputs = { - "input": x, - "h": h, - "c": c, - "sr": np.array(sr, dtype="int64"), - } - - out, h, c = self.session.run(None, ort_inputs) - state = (h, c) - - return out, state diff --git a/config.yaml b/config.yaml index 34d88f1..5eec7a2 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,7 @@ compute_type: float16 enable_debug_mode: 0 enable_previews: 1 +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. save_audio: 0 language: english gpu_idx: 0 diff --git a/ui/index.html b/ui/index.html index b06e56b..90f78c1 100644 --- a/ui/index.html +++ b/ui/index.html @@ -10,179 +10,229 @@
-
-
- -
-
-
- - -
-
- - -
-
- - -
-
-
- - - - - -
-
- -
diff --git a/ui/index.js b/ui/index.js index a056156..2420ece 100644 --- a/ui/index.js +++ b/ui/index.js @@ -3,6 +3,7 @@ const path = require('node:path'); const fs = require('node:fs').promises; const yaml = require('js-yaml'); const { spawn } = require('child_process'); +const https = require('https'); const APP_ROOT = path.join(__dirname, '..'); const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); @@ -10,6 +11,20 @@ const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); let mainWindow; let runningProcess = null; // Track the running Python process +// Required DLL files for CUDA/cuDNN support +const REQUIRED_DLLS = [ + 'cublas64_12.dll', + 'cublasLt64_12.dll', + 'cudnn64_9.dll', + 'cudnn_adv64_9.dll', + 'cudnn_cnn64_9.dll', + 'cudnn_engines_precompiled64_9.dll', + 'cudnn_engines_runtime_compiled64_9.dll', + 'cudnn_graph64_9.dll', + 'cudnn_heuristic64_9.dll', + 'cudnn_ops64_9.dll' +]; + // Helper function to get the correct Python executable from venv function getVenvPython() { const venvPath = path.join(APP_ROOT, 'venv'); @@ -24,6 +39,78 @@ function sendPythonOutput(message, type = 'stdout') { } } +// Helper function to create environment with DLL path +function createPythonEnvironment() { + const dllPath = path.join(APP_ROOT, 'dll'); + const binPath = path.join(APP_ROOT, 'bin'); + const env = { ...process.env }; + env.PATH = `${dllPath};${binPath};${env.PATH}`; + env.HF_HUB_DISABLE_SYMLINKS_WARNING = '1'; + return env; +} + +// Helper function to download a file from URL +function downloadFile(url, outputPath) { + return new Promise((resolve, reject) => { + const file = require('fs').createWriteStream(outputPath); + + const request = https.get(url, (response) => { + if (response.statusCode === 200) { + response.pipe(file); + + file.on('finish', () => { + file.close(); + resolve(); + }); + + file.on('error', (err) => { + fs.unlink(outputPath).catch(() => {}); // Clean up on error + reject(err); + }); + } else { + file.close(); + fs.unlink(outputPath).catch(() => {}); // Clean up on error + reject(new Error(`Failed to download: HTTP ${response.statusCode}`)); + } + }); + + request.on('error', (err) => { + file.close(); + fs.unlink(outputPath).catch(() => {}); // Clean up on error + reject(err); + }); + }); +} + +// Helper function to setup process event handlers +function setupProcessHandlers(process) { + process.stdout.on('data', (data) => { + const text = data.toString(); + sendPythonOutput(text.trimEnd(), 'stdout'); + }); + + process.stderr.on('data', (data) => { + const text = data.toString(); + sendPythonOutput(text.trimEnd(), 'stderr'); + }); + + process.on('error', (error) => { + sendPythonOutput(`Process error: ${error.message}`, 'stderr'); + runningProcess = null; + if (mainWindow && !mainWindow.isDestroyed()) { + mainWindow.webContents.send('process-stopped'); + } + }); + + process.on('close', (code) => { + sendPythonOutput(`Process exited with code ${code}`, 'info'); + runningProcess = null; + if (mainWindow && !mainWindow.isDestroyed()) { + mainWindow.webContents.send('process-stopped'); + } + }); +} + // Helper function to execute Python commands using venv function executePythonCommand(args, options = {}) { return new Promise((resolve, reject) => { @@ -31,14 +118,9 @@ function executePythonCommand(args, options = {}) { const commandStr = `${path.basename(pythonPath)} ${args.join(' ')}`; sendPythonOutput(`> ${commandStr}`, 'info'); - // Add dll directory to PATH for Windows DLL loading - const dllPath = path.join(APP_ROOT, 'dll'); - const env = { ...process.env }; - env.PATH = `${dllPath};${env.PATH}`; - const spawnOptions = { ...options, - env + env: createPythonEnvironment() }; const pythonProcess = spawn(pythonPath, args, spawnOptions); @@ -78,6 +160,7 @@ function createWindow () { mainWindow = new BrowserWindow({ width: 1000, height: 800, + icon: path.join(APP_ROOT, 'Images', 'favicon.ico'), webPreferences: { preload: path.join(__dirname, 'preload.js'), contextIsolation: true, @@ -93,6 +176,7 @@ const DEFAULT_CONFIG = { compute_type: 'float16', enable_debug_mode: 0, enable_previews: 1, + user_prompt: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc.', save_audio: 0, language: 'english', gpu_idx: 0, @@ -117,11 +201,11 @@ ipcMain.handle('load-config', async () => { } catch (error) { if (error.code === 'ENOENT') { // Config file doesn't exist, create it with defaults - console.log('Config file not found, creating with defaults...'); + console.error('Config file not found, creating with defaults...'); try { const yamlContent = yaml.dump(DEFAULT_CONFIG, { lineWidth: -1 }); await fs.writeFile(CONFIG_PATH, yamlContent, 'utf8'); - console.log('Created config.yaml with default values'); + console.error('Created config.yaml with default values'); return DEFAULT_CONFIG; } catch (writeError) { console.error('Error creating default config:', writeError); @@ -145,21 +229,138 @@ ipcMain.handle('save-config', async (event, config) => { } }); -ipcMain.handle('restart-app', () => { - app.relaunch(); - app.exit(); +ipcMain.handle('reset-config', async () => { + try { + // Check if the file exists first + try { + await fs.access(CONFIG_PATH); + // File exists, delete it + await fs.unlink(CONFIG_PATH); + console.error('Config file deleted successfully'); + return { success: true, message: 'Configuration reset to defaults' }; + } catch (error) { + if (error.code === 'ENOENT') { + // Config file doesn't exist, that's fine + return { success: true, message: 'Configuration already at defaults' }; + } + throw error; + } + } catch (error) { + console.error('Error resetting config:', error); + throw new Error(`Failed to reset configuration: ${error.message}`); + } }); -ipcMain.handle('install-requirements', async (event) => { +// Generic function to ensure required files are present +async function ensureRequiredFiles(config) { + const { + directoryName, + requiredFiles, + downloadBaseUrl, + resourceType + } = config; + + const targetPath = path.join(APP_ROOT, directoryName); + + try { + // Check if target directory exists, create it if not + try { + await fs.access(targetPath); + sendPythonOutput(`${resourceType} directory exists`, 'info'); + } catch (error) { + if (error.code === 'ENOENT') { + sendPythonOutput(`Creating ${resourceType} directory...`, 'info'); + await fs.mkdir(targetPath, { recursive: true }); + sendPythonOutput(`${resourceType} directory created`, 'info'); + } else { + throw error; + } + } + + // Check each required file + const missingFiles = []; + for (const fileName of requiredFiles) { + const filePath = path.join(targetPath, fileName); + try { + await fs.access(filePath); + sendPythonOutput(`✓ ${fileName} exists`, 'info'); + } catch (error) { + if (error.code === 'ENOENT') { + missingFiles.push(fileName); + sendPythonOutput(`✗ ${fileName} missing`, 'info'); + } else { + throw error; + } + } + } + + // Download missing files + if (missingFiles.length > 0) { + sendPythonOutput(`Downloading ${missingFiles.length} missing ${resourceType} file${missingFiles.length > 1 ? 's' : ''}...`, 'info'); + + for (const fileName of missingFiles) { + const filePath = path.join(targetPath, fileName); + const downloadUrl = `${downloadBaseUrl}/${fileName}`; + + try { + sendPythonOutput(`Downloading ${fileName}...`, 'info'); + await downloadFile(downloadUrl, filePath); + sendPythonOutput(`✓ Downloaded ${fileName}`, 'info'); + } catch (downloadError) { + sendPythonOutput(`✗ Failed to download ${fileName}: ${downloadError.message}`, 'stderr'); + throw new Error(`Failed to download ${fileName}: ${downloadError.message}`); + } + } + + sendPythonOutput(`All missing ${resourceType} files downloaded successfully`, 'info'); + } else { + sendPythonOutput(`All required ${resourceType} files are present`, 'info'); + } + + return { + success: true, + message: `${resourceType} setup complete. ${missingFiles.length} file${missingFiles.length > 1 ? 's' : ''} downloaded.`, + downloadedFiles: missingFiles + }; + } catch (error) { + console.error(`Error setting up ${resourceType} files:`, error); + throw new Error(`${resourceType} setup failed: ${error.message}`); + } +} + +// Update the install-requirements handler +ipcMain.handle('install-requirements', async () => { const requirementsPath = path.join(APP_ROOT, 'app', 'requirements.txt'); + const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up'); try { + // Check if venv is already set up + try { + await fs.access(venvMarkerPath); + sendPythonOutput('Virtual environment already set up, skipping installation', 'info'); + return { success: true, message: 'Virtual environment already set up' }; + } catch (error) { + // Marker doesn't exist, proceed with setup + } + // Check if requirements.txt exists await fs.access(requirementsPath); - const result = await executePythonCommand(['-m', 'pip', 'install', '-r', requirementsPath]); + await executePythonCommand(['-m', 'pip', 'install', '-r', requirementsPath]); + + await ensureRequiredFiles({ + directoryName: 'dll', + requiredFiles: REQUIRED_DLLS, + downloadBaseUrl: 'https://yummers.dev/tastt/dll', + resourceType: 'DLL' + }); + + await fs.mkdir(path.join(APP_ROOT, 'Models'), { recursive: true }); + + await fs.writeFile(venvMarkerPath, new Date().toISOString(), 'utf8'); + sendPythonOutput('Created .venv_is_set_up marker file', 'info'); - return { success: true, message: 'Requirements installed successfully' }; + return { success: true, message: 'Requirements and dependencies installed successfully' }; } catch (error) { console.error('Error installing requirements:', error); if (error.code === 'ENOENT') { @@ -175,7 +376,6 @@ ipcMain.handle('get-microphones', async () => { try { const result = await executePythonCommand([scriptPath]); const microphones = JSON.parse(result.stdout.trim()); - console.log('Successfully retrieved microphones:', microphones); return microphones; } catch (error) { console.error('Failed to get microphones:', error); @@ -183,53 +383,135 @@ ipcMain.handle('get-microphones', async () => { } }); -// Add handlers for starting and stopping the process -ipcMain.handle('start-process', async () => { - if (runningProcess) { - throw new Error('Process is already running'); +// Helper function to safely delete directory contents +async function clearDirectory(dirPath, dirName) { + try { + await fs.access(dirPath); + sendPythonOutput(`Clearing ${dirName} directory...`, 'info'); + + const files = await fs.readdir(dirPath); + let deletedCount = 0; + + for (const file of files) { + const filePath = path.join(dirPath, file); + + try { + await fs.rm(filePath, { recursive: true, force: true }); + sendPythonOutput(`✗ Deleted file ${file}`, 'info'); + + deletedCount++; + } catch (deleteError) { + sendPythonOutput(`Warning: Could not delete ${file}: ${deleteError.message}`, 'stderr'); + // Continue with other files even if one fails + } + } + + sendPythonOutput(`${dirName} directory cleared`, 'info'); + return deletedCount; + } catch (error) { + if (error.code === 'ENOENT') { + sendPythonOutput(`${dirName} directory doesn't exist, skipping`, 'info'); + return 0; + } else { + sendPythonOutput(`Error clearing ${dirName} directory: ${error.message}`, 'stderr'); + throw error; + } } +} - const scriptPath = path.join(APP_ROOT, 'app', 'hi.py'); - const configPath = CONFIG_PATH; +ipcMain.handle('reset-venv', async () => { + const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up'); try { - const pythonPath = getVenvPython(); - const args = [scriptPath, '--config', configPath]; + sendPythonOutput('Starting virtual environment reset...', 'info'); - sendPythonOutput(`Starting process: ${path.basename(pythonPath)} ${args.join(' ')}`, 'info'); + // Delete the venv marker file first + try { + await fs.unlink(venvMarkerPath); + sendPythonOutput('Deleted .venv_is_set_up marker file', 'info'); + } catch (error) { + if (error.code !== 'ENOENT') { + sendPythonOutput(`Warning: Could not delete marker file: ${error.message}`, 'stderr'); + } + } + + // Get list of installed packages + sendPythonOutput('Getting list of installed packages...', 'info'); + const freezeResult = await executePythonCommand(['-m', 'pip', 'freeze']); + const installedPackages = freezeResult.stdout.trim(); + + let uninstalledPackages = []; + + if (!installedPackages) { + sendPythonOutput('No packages found to uninstall', 'info'); + } else { + // Parse package names and filter out core packages + const packageLines = installedPackages.split('\n').filter(line => line.trim()); + const packageNames = packageLines + .map(line => line.split('==')[0].trim()) + .filter(name => name && !name.startsWith('#')); + + const corePackages = ['pip', 'setuptools', 'wheel']; + const packagesToUninstall = packageNames.filter(name => !corePackages.includes(name.toLowerCase())); + + if (packagesToUninstall.length === 0) { + sendPythonOutput('Only core packages found, nothing to uninstall', 'info'); + } else { + sendPythonOutput(`Uninstalling ${packagesToUninstall.length} packages...`, 'info'); + + const uninstallArgs = ['-m', 'pip', 'uninstall', '-y', ...packagesToUninstall]; + await executePythonCommand(uninstallArgs); + uninstalledPackages = packagesToUninstall; + } + } + + // Clear downloaded files + sendPythonOutput('Clearing downloaded files...', 'info'); - // Add dll directory to PATH for Windows DLL loading const dllPath = path.join(APP_ROOT, 'dll'); - const env = { ...process.env }; - env.PATH = `${dllPath};${env.PATH}`; + const modelsPath = path.join(APP_ROOT, 'Models'); + const binPath = path.join(APP_ROOT, 'bin'); - runningProcess = spawn(pythonPath, args, { env }); + const deletedDlls = await clearDirectory(dllPath, 'DLL'); + const deletedModels = await clearDirectory(modelsPath, 'Models'); + const deletedBins = await clearDirectory(binPath, 'Binary'); - runningProcess.stdout.on('data', (data) => { - const text = data.toString(); - sendPythonOutput(text.trimEnd(), 'stdout'); - }); + const totalDeletedFiles = deletedDlls + deletedModels + deletedBins; - runningProcess.stderr.on('data', (data) => { - const text = data.toString(); - sendPythonOutput(text.trimEnd(), 'stderr'); - }); + sendPythonOutput('Virtual environment reset successfully!', 'info'); - runningProcess.on('error', (error) => { - sendPythonOutput(`Process error: ${error.message}`, 'stderr'); - runningProcess = null; - if (mainWindow && !mainWindow.isDestroyed()) { - mainWindow.webContents.send('process-stopped'); + return { + success: true, + message: `Virtual environment reset complete. Uninstalled ${uninstalledPackages.length} packages and deleted ${totalDeletedFiles} downloaded files.`, + uninstalledPackages, + deletedFiles: { + dlls: deletedDlls, + models: deletedModels, + binaries: deletedBins, + total: totalDeletedFiles } - }); + }; + } catch (error) { + console.error('Error resetting virtual environment:', error); + throw new Error(`Virtual environment reset failed: ${error.message}`); + } +}); + +// Add handlers for starting and stopping the process +ipcMain.handle('start-process', async () => { + if (runningProcess) { + throw new Error('Process is already running'); + } + + const scriptPath = path.join(APP_ROOT, 'app', 'hi.py'); + const args = [scriptPath, '--config', CONFIG_PATH]; + + try { + const pythonPath = getVenvPython(); + sendPythonOutput(`Starting process: ${path.basename(pythonPath)} ${args.join(' ')}`, 'info'); - runningProcess.on('close', (code) => { - sendPythonOutput(`Process exited with code ${code}`, 'info'); - runningProcess = null; - if (mainWindow && !mainWindow.isDestroyed()) { - mainWindow.webContents.send('process-stopped'); - } - }); + runningProcess = spawn(pythonPath, args, { env: createPythonEnvironment() }); + setupProcessHandlers(runningProcess); return { success: true }; } catch (error) { @@ -243,7 +525,7 @@ ipcMain.handle('stop-process', async () => { throw new Error('No process is running'); } - return new Promise((resolve, reject) => { + return new Promise((resolve) => { let forcefullyKilled = false; // Set up a timeout to force kill after 10 seconds diff --git a/ui/package.json b/ui/package.json index fee2d67..3a58298 100644 --- a/ui/package.json +++ b/ui/package.json @@ -3,12 +3,85 @@ "version": "1.0.0", "description": "Speech-to-text tool for VRChat", "main": "index.js", + "homepage": "./", "scripts": { "start": "npm run build:css && electron .", "build:css": "tailwindcss -i ./src/components.css -o ./build/output.css", "watch:css": "tailwindcss -i ./src/components.css -o ./build/output.css --watch", "dev": "concurrently \"npm run watch:css\" \"electron .\"", - "test": "echo \"Error: no test specified\" && exit 1" + "test": "echo \"Error: no test specified\" && exit 1", + "dist": "npm run build:css && electron-builder", + "dist:win": "npm run build:css && electron-builder --win", + "dist:portable": "npm run build:css && electron-builder --win portable", + "dist:zip": "npm run build:css && electron-builder --win zip" + }, + "build": { + "appId": "com.yum_food.tastt", + "productName": "TaSTT", + "directories": { + "output": "dist" + }, + "files": [ + "**/*", + "!dist/**/*", + "!src/**/*", + "!node_modules/**/{CHANGELOG.md,README.md,README,readme.md,readme}", + "!node_modules/**/{test,__tests__,tests,powered-test,example,examples}", + "!node_modules/**/*.d.ts", + "!node_modules/.bin", + "!.git/**/*", + "!.gitignore" + ], + "extraResources": [ + { + "from": "../app", + "to": "app", + "filter": [ + "**/*.py", + "requirements.txt", + "!**/__pycache__/**/*" + ] + }, + { + "from": "../config.yaml", + "to": "config.yaml" + }, + { + "from": "../dll", + "to": "dll", + "filter": ["**/*"] + }, + { + "from": "../Images", + "to": "Images", + "filter": ["**/*"] + }, + { + "from": "../bin", + "to": "bin", + "filter": ["**/*"] + } + ], + "win": { + "icon": "../Images/logo.png", + "target": [ + { + "target": "portable", + "arch": ["x64"] + }, + { + "target": "zip", + "arch": ["x64"] + } + ] + }, + "portable": { + "artifactName": "${productName}-${version}-portable.exe" + }, + "nsis": { + "oneClick": false, + "allowToChangeInstallationDirectory": true + } }, "keywords": [], "author": "yum_food", @@ -22,6 +95,7 @@ "concurrently": "^9.1.2", "cross-env": "^7.0.3", "electron": "^36.3.2", + "electron-builder": "^25.1.8", "postcss": "^8.5.4", "tailwindcss": "^3.4.17", "vite": "^6.3.5", diff --git a/ui/preload.js b/ui/preload.js index e6c0623..35cc8d6 100644 --- a/ui/preload.js +++ b/ui/preload.js @@ -3,14 +3,13 @@ const { contextBridge, ipcRenderer } = require('electron'); contextBridge.exposeInMainWorld('electronAPI', { loadConfig: () => ipcRenderer.invoke('load-config'), saveConfig: (config) => ipcRenderer.invoke('save-config', config), - restartApp: () => ipcRenderer.invoke('restart-app'), + resetConfig: () => ipcRenderer.invoke('reset-config'), getMicrophones: () => ipcRenderer.invoke('get-microphones'), installRequirements: () => ipcRenderer.invoke('install-requirements'), + resetVenv: () => ipcRenderer.invoke('reset-venv'), startProcess: () => ipcRenderer.invoke('start-process'), stopProcess: () => ipcRenderer.invoke('stop-process'), onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)), - onProcessStopped: (callback) => ipcRenderer.on('process-stopped', (event) => callback()) + onProcessStopped: (callback) => ipcRenderer.on('process-stopped', () => callback()) }); -console.log('Preload script loaded.'); - diff --git a/ui/renderer.js b/ui/renderer.js index b3f05a6..201eef6 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -1,99 +1,220 @@ -// Handle status messages +// Configuration and form field mappings +const CONFIG_FIELDS = { + // String fields + compute_type: { type: 'select', default: 'float16' }, + language: { type: 'select', default: 'english' }, + model: { type: 'select', default: 'turbo' }, + microphone: { type: 'number', default: 0 }, + user_prompt: { type: 'text', default: '' }, + + // Number fields + gpu_idx: { type: 'number', default: 0 }, + max_speech_duration_s: { type: 'number', default: 10 }, + min_silence_duration_ms: { type: 'number', default: 250 }, + reset_after_silence_s: { type: 'number', default: 15 }, + transcription_loop_delay_ms: { type: 'number', default: 100 }, + block_width: { type: 'number', default: 2 }, + num_blocks: { type: 'number', default: 40 }, + rows: { type: 'number', default: 10 }, + cols: { type: 'number', default: 24 }, + + // Boolean fields (stored as 1/0) + enable_debug_mode: { type: 'boolean', default: 0 }, + enable_previews: { type: 'boolean', default: 1 }, + save_audio: { type: 'boolean', default: 0 }, + use_cpu: { type: 'boolean', default: 0 } +}; + +// Button management system +class ButtonManager { + constructor() { + this.buttons = { + start: document.getElementById('start-process'), + stop: document.getElementById('stop-process'), + setupVenv: document.getElementById('setup-venv'), + resetVenv: document.getElementById('reset-venv'), + refreshMicrophones: document.getElementById('refresh-microphones') + }; + } + + setState(buttonName, disabled) { + const button = this.buttons[buttonName]; + if (!button) return; + + button.disabled = disabled; + if (disabled) { + button.classList.add('opacity-50', 'cursor-not-allowed'); + } else { + button.classList.remove('opacity-50', 'cursor-not-allowed'); + } + } + + setProcessRunning() { + this.setState('start', true); + this.setState('stop', false); + } + + setProcessStopped() { + this.setState('start', false); + this.setState('stop', true); + } + + async withButtonLoading(buttonName, asyncFn) { + this.setState(buttonName, true); + try { + return await asyncFn(); + } finally { + this.setState(buttonName, false); + } + } +} + +const buttonManager = new ButtonManager(); + +// Add loading overlay management +class LoadingOverlay { + constructor() { + this.overlay = document.getElementById('loading-overlay'); + this.form = document.getElementById('config-form'); + this.messageElement = this.overlay.querySelector('p'); + this.defaultMessage = 'Environment setup underway - please wait.'; + } + + show(message = null) { + this.messageElement.textContent = message || this.defaultMessage; + this.overlay.classList.remove('hidden'); + // Disable all form inputs and buttons in the entire left panel + const leftPanel = this.overlay.parentElement; + const inputs = leftPanel.querySelectorAll('input, select, textarea, button'); + inputs.forEach(input => { + input.disabled = true; + input.classList.add('opacity-50'); + }); + } + + hide() { + this.overlay.classList.add('hidden'); + // Re-enable all form inputs and buttons in the entire left panel + const leftPanel = this.overlay.parentElement; + const inputs = leftPanel.querySelectorAll('input, select, textarea, button'); + inputs.forEach(input => { + input.disabled = false; + input.classList.remove('opacity-50'); + }); + // Reset to default message + this.messageElement.textContent = this.defaultMessage; + } +} + +const loadingOverlay = new LoadingOverlay(); + +// Add a flag to prevent auto-save during programmatic updates +let isSettingValues = false; + +// Handle status messages with better color management function showStatus(message, type = 'info') { const statusEl = document.getElementById('status-message'); statusEl.textContent = message; - statusEl.classList.remove('hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800'); - - if (type === 'success') { - statusEl.classList.add('bg-green-100', 'text-green-800'); - } else if (type === 'error') { - statusEl.classList.add('bg-red-100', 'text-red-800'); - } else { - statusEl.classList.add('bg-blue-100', 'text-blue-800'); - } + + // Remove all status classes + const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800']; + statusEl.classList.remove(...statusClasses); + + // Add appropriate classes based on type + const typeMap = { + success: ['bg-green-100', 'text-green-800'], + error: ['bg-red-100', 'text-red-800'], + info: ['bg-blue-100', 'text-blue-800'] + }; + + statusEl.classList.add(...(typeMap[type] || typeMap.info)); // Also log to console appendToConsole(message, type === 'error' ? 'stderr' : 'info'); - setTimeout(() => { - statusEl.classList.add('hidden'); - }, 5000); + setTimeout(() => statusEl.classList.add('hidden'), 5000); } -// Get form values +// Get form values using field mappings function getFormValues() { - const microphoneValue = document.getElementById('microphone').value; - // Convert to number if it's a numeric string (device index) - const microphoneForConfig = /^\d+$/.test(microphoneValue) ? parseInt(microphoneValue) : microphoneValue; - - return { - compute_type: document.getElementById('compute_type').value, - enable_debug_mode: document.getElementById('enable_debug_mode').checked ? 1 : 0, - enable_previews: document.getElementById('enable_previews').checked ? 1 : 0, - save_audio: document.getElementById('save_audio').checked ? 1 : 0, - language: document.getElementById('language').value, - gpu_idx: parseInt(document.getElementById('gpu_idx').value), - max_speech_duration_s: parseInt(document.getElementById('max_speech_duration_s').value), - min_silence_duration_ms: parseInt(document.getElementById('min_silence_duration_ms').value), - microphone: microphoneForConfig, - model: document.getElementById('model').value, - reset_after_silence_s: parseInt(document.getElementById('reset_after_silence_s').value), - transcription_loop_delay_ms: parseInt(document.getElementById('transcription_loop_delay_ms').value), - use_cpu: document.getElementById('use_cpu').checked ? 1 : 0, - block_width: parseInt(document.getElementById('block_width').value), - num_blocks: parseInt(document.getElementById('num_blocks').value), - rows: parseInt(document.getElementById('rows').value), - cols: parseInt(document.getElementById('cols').value) - }; + const config = {}; + + for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) { + const element = document.getElementById(fieldName); + if (!element) continue; + + switch (fieldConfig.type) { + case 'boolean': + config[fieldName] = element.checked ? 1 : 0; + break; + case 'number': + config[fieldName] = parseInt(element.value) || fieldConfig.default; + break; + case 'text': + config[fieldName] = element.value || fieldConfig.default; + break; + default: + config[fieldName] = element.value || fieldConfig.default; + } + } + + return config; } -// Add a flag to prevent auto-save during programmatic updates -let isSettingValues = false; - -// Set form values +// Set form values using field mappings function setFormValues(config) { isSettingValues = true; // Disable auto-save temporarily - document.getElementById('compute_type').value = config.compute_type || 'int8'; - document.getElementById('enable_debug_mode').checked = config.enable_debug_mode === 1; - document.getElementById('enable_previews').checked = config.enable_previews === 1; - document.getElementById('save_audio').checked = config.save_audio === 1; - document.getElementById('language').value = config.language || 'english'; - document.getElementById('gpu_idx').value = config.gpu_idx || 0; - document.getElementById('max_speech_duration_s').value = config.max_speech_duration_s || 10; - document.getElementById('min_silence_duration_ms').value = config.min_silence_duration_ms || 250; - document.getElementById('microphone').value = config.microphone || 'motu'; - document.getElementById('model').value = config.model || 'turbo'; - document.getElementById('reset_after_silence_s').value = config.reset_after_silence_s || 15; - document.getElementById('transcription_loop_delay_ms').value = config.transcription_loop_delay_ms || 100; - document.getElementById('use_cpu').checked = config.use_cpu === 1; - document.getElementById('block_width').value = config.block_width || 2; - document.getElementById('num_blocks').value = config.num_blocks || 40; - document.getElementById('rows').value = config.rows || 10; - document.getElementById('cols').value = config.cols || 24; + for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) { + const element = document.getElementById(fieldName); + if (!element) continue; + + const value = config[fieldName] ?? fieldConfig.default; + + switch (fieldConfig.type) { + case 'boolean': + element.checked = value === 1; + break; + case 'text': + element.value = value || ''; + break; + default: + element.value = value; + } + } isSettingValues = false; // Re-enable auto-save } -// Toggle advanced settings -document.getElementById('toggle-advanced').addEventListener('click', () => { - const advancedSettings = document.getElementById('advanced-settings'); - const chevron = document.getElementById('chevron'); - - if (advancedSettings.classList.contains('hidden')) { - advancedSettings.classList.remove('hidden'); - chevron.classList.add('rotate-90'); - } else { - advancedSettings.classList.add('hidden'); - chevron.classList.remove('rotate-90'); - } -}); +// Console management +const consoleContent = document.getElementById('console-content'); + +function appendToConsole(message, type = 'stdout') { + const timestamp = new Date().toLocaleTimeString(); + const timestampSpan = document.createElement('span'); + timestampSpan.className = 'console-timestamp'; + timestampSpan.textContent = `[${timestamp}] `; + + const messageSpan = document.createElement('span'); + messageSpan.className = `console-${type}`; + messageSpan.textContent = message; + + const lineDiv = document.createElement('div'); + lineDiv.appendChild(timestampSpan); + lineDiv.appendChild(messageSpan); + + consoleContent.appendChild(lineDiv); + + // Auto-scroll to bottom + const pythonConsole = document.getElementById('python-console'); + pythonConsole.scrollTop = pythonConsole.scrollHeight; +} -// Simplify button handlers by extracting common patterns +// Async action handler with better error handling async function handleAsyncAction(actionName, actionFn) { try { const result = await actionFn(); - if (result && result.message) { + if (result?.message) { showStatus(result.message, 'success'); } return result; @@ -103,36 +224,12 @@ async function handleAsyncAction(actionName, actionFn) { } } -// Process control buttons -const startButton = document.getElementById('start-process'); -const stopButton = document.getElementById('stop-process'); - -// Helper functions for button state management -function setButtonState(button, disabled) { - button.disabled = disabled; - if (disabled) { - button.classList.add('opacity-50', 'cursor-not-allowed'); - } else { - button.classList.remove('opacity-50', 'cursor-not-allowed'); - } -} - -function setProcessRunningState() { - setButtonState(startButton, true); - setButtonState(stopButton, false); -} - -function setProcessStoppedState() { - setButtonState(startButton, false); - setButtonState(stopButton, true); -} - // Auto-save functionality with debouncing let saveTimeout; -const SAVE_DELAY = 500; // milliseconds +const SAVE_DELAY = 500; async function autoSaveConfig() { - if (isSettingValues) return; // Don't save during programmatic updates + if (isSettingValues) return; clearTimeout(saveTimeout); saveTimeout = setTimeout(async () => { @@ -141,28 +238,19 @@ async function autoSaveConfig() { await window.electronAPI.saveConfig(config); showStatus('Configuration saved', 'success'); - // Check if process is running (stop button is enabled means process is running) - const stopButton = document.getElementById('stop-process'); - - if (!stopButton.disabled) { - // Process is running, restart it with new config + // Restart process if running + if (!buttonManager.buttons.stop.disabled) { appendToConsole('Restarting process with new configuration...', 'info'); try { await window.electronAPI.stopProcess(); - await new Promise(resolve => setTimeout(resolve, 1000)); - await window.electronAPI.startProcess(); - - // Update button states to reflect running process - setProcessRunningState(); - + buttonManager.setProcessRunning(); appendToConsole('Process restarted with new configuration', 'info'); } catch (error) { appendToConsole(`Failed to restart process: ${error.message}`, 'stderr'); - // Process is stopped, update button states - setProcessStoppedState(); + buttonManager.setProcessStopped(); } } } catch (error) { @@ -171,47 +259,32 @@ async function autoSaveConfig() { }, SAVE_DELAY); } -// Add event listeners to all form inputs for auto-save +// Auto-save setup function setupAutoSave() { - // Get all form inputs const form = document.getElementById('config-form'); - const inputs = form.querySelectorAll('input, select'); + const inputs = form.querySelectorAll('input, select, textarea'); - // Add change listener to each input inputs.forEach(input => { - if (input.type === 'checkbox') { - input.addEventListener('change', autoSaveConfig); - } else if (input.type === 'number' || input.type === 'text') { - input.addEventListener('input', autoSaveConfig); - } else if (input.tagName === 'SELECT') { - input.addEventListener('change', autoSaveConfig); - } + const eventType = input.type === 'checkbox' ? 'change' : + (input.type === 'number' || input.type === 'text' || input.tagName === 'TEXTAREA') ? 'input' : 'change'; + input.addEventListener(eventType, autoSaveConfig); }); } -// Update the setup-venv handler -document.getElementById('setup-venv').addEventListener('click', async () => { - const setupButton = document.getElementById('setup-venv'); - setupButton.disabled = true; - setupButton.classList.add('opacity-50', 'cursor-not-allowed'); - - try { - await handleAsyncAction('Install requirements', async () => { - return await window.electronAPI.installRequirements(); - }); - // Reload microphones after successful installation - await loadMicrophones(); - } finally { - setupButton.disabled = false; - setupButton.classList.remove('opacity-50', 'cursor-not-allowed'); - } -}); - -// Simplified microphone loading +// Microphone loading async function loadMicrophones() { const microphoneSelect = document.getElementById('microphone'); try { + // Check/install requirements during startup + appendToConsole('Checking virtual environment and requirements...', 'info'); + loadingOverlay.show('Setting up environment - this can take several minutes.'); + try { + await handleAsyncAction('Install requirements', () => window.electronAPI.installRequirements()); + } finally { + loadingOverlay.hide(); // Always hide overlay when done + } + appendToConsole('Loading available microphones...', 'info'); const microphones = await window.electronAPI.getMicrophones(); @@ -232,7 +305,7 @@ async function loadMicrophones() { appendToConsole(` - ${mic.name} (Device ${mic.index})`, 'stdout'); }); - // Restore previously selected microphone if possible + // Restore previously selected microphone try { const config = await window.electronAPI.loadConfig(); if (config.microphone) { @@ -248,11 +321,144 @@ async function loadMicrophones() { } } -// Update window load to include auto-save setup +// Event handlers setup +function setupEventHandlers() { + // Advanced settings toggle + document.getElementById('toggle-advanced').addEventListener('click', () => { + const advancedSettings = document.getElementById('advanced-settings'); + const chevron = document.getElementById('chevron'); + + if (advancedSettings.classList.contains('hidden')) { + advancedSettings.classList.remove('hidden'); + chevron.classList.add('rotate-90'); + } else { + advancedSettings.classList.add('hidden'); + chevron.classList.remove('rotate-90'); + } + }); + + // Setup virtual environment + document.getElementById('setup-venv').addEventListener('click', async () => { + loadingOverlay.show('Setting up virtual environment - please wait...'); // Show overlay with custom message + try { + await buttonManager.withButtonLoading('setupVenv', async () => { + await handleAsyncAction('Install requirements', () => window.electronAPI.installRequirements()); + }); + } finally { + loadingOverlay.hide(); // Always hide overlay when done + } + }); + + // Reset virtual environment + document.getElementById('reset-venv').addEventListener('click', async () => { + loadingOverlay.show('Resetting virtual environment - please wait...'); // Show overlay with custom message + try { + await buttonManager.withButtonLoading('resetVenv', async () => { + await handleAsyncAction('Reset virtual environment', () => window.electronAPI.resetVenv()); + }); + } finally { + loadingOverlay.hide(); // Always hide overlay when done + } + }); + + // Reset configuration + document.getElementById('reset-config').addEventListener('click', async () => { + const confirmReset = confirm('Are you sure you want to reset all settings to defaults? This cannot be undone.'); + if (!confirmReset) return; + + try { + // Stop process if running + const wasRunning = !buttonManager.buttons.stop.disabled; + if (wasRunning) { + appendToConsole('Stopping process before resetting configuration...', 'info'); + await window.electronAPI.stopProcess(); + buttonManager.setProcessStopped(); + await new Promise(resolve => setTimeout(resolve, 500)); + } + + // Reset configuration + appendToConsole('Resetting configuration to defaults...', 'info'); + const result = await window.electronAPI.resetConfig(); + + // Reload configuration with defaults + const config = await window.electronAPI.loadConfig(); + setFormValues(config); + + showStatus(result.message, 'success'); + appendToConsole('Configuration reset successfully', 'info'); + + // Restart process if it was running + if (wasRunning) { + appendToConsole('Restarting process with default configuration...', 'info'); + await window.electronAPI.startProcess(); + buttonManager.setProcessRunning(); + appendToConsole('Process restarted with default configuration', 'info'); + } + } catch (error) { + showStatus(`Failed to reset configuration: ${error.message}`, 'error'); + appendToConsole(`Failed to reset configuration: ${error.message}`, 'stderr'); + } + }); + + // Refresh microphones + document.getElementById('refresh-microphones').addEventListener('click', async () => { + await buttonManager.withButtonLoading('refreshMicrophones', async () => { + await loadMicrophones(); + }); + }); + + // Start process + document.getElementById('start-process').addEventListener('click', async () => { + buttonManager.setState('start', true); + + try { + // The installRequirements function will now check if venv is set up. + loadingOverlay.show('Verifying environment setup - please wait...'); // Show overlay with custom message + try { + await window.electronAPI.installRequirements(); + appendToConsole('Virtual environment setup checked/completed', 'info'); + } finally { + loadingOverlay.hide(); // Always hide overlay when done + } + + await window.electronAPI.startProcess(); + buttonManager.setProcessRunning(); + appendToConsole('Process started successfully', 'info'); + } catch (error) { + appendToConsole(`Failed to start process: ${error.message}`, 'stderr'); + buttonManager.setState('start', false); + } + }); + + // Stop process + document.getElementById('stop-process').addEventListener('click', async () => { + buttonManager.setState('stop', true); + + try { + await window.electronAPI.stopProcess(); + appendToConsole('Process stop initiated', 'info'); + } catch (error) { + appendToConsole(`Failed to stop process: ${error.message}`, 'stderr'); + buttonManager.setState('stop', false); + } + }); + + // Listen for process stopped event + window.electronAPI.onProcessStopped(() => { + buttonManager.setProcessStopped(); + }); +} + +// Initialize application window.addEventListener('load', async () => { appendToConsole('TaSTT Configuration UI initialized', 'info'); - // Load config first + // Set up Python output listener first so we capture all output + window.electronAPI.onPythonOutput((data) => { + appendToConsole(data.message, data.type); + }); + + // Load configuration try { const config = await window.electronAPI.loadConfig(); setFormValues(config); @@ -264,71 +470,7 @@ window.addEventListener('load', async () => { // Load microphones await loadMicrophones(); - // Set up auto-save after everything is loaded + // Setup event handlers and auto-save + setupEventHandlers(); setupAutoSave(); -}); - -// Console management -const consoleContent = document.getElementById('console-content'); - -function appendToConsole(message, type = 'stdout') { - const timestamp = new Date().toLocaleTimeString(); - const timestampSpan = document.createElement('span'); - timestampSpan.className = 'console-timestamp'; - timestampSpan.textContent = `[${timestamp}] `; - - const messageSpan = document.createElement('span'); - messageSpan.className = `console-${type}`; - messageSpan.textContent = message; - - const lineDiv = document.createElement('div'); - lineDiv.appendChild(timestampSpan); - lineDiv.appendChild(messageSpan); - - consoleContent.appendChild(lineDiv); - - // Auto-scroll to bottom - const pythonConsole = document.getElementById('python-console'); - pythonConsole.scrollTop = pythonConsole.scrollHeight; -} - -// Clear console button -document.getElementById('clear-console').addEventListener('click', () => { - consoleContent.innerHTML = ''; - appendToConsole('Console cleared', 'info'); -}); - -// Listen for Python output -window.electronAPI.onPythonOutput((data) => { - appendToConsole(data.message, data.type); -}); - -document.getElementById('start-process').addEventListener('click', async () => { - setButtonState(startButton, true); - - try { - await window.electronAPI.startProcess(); - setProcessRunningState(); - appendToConsole('Process started successfully', 'info'); - } catch (error) { - appendToConsole(`Failed to start process: ${error.message}`, 'stderr'); - setButtonState(startButton, false); - } -}); - -document.getElementById('stop-process').addEventListener('click', async () => { - setButtonState(stopButton, true); - - try { - const result = await window.electronAPI.stopProcess(); - appendToConsole('Process stop initiated', 'info'); - } catch (error) { - appendToConsole(`Failed to stop process: ${error.message}`, 'stderr'); - setButtonState(stopButton, false); - } -}); - -// Listen for process stopped event -window.electronAPI.onProcessStopped(() => { - setProcessStoppedState(); }); \ No newline at end of file diff --git a/ui/src/components.css b/ui/src/components.css index d8d909d..2832e12 100644 --- a/ui/src/components.css +++ b/ui/src/components.css @@ -46,6 +46,14 @@ .btn-red { @apply bg-red-600 text-white hover:bg-red-700 focus:ring-red-500; } + + .btn-purple { + @apply bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500; + } + + .btn-orange { + @apply bg-orange-600 text-white hover:bg-orange-700 focus:ring-orange-500; + } } /* Console styling */ diff --git a/ui_design.md b/ui_design.md index 06eee65..e1ff095 100644 --- a/ui_design.md +++ b/ui_design.md @@ -10,7 +10,13 @@ $ choco uninstall nodejs -y $ choco install nodejs-lts -y ``` -Now open a non-admin PowerShell terminal: +To build the app: +``` +$ npm install +$ npm run dev +``` + +For posterity, this is how I set up the ui directory initially. In a non-admin PowerShell window: ```bash # Check your node and npm versions. @@ -30,3 +36,4 @@ npx tailwindcss init -p npm install --save-dev vue@3 @vitejs/plugin-vue vite yaml npm install --save-dev js-yaml ``` + -- cgit v1.2.3 From 7fb9c575aea4d318e9c14b82174d1b323171b62b Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 30 May 2025 13:32:36 -0700 Subject: More stuff - fix unicode output from python terminal - fix cpu inference - add filters - add beam search params to UI - DRY up config definition in UI --- Third_Party/Profanity | 1 + app/hi.py | 4 ++ app/profanity_filter.py | 43 ++++++++++++++ app/stt.py | 151 ++++++++++++++++++++++++++++++++++++++++-------- config.yaml | 20 ++++--- ui/config-schema.js | 49 ++++++++++++++++ ui/index.html | 52 ++++++++++++++++- ui/index.js | 49 +++++++--------- ui/renderer.js | 31 ++-------- 9 files changed, 312 insertions(+), 88 deletions(-) create mode 160000 Third_Party/Profanity create mode 100644 app/profanity_filter.py create mode 100644 ui/config-schema.js diff --git a/Third_Party/Profanity b/Third_Party/Profanity new file mode 160000 index 0000000..5faf2ba --- /dev/null +++ b/Third_Party/Profanity @@ -0,0 +1 @@ +Subproject commit 5faf2ba42d7b1c0977169ec3611df25a3c08eb13 diff --git a/app/hi.py b/app/hi.py index e6877ff..bab0fd4 100644 --- a/app/hi.py +++ b/app/hi.py @@ -1,5 +1,6 @@ import app_config import argparse +import io from math import floor, ceil import msvcrt import os @@ -11,6 +12,9 @@ import sys import threading import time +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose diff --git a/app/profanity_filter.py b/app/profanity_filter.py new file mode 100644 index 0000000..b8c84ed --- /dev/null +++ b/app/profanity_filter.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +class ProfanityFilter: + def __init__(self, en_path: str): + self.en_path = en_path + self.en_profanity = set() + + def load(self): + with open(self.en_path, 'r') as f: + for line in f: + self.en_profanity.add(line.strip()) + + def filter(self, line: str, language_code: str = "en") -> str: + filtered = "" + + if language_code not in {"en"}: + raise ValueError(f"Language code \"{language_code}\" is " + + "unsupported by the profanity filter") + + # Translation table converting vowels to asterisks. + vowel_to_asterisk = str.maketrans('aeiouAEIOU', '**********') + + result = [] + for word in line.split(): + word_clean = word.lower() + # Filter out non-alphabet characters from the word. + word_clean = ''.join([char for char in word_clean if char.isalpha()]) + if word_clean in self.en_profanity: + result.append(word.translate(vowel_to_asterisk)) + else: + result.append(word) + + return " ".join(result) + +if __name__ == "__main__": + en_path = "/mnt/d/vrc/TaSTT/GUI/Profanity/Profanity/en" + p = ProfanityFilter(en_path) + p.load() + assert(p.filter("fuck") == "f*ck") + assert(p.filter("fuck!") == "f*ck!") + assert(p.filter("fuck shit") == "f*ck sh*t") + assert(p.filter("fuck shit this should not be filtered") == "f*ck sh*t this should not be filtered") + assert(p.filter("ASS") == "*SS") diff --git a/app/stt.py b/app/stt.py index 7d76333..a3988e1 100644 --- a/app/stt.py +++ b/app/stt.py @@ -3,6 +3,12 @@ from faster_whisper import WhisperModel import langcodes import numpy as np import os +try: + from profanity_filter import ProfanityFilter + PROFANITY_FILTER_AVAILABLE = True +except ImportError: + PROFANITY_FILTER_AVAILABLE = False + print("Warning: profanity_filter module not available", file=sys.stderr) import pyaudio from pydub import AudioSegment from shared_thread_data import SharedThreadData @@ -12,7 +18,6 @@ import time import typing import wave - APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -297,21 +302,19 @@ class AudioSegmenter: max_speech_s=5): self.min_silence_ms = min_silence_ms self.max_speech_s = max_speech_s - + # Load Silero VAD model self.model = load_silero_vad() - + self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s - - self.speech_pad_ms = 300 def segmentAudio(self, audio: bytes): # Convert audio bytes to numpy array expected by silero-vad audio_array = np.frombuffer(audio, dtype=np.int16).flatten().astype(np.float32) / 32768.0 - + # Get speech timestamps using silero-vad # Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS speech_timestamps = get_speech_timestamps( @@ -323,7 +326,7 @@ class AudioSegmenter: max_speech_duration_s=self.max_speech_duration_s, return_seconds=False # We want frame indices, not seconds ) - + return speech_timestamps # Returns the stable cutoff (if any) and whether there are any segments. @@ -399,27 +402,25 @@ class Whisper: self.model = None self.cfg = cfg - abspath = os.path.abspath(__file__) - my_dir = os.path.dirname(abspath) - parent_dir = os.path.dirname(my_dir) - model_str = cfg["model"] - model_root = os.path.join(parent_dir, "Models", + model_root = os.path.join(PROJECT_ROOT, "Models", os.path.normpath(model_str)) if cfg["enable_debug_mode"]: print(f"Model {cfg['model']} will be saved to {model_root}", file=sys.stderr) model_device = "cuda" + compute_type = cfg["compute_type"] if cfg["use_cpu"]: model_device = "cpu" + compute_type = "int8" already_downloaded = os.path.exists(model_root) self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], - compute_type = cfg["compute_type"], + compute_type = compute_type, download_root = model_root, local_files_only = already_downloaded) @@ -436,14 +437,14 @@ class Whisper: def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: frames = self.collector.getAudio() - + # Convert audio to float32 audio = np.frombuffer(frames, dtype=np.int16).flatten().astype(np.float32) / 32768.0 # Build context-aware prompt prompt = self._build_prompt() - + t0 = time.time() segments, info = self.model.transcribe( audio, @@ -452,12 +453,9 @@ class Whisper: temperature=0.0, without_timestamps = False, initial_prompt=prompt, - beam_size=5, - best_of=5, - condition_on_previous_text=True, - compression_ratio_threshold=2.4, - log_prob_threshold=-1.0, - no_speech_threshold=0.6 + beam_size=self.cfg.get("beam_size", 5), + best_of=self.cfg.get("best_of", 5), + condition_on_previous_text=True ) res = [] for s in segments: @@ -562,21 +560,21 @@ class VadCommitter: latency_s = self.collector.now() - self.collector.begin() duration_s = stable_cutoff / AudioStream.FPS start_ts = self.collector.begin() - + # Get the filtered audio first, then extract the portion we need filtered_audio = self.collector.getAudio() commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ] - + # Now drop the prefix from the collector self.collector.dropAudioPrefixByFrames(stable_cutoff) segments = self.whisper.transcribe(commit_audio) delta = ''.join(s.transcript for s in segments) - + # Update whisper's context with the committed text if delta.strip(): self.whisper.update_context(delta.strip()) - + audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: @@ -608,6 +606,88 @@ class VadCommitter: duration_s=duration_s, start_ts=start_ts) + +class StreamingPlugin: + def __init__(self): + pass + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + return commit + + def stop(self): + pass + + +class LowercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_lowercase_filter"]: + commit.delta = commit.delta.lower() + commit.preview = commit.preview.lower() + return commit + + +class UppercasePlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_uppercase_filter"]: + commit.delta = commit.delta.upper() + commit.preview = commit.preview.upper() + return commit + + +class ProfanityPlugin(StreamingPlugin): + def __init__(self, cfg): + self.cfg = cfg + self.filter = None + if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]: + en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en") + try: + self.filter = ProfanityFilter(en_profanity_path) + self.filter.load() + except Exception as e: + print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr) + self.filter = None + + def transform(self, commit: TranscriptCommit) -> TranscriptCommit: + if self.cfg["enable_profanity_filter"] and self.filter: + commit.delta = self.filter.filter(commit.delta) + commit.preview = self.filter.filter(commit.preview) + return commit + + +class PresentationFilter: + def __init__(self): + pass + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + return transcript, preview + + def stop(self): + pass + + +class TrailingPeriodFilter(PresentationFilter): + def __init__(self, cfg): + self.cfg = cfg + + def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]: + if self.cfg["remove_trailing_period"]: + def _remove_trailing_period(s: str) -> str: + if len(s) > 0 and s[-1] == '.' and not s.endswith("..."): + s = s[0:len(s)-1] + return s + if len(preview) == 0: + transcript = _remove_trailing_period(transcript) + else: + preview = _remove_trailing_period(preview) + return transcript, preview + + def transcriptionThread(shared_data: SharedThreadData): last_stable_commit = None @@ -621,6 +701,17 @@ def transcriptionThread(shared_data: SharedThreadData): max_speech_s=shared_data.cfg["max_speech_duration_s"]) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + plugins = [] + # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet + plugins.append(UppercasePlugin(shared_data.cfg)) + plugins.append(LowercasePlugin(shared_data.cfg)) + plugins.append(ProfanityPlugin(shared_data.cfg)) + # plugins.append(UwuPlugin(shared_data.cfg)) # Not implemented yet + # plugins.append(BrowserSource(shared_data.cfg)) # Not implemented yet + + filters = [] + filters.append(TrailingPeriodFilter(shared_data.cfg)) + transcript = "" preview = "" @@ -633,6 +724,9 @@ def transcriptionThread(shared_data: SharedThreadData): commit = committer.getDelta() + for plugin in plugins: + commit = plugin.transform(commit) + if len(commit.delta) > 0 or len(commit.preview) > 0: # Avoid re-sending text after long pauses if shared_data.cfg["reset_after_silence_s"] > 0: @@ -664,6 +758,9 @@ def transcriptionThread(shared_data: SharedThreadData): transcript = join_segments(transcript, commit.delta) preview = commit.preview + for filt in filters: + transcript, preview = filt.transform(transcript, preview) + try: print(f"Transcript: {transcript}", flush=True) except UnicodeEncodeError: @@ -691,4 +788,8 @@ def transcriptionThread(shared_data: SharedThreadData): (not commit.delta.endswith(' ')) and \ (not commit.preview.startswith(' ')): commit.preview = ' ' + commit.preview + for plugin in plugins: + plugin.stop() + for filt in filters: + filt.stop() diff --git a/config.yaml b/config.yaml index 5eec7a2..fea03bb 100644 --- a/config.yaml +++ b/config.yaml @@ -1,18 +1,24 @@ compute_type: float16 -enable_debug_mode: 0 -enable_previews: 1 -user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. -save_audio: 0 language: english +model: turbo +microphone: 2 +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. gpu_idx: 0 max_speech_duration_s: 10 min_silence_duration_ms: 250 -microphone: 0 -model: turbo reset_after_silence_s: 15 transcription_loop_delay_ms: 100 -use_cpu: 0 block_width: 2 num_blocks: 40 rows: 10 cols: 24 +beam_size: 5 +best_of: 5 +enable_debug_mode: 0 +enable_previews: 1 +save_audio: 0 +use_cpu: 0 +enable_lowercase_filter: 0 +enable_uppercase_filter: 0 +enable_profanity_filter: 0 +remove_trailing_period: 0 diff --git a/ui/config-schema.js b/ui/config-schema.js new file mode 100644 index 0000000..b1108ff --- /dev/null +++ b/ui/config-schema.js @@ -0,0 +1,49 @@ +// Shared configuration schema with types and defaults +const CONFIG_SCHEMA = { + // String fields + compute_type: { type: 'select', default: 'float16' }, + language: { type: 'select', default: 'english' }, + model: { type: 'select', default: 'turbo' }, + microphone: { type: 'number', default: 0 }, + user_prompt: { type: 'text', default: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.' }, + + // Number fields + gpu_idx: { type: 'number', default: 0 }, + max_speech_duration_s: { type: 'number', default: 10 }, + min_silence_duration_ms: { type: 'number', default: 250 }, + reset_after_silence_s: { type: 'number', default: 15 }, + transcription_loop_delay_ms: { type: 'number', default: 100 }, + block_width: { type: 'number', default: 2 }, + num_blocks: { type: 'number', default: 40 }, + rows: { type: 'number', default: 10 }, + cols: { type: 'number', default: 24 }, + beam_size: { type: 'number', default: 5 }, + best_of: { type: 'number', default: 5 }, + + // Boolean fields (stored as 1/0) + enable_debug_mode: { type: 'boolean', default: 0 }, + enable_previews: { type: 'boolean', default: 1 }, + save_audio: { type: 'boolean', default: 0 }, + use_cpu: { type: 'boolean', default: 0 }, + enable_lowercase_filter: { type: 'boolean', default: 0 }, + enable_uppercase_filter: { type: 'boolean', default: 0 }, + enable_profanity_filter: { type: 'boolean', default: 0 }, + remove_trailing_period: { type: 'boolean', default: 0 } +}; + +// Helper to extract just the default values +function getDefaultConfig() { + const defaults = {}; + for (const [key, schema] of Object.entries(CONFIG_SCHEMA)) { + defaults[key] = schema.default; + } + return defaults; +} + +// Export for both CommonJS (main process) and ES modules (renderer) +if (typeof module !== 'undefined' && module.exports) { + module.exports = { CONFIG_SCHEMA, getDefaultConfig }; +} else { + window.CONFIG_SCHEMA = CONFIG_SCHEMA; + window.getDefaultConfig = getDefaultConfig; +} \ No newline at end of file diff --git a/ui/index.html b/ui/index.html index 90f78c1..97da3d2 100644 --- a/ui/index.html +++ b/ui/index.html @@ -10,9 +10,9 @@
-
+
- + diff --git a/ui/index.js b/ui/index.js index 2420ece..7717c92 100644 --- a/ui/index.js +++ b/ui/index.js @@ -4,6 +4,7 @@ const fs = require('node:fs').promises; const yaml = require('js-yaml'); const { spawn } = require('child_process'); const https = require('https'); +const { CONFIG_SCHEMA, getDefaultConfig } = require('./config-schema.js'); const APP_ROOT = path.join(__dirname, '..'); const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); @@ -82,6 +83,14 @@ function downloadFile(url, outputPath) { }); } +function shouldFilterMessage(message) { + // Filter out pydub ffmpeg/avconv warning. It does not actually matter. + if (message.includes("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work")) { + return true; + } + return false; +} + // Helper function to setup process event handlers function setupProcessHandlers(process) { process.stdout.on('data', (data) => { @@ -91,7 +100,9 @@ function setupProcessHandlers(process) { process.stderr.on('data', (data) => { const text = data.toString(); - sendPythonOutput(text.trimEnd(), 'stderr'); + if (!shouldFilterMessage(text)) { + sendPythonOutput(text.trimEnd(), 'stderr'); + } }); process.on('error', (error) => { @@ -137,7 +148,10 @@ function executePythonCommand(args, options = {}) { pythonProcess.stderr.on('data', (data) => { const text = data.toString(); stderr += text; - sendPythonOutput(text.trimEnd(), 'stderr'); + // Filter out specific warning messages + if (!shouldFilterMessage(text)) { + sendPythonOutput(text.trimEnd(), 'stderr'); + } }); pythonProcess.on('error', (error) => { @@ -171,27 +185,8 @@ function createWindow () { mainWindow.loadFile('index.html'); } -// Default configuration based on user's current config.yaml -const DEFAULT_CONFIG = { - compute_type: 'float16', - enable_debug_mode: 0, - enable_previews: 1, - user_prompt: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc.', - save_audio: 0, - language: 'english', - gpu_idx: 0, - max_speech_duration_s: 10, - min_silence_duration_ms: 250, - microphone: 0, - model: 'turbo', - reset_after_silence_s: 15, - transcription_loop_delay_ms: 100, - use_cpu: 0, - block_width: 2, - num_blocks: 40, - rows: 10, - cols: 24 -}; +// Replace the DEFAULT_CONFIG constant with: +const DEFAULT_CONFIG = getDefaultConfig(); // IPC handlers ipcMain.handle('load-config', async () => { @@ -521,12 +516,12 @@ ipcMain.handle('start-process', async () => { }); ipcMain.handle('stop-process', async () => { - if (!runningProcess) { - throw new Error('No process is running'); - } - return new Promise((resolve) => { let forcefullyKilled = false; + + if (!runningProcess) { + resolve({ success: true, forcefullyKilled }); + } // Set up a timeout to force kill after 10 seconds const killTimeout = setTimeout(() => { diff --git a/ui/renderer.js b/ui/renderer.js index 201eef6..133a79b 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -1,29 +1,5 @@ -// Configuration and form field mappings -const CONFIG_FIELDS = { - // String fields - compute_type: { type: 'select', default: 'float16' }, - language: { type: 'select', default: 'english' }, - model: { type: 'select', default: 'turbo' }, - microphone: { type: 'number', default: 0 }, - user_prompt: { type: 'text', default: '' }, - - // Number fields - gpu_idx: { type: 'number', default: 0 }, - max_speech_duration_s: { type: 'number', default: 10 }, - min_silence_duration_ms: { type: 'number', default: 250 }, - reset_after_silence_s: { type: 'number', default: 15 }, - transcription_loop_delay_ms: { type: 'number', default: 100 }, - block_width: { type: 'number', default: 2 }, - num_blocks: { type: 'number', default: 40 }, - rows: { type: 'number', default: 10 }, - cols: { type: 'number', default: 24 }, - - // Boolean fields (stored as 1/0) - enable_debug_mode: { type: 'boolean', default: 0 }, - enable_previews: { type: 'boolean', default: 1 }, - save_audio: { type: 'boolean', default: 0 }, - use_cpu: { type: 'boolean', default: 0 } -}; +// Import configuration schema +const CONFIG_FIELDS = window.CONFIG_SCHEMA; // Button management system class ButtonManager { @@ -35,6 +11,9 @@ class ButtonManager { resetVenv: document.getElementById('reset-venv'), refreshMicrophones: document.getElementById('refresh-microphones') }; + + // Initialize button states on construction + this.setProcessStopped(); } setState(buttonName, disabled) { -- cgit v1.2.3 From 73de7cb2d8fb964e7f76ab55420e9bc331bf7bea Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 30 May 2025 21:31:05 -0700 Subject: More stuff - add desktop and vr input threads - add audio feedback for input - add volume control for audio feedback - add UI for custom chatbox/built in chatbox - add ability to dismiss built in chatbox (sync empty messages) - limit lines in python console - limit length of each transcript --- Sounds/Dismiss_Noise.wav | Bin 0 -> 192078 bytes Sounds/Dismiss_Noise_Quiet.wav | Bin 0 -> 192078 bytes Sounds/KB_Noise_Off.wav | Bin 0 -> 192078 bytes Sounds/KB_Noise_Off_Quiet.wav | Bin 0 -> 192078 bytes Sounds/KB_Noise_On.wav | Bin 0 -> 266318 bytes Sounds/KB_Noise_On_Quiet.wav | Bin 0 -> 266318 bytes Sounds/Noise_Off.wav | Bin 0 -> 67278 bytes Sounds/Noise_Off_Quiet.wav | Bin 0 -> 67278 bytes Sounds/Noise_On.wav | Bin 0 -> 67278 bytes Sounds/Noise_On_Quiet.wav | Bin 0 -> 67278 bytes Sounds/speech_noise.wav | Bin 0 -> 61518 bytes app/hi.py | 308 ++++++++++++++++++++++++++++++++++------- app/keybind_event_machine.py | 21 +++ app/requirements.txt | 3 + app/shared_thread_data.py | 7 +- app/steamvr.py | 87 ++++++++++++ app/stt.py | 143 ++++++++++--------- config.yaml | 15 +- ui/config-schema.js | 11 +- ui/index.html | 50 +++++++ ui/index.js | 16 ++- ui/preload.js | 1 + ui/renderer.js | 58 ++++++++ 23 files changed, 595 insertions(+), 125 deletions(-) create mode 100644 Sounds/Dismiss_Noise.wav create mode 100644 Sounds/Dismiss_Noise_Quiet.wav create mode 100644 Sounds/KB_Noise_Off.wav create mode 100644 Sounds/KB_Noise_Off_Quiet.wav create mode 100644 Sounds/KB_Noise_On.wav create mode 100644 Sounds/KB_Noise_On_Quiet.wav create mode 100644 Sounds/Noise_Off.wav create mode 100644 Sounds/Noise_Off_Quiet.wav create mode 100644 Sounds/Noise_On.wav create mode 100644 Sounds/Noise_On_Quiet.wav create mode 100644 Sounds/speech_noise.wav create mode 100644 app/keybind_event_machine.py create mode 100644 app/steamvr.py diff --git a/Sounds/Dismiss_Noise.wav b/Sounds/Dismiss_Noise.wav new file mode 100644 index 0000000..fe60f21 Binary files /dev/null and b/Sounds/Dismiss_Noise.wav differ diff --git a/Sounds/Dismiss_Noise_Quiet.wav b/Sounds/Dismiss_Noise_Quiet.wav new file mode 100644 index 0000000..5c3b1cb Binary files /dev/null and b/Sounds/Dismiss_Noise_Quiet.wav differ diff --git a/Sounds/KB_Noise_Off.wav b/Sounds/KB_Noise_Off.wav new file mode 100644 index 0000000..64d9c6f Binary files /dev/null and b/Sounds/KB_Noise_Off.wav differ diff --git a/Sounds/KB_Noise_Off_Quiet.wav b/Sounds/KB_Noise_Off_Quiet.wav new file mode 100644 index 0000000..b965e6a Binary files /dev/null and b/Sounds/KB_Noise_Off_Quiet.wav differ diff --git a/Sounds/KB_Noise_On.wav b/Sounds/KB_Noise_On.wav new file mode 100644 index 0000000..a959041 Binary files /dev/null and b/Sounds/KB_Noise_On.wav differ diff --git a/Sounds/KB_Noise_On_Quiet.wav b/Sounds/KB_Noise_On_Quiet.wav new file mode 100644 index 0000000..e49513e Binary files /dev/null and b/Sounds/KB_Noise_On_Quiet.wav differ diff --git a/Sounds/Noise_Off.wav b/Sounds/Noise_Off.wav new file mode 100644 index 0000000..0d3843c Binary files /dev/null and b/Sounds/Noise_Off.wav differ diff --git a/Sounds/Noise_Off_Quiet.wav b/Sounds/Noise_Off_Quiet.wav new file mode 100644 index 0000000..d5c6171 Binary files /dev/null and b/Sounds/Noise_Off_Quiet.wav differ diff --git a/Sounds/Noise_On.wav b/Sounds/Noise_On.wav new file mode 100644 index 0000000..28c8f6b Binary files /dev/null and b/Sounds/Noise_On.wav differ diff --git a/Sounds/Noise_On_Quiet.wav b/Sounds/Noise_On_Quiet.wav new file mode 100644 index 0000000..79170f5 Binary files /dev/null and b/Sounds/Noise_On_Quiet.wav differ diff --git a/Sounds/speech_noise.wav b/Sounds/speech_noise.wav new file mode 100644 index 0000000..a6224ee Binary files /dev/null and b/Sounds/speech_noise.wav differ diff --git a/app/hi.py b/app/hi.py index bab0fd4..1297b37 100644 --- a/app/hi.py +++ b/app/hi.py @@ -1,25 +1,34 @@ import app_config import argparse import io +import keybind_event_machine from math import floor, ceil import msvcrt import os from pythonosc import udp_client import sentencepiece as spm +import steamvr from shared_thread_data import SharedThreadData import stt import sys import threading import time +import pygame sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') +# Initialize pygame mixer +pygame.mixer.init() + TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose LOG_LEVEL = 0 +# Global volume control (0.0 to 1.0) +VOLUME = 0.3 + APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -315,79 +324,276 @@ def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg): send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) def osc_thread(shared_data: SharedThreadData): - tokenizer = get_tokenizer() osc_client = getOscClient() - # Prime the board - print("Priming the board") - input_state = InputState() - handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) + def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + + if shared_data.cfg["use_builtin"]: + last_change = time.time() + remote_word = "" + while not shared_data.exit_event.is_set(): + time.sleep(0.1) + local_word = "" + with shared_data.word_lock: + local_word = join_segments(shared_data.transcript, + shared_data.preview) + local_word = local_word[-140:] + if local_word == remote_word: + continue + if time.time() - last_change < 1.5: + continue + addr = "/chatbox/input" + print(f"Send {local_word}", flush=True) + osc_client.send_message(addr, (local_word, True, False)) + last_change = time.time() + remote_word = local_word + else: + # Custom chatbox + tokenizer = get_tokenizer() + + # Prime the board + print("Priming the board") + input_state = InputState() + handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg) + + while not shared_data.exit_event.is_set(): + word_copy = "" + with shared_data.word_lock: + word_copy = shared_data.word + handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) + time.sleep(0.01) + + +def vrInputThread(shared_data: SharedThreadData): + RECORD_STATE = 0 + PAUSE_STATE = 1 + state = PAUSE_STATE + + hand_id = shared_data.cfg["button_hand"] + button_id = shared_data.cfg["button_type"] + # Rough description of state machine: + # Single short press: toggle transcription + # Medium press: dismiss custom chatbox + # Long press: update chatbox in place + # Medium press + long press: type transcription + + last_rising = time.time() + last_medium_press_end = 0 + + waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav") + waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav") + waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav") + waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav") + + button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id, + shared_data=shared_data) while not shared_data.exit_event.is_set(): - word_copy = "" + time.sleep(0.01) + try: + event = next(button_generator) + except StopIteration: + break + with shared_data.word_lock: - word_copy = shared_data.word - handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg) + if not shared_data.stream or not shared_data.collector: + continue + + if event.opcode == steamvr.EVENT_RISING_EDGE: + last_rising = time.time() + + if state == PAUSE_STATE: + shared_data.stream.pause(False) + shared_data.stream.getSamples() + + elif event.opcode == steamvr.EVENT_FALLING_EDGE: + now = time.time() + if now - last_rising > 1.5: + # Long press: treat as the end of transcription. + state = PAUSE_STATE + + shared_data.stream.pause(True) + + if last_rising - last_medium_press_end < 1.0: + # Type transcription + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform3) + else: + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform1) + + elif now - last_rising > 0.5: + # Medium press + print("CLEARING", file=sys.stderr) + last_medium_press_end = now + state = PAUSE_STATE + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform2) + + # Flush the *entire* pipeline. + shared_data.stream.pause(True) + shared_data.stream.getSamples() + shared_data.collector.dropAudio() + shared_data.transcript = "" + shared_data.preview = "" + continue + + # Short hold + if state == RECORD_STATE: + print("PAUSED", file=sys.stderr) + state = PAUSE_STATE + + shared_data.stream.pause(True) + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform1) + elif state == PAUSE_STATE: + print("RECORDING", file=sys.stderr) + state = RECORD_STATE + if shared_data.cfg["reset_on_toggle"]: + if shared_data.cfg["enable_debug_mode"]: + print("Toggle detected, dropping transcript (3)", + file=sys.stderr) + shared_data.transcript = "" + shared_data.preview = "" + #audio_state.drop_transcription = True + else: + if shared_data.cfg["enable_debug_mode"]: + print("Toggle detected, committing preview text (3)", + file=sys.stderr) + #audio_state.text += audio_state.preview_text + + shared_data.stream.pause(False) + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform0) + + +def kbInputThread(shared_data: SharedThreadData): + machine = keybind_event_machine.KeybindEventMachine(shared_data.cfg["keybind"]) + last_press_time = 0 + + # double pressing the keybind + double_press_timeout = 0.5 + + RECORD_STATE = 0 + PAUSE_STATE = 1 + state = PAUSE_STATE + + waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav") + waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav") + waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav") + waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav") + + while not shared_data.exit_event.is_set(): time.sleep(0.01) + cur_press_time = machine.getNextPressTime() + if cur_press_time == 0: + continue + + with shared_data.word_lock: + if not shared_data.stream or not shared_data.collector: + continue + + EVENT_SINGLE_PRESS = 0 + EVENT_DOUBLE_PRESS = 1 + if last_press_time == 0: + event = EVENT_SINGLE_PRESS + elif cur_press_time - last_press_time < double_press_timeout: + event = EVENT_DOUBLE_PRESS + else: + event = EVENT_SINGLE_PRESS + last_press_time = cur_press_time + + if event == EVENT_DOUBLE_PRESS: + print("CLEARING", file=sys.stderr) + state = PAUSE_STATE + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform2) + + # Flush the *entire* pipeline. + shared_data.stream.pause(True) + shared_data.stream.getSamples() + shared_data.collector.dropAudio() + shared_data.transcript = "" + shared_data.preview = "" + continue + + # Short hold + if state == RECORD_STATE: + print("PAUSED", file=sys.stderr) + state = PAUSE_STATE + + shared_data.stream.pause(True) + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform1) + elif state == PAUSE_STATE: + print("RECORDING", file=sys.stderr) + state = RECORD_STATE + if shared_data.cfg["reset_on_toggle"]: + if shared_data.cfg["enable_debug_mode"]: + print("Toggle detected, dropping transcript (2)", + file=sys.stderr) + shared_data.transcript = "" + shared_data.preview = "" + else: + if shared_data.cfg["enable_debug_mode"]: + print("Toggle detected, committing preview text (2)", + file=sys.stderr) + #audio_state.text += audio_state.preview_text + + shared_data.stream.pause(False) + + if shared_data.cfg["enable_local_beep"]: + play_sound_with_volume(waveform0) + +def play_sound_with_volume(filepath): + """Play a WAV file with adjusted volume""" + volume = VOLUME + + try: + sound = pygame.mixer.Sound(filepath) + sound.set_volume(volume) + sound.play() + except Exception as e: + print(f"Error playing sound {filepath}: {e}", file=sys.stderr) + if __name__ == "__main__": cli_args = parse_args() cfg = app_config.getConfig(cli_args.config) shared_data = SharedThreadData(cfg) - if False: - osc_thread = threading.Thread( - target=osc_thread, - args=(shared_data,)) - osc_thread.start() + osc_thread = threading.Thread( + target=osc_thread, + args=(shared_data,)) + osc_thread.start() transcribe_thread = threading.Thread( target=stt.transcriptionThread, args=(shared_data,)) transcribe_thread.start() + vr_input_thd = threading.Thread(target=vrInputThread, args=(shared_data,)) + vr_input_thd.start() + + kb_input_thd = threading.Thread(target=kbInputThread, args=(shared_data,)) + kb_input_thd.start() + word_is_over = False local_word = "" while True: - char_bytes = msvcrt.getch() - if char_bytes == b'\x03': # ctrl+C - break - time.sleep(0.1) continue - - try: - char = char_bytes.decode('utf-8') - if char == '\r' or char == '\n': - word_is_over = True - continue - except UnicodeDecodeError: - print(f"Unsupported character: {char_bytes}") - if char_bytes == b'\x00' or char_bytes == b'\xe0': - special_char = msvcrt.getch() - continue - - if char_bytes == b'\x03': # ctrl+C - break - elif char_bytes == b'\x08': # backspace - with shared_data.word_lock: - shared_data.word = shared_data.word[:-1] - local_word = shared_data.word - elif char_bytes == b'\x0c': # ctrl+L - with shared_data.word_lock: - shared_data.word = "" - local_word = shared_data.word - elif word_is_over: - with shared_data.word_lock: - shared_data.word = char - local_word = shared_data.word - word_is_over = False - else: - with shared_data.word_lock: - shared_data.word += char - local_word = shared_data.word - print(local_word + "_") shared_data.exit_event.set() - if False: - osc_thread.join() + osc_thread.join() transcribe_thread.join() + vr_input_thd.join() + kb_input_thd.join() diff --git a/app/keybind_event_machine.py b/app/keybind_event_machine.py new file mode 100644 index 0000000..3ce6794 --- /dev/null +++ b/app/keybind_event_machine.py @@ -0,0 +1,21 @@ +import keyboard +import time + +class KeybindEventMachine: + def __init__(self, keybind: str): + self.keybind = keybind + self.events = [] + keyboard.add_hotkey(keybind, self.onPress) + + def onPress(self) -> None: + self.events.append(time.time()) + + # Returns the timestamp when the keybind was pressed, or 0 if no keypresses + # are queued. + def getNextPressTime(self) -> int: + if len(self.events) == 0: + return 0 + ret = self.events[0] + self.events = self.events[1:] + return ret + diff --git a/app/requirements.txt b/app/requirements.txt index f8b7069..e68a16c 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,8 +1,11 @@ faster-whisper hf-xet +keyboard langcodes pyaudio +pygame pydub python-osc sentencepiece silero-vad +openvr diff --git a/app/shared_thread_data.py b/app/shared_thread_data.py index ba0a419..40885e8 100644 --- a/app/shared_thread_data.py +++ b/app/shared_thread_data.py @@ -2,7 +2,12 @@ import threading class SharedThreadData: def __init__(self, cfg): - self.word = "" + self.transcript = "" + self.preview = "" + + self.stream = None + self.collector = None + self.word_lock = threading.Lock() self.exit_event = threading.Event() self.cfg = cfg diff --git a/app/steamvr.py b/app/steamvr.py new file mode 100644 index 0000000..64f34f5 --- /dev/null +++ b/app/steamvr.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import openvr as vr +import sys +import time + +EVENT_NONE = 0 +EVENT_RISING_EDGE = 1 +EVENT_FALLING_EDGE = 2 + +class InputEvent: + def __init__(self, + opcode: int): + self.opcode = opcode + +# Checks if the given button on the given controller is pressed. +def pollButtonPress( + hand: str = "right", + button: str = "b", + shared_data = None # SharedThreadData object + ) -> int: + hands = {} + hands["left"] = vr.TrackedControllerRole_LeftHand + hands["right"] = vr.TrackedControllerRole_RightHand + + buttons = {} + buttons["a"] = vr.k_EButton_IndexController_A + buttons["b"] = vr.k_EButton_IndexController_B + buttons["thumbstick"] = vr.k_EButton_Axis0 + + system = None + first = True + while not shared_data.exit_event.is_set() and not system: + try: + system = vr.init(vr.VRApplication_Background) + except Exception as e: + if first: + print(f"Failed to start steamVR input thread: {repr(e)}", file=sys.stderr) + first = False + time.sleep(1) + last_packet = 0 + event_high = False + + while not shared_data.exit_event.is_set(): + time.sleep(0.01) + + lh_idx = system.getTrackedDeviceIndexForControllerRole(hands[hand]) + #print("left hand device idx: {}".format(lh_idx)) + + got_state, state = system.getControllerState(lh_idx) + if not got_state: + continue + + if state.unPacketNum == last_packet: + continue + + # Clicking joysticks and moving joysticks fire the same events. To + # differentiate movement from clicking, we create a dead zone: if the event + # fires while the stick isn't moved far from center, we assume it's a + # click, not movement. + dead_zone_radius = 0.7 + + button_mask = (1 << buttons[button]) + ret = EVENT_NONE + if (state.ulButtonPressed & button_mask) != 0 and\ + (state.rAxis[0].x**2 + state.rAxis[0].y**2 < dead_zone_radius**2): + #print("button pressed: %016x" % state.ulButtonPressed) + #for i in range(0, 5): + # print("axis {} x: {} y: {}".format(i, state.rAxis[i].x, state.rAxis[i].y)) + if not event_high: + yield InputEvent(EVENT_RISING_EDGE) + event_high = True + elif event_high: + event_high = False + yield InputEvent(EVENT_FALLING_EDGE) + +if __name__ == "__main__": + gen = pollButtonPress() + while True: + time.sleep(0.1) + + event = pollButtonPress(session_state) + if event == EVENT_RISING_EDGE: + print("rising edge") + elif event == EVENT_FALLING_EDGE: + print("falling edge") + diff --git a/app/stt.py b/app/stt.py index a3988e1..c1f4836 100644 --- a/app/stt.py +++ b/app/stt.py @@ -299,9 +299,11 @@ class CompressingAudioCollector(AudioCollectorFilter): class AudioSegmenter: def __init__(self, min_silence_ms=250, - max_speech_s=5): + max_speech_s=5, + min_speech_duration_ms=100): self.min_silence_ms = min_silence_ms self.max_speech_s = max_speech_s + self.min_speech_duration_ms = min_speech_duration_ms # Load Silero VAD model self.model = load_silero_vad() @@ -309,6 +311,7 @@ class AudioSegmenter: self.vad_threshold = 0.3 self.min_silence_duration_ms = min_silence_ms self.max_speech_duration_s = max_speech_s + self.min_speech_duration_ms = min_speech_duration_ms def segmentAudio(self, audio: bytes): # Convert audio bytes to numpy array expected by silero-vad @@ -324,6 +327,7 @@ class AudioSegmenter: threshold=self.vad_threshold, min_silence_duration_ms=self.min_silence_duration_ms, max_speech_duration_s=self.max_speech_duration_s, + min_speech_duration_ms=self.min_speech_duration_ms, return_seconds=False # We want frame indices, not seconds ) @@ -698,7 +702,8 @@ def transcriptionThread(shared_data: SharedThreadData): collector = NormalizingAudioCollector(collector) whisper = Whisper(collector, shared_data.cfg) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], - max_speech_s=shared_data.cfg["max_speech_duration_s"]) + max_speech_s=shared_data.cfg["max_speech_duration_s"], + min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"]) committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) plugins = [] @@ -715,6 +720,10 @@ def transcriptionThread(shared_data: SharedThreadData): transcript = "" preview = "" + with shared_data.word_lock: + shared_data.stream = stream + shared_data.collector = collector + print(f"Ready to go!", flush=True) while not shared_data.exit_event.is_set(): @@ -724,70 +733,72 @@ def transcriptionThread(shared_data: SharedThreadData): commit = committer.getDelta() - for plugin in plugins: - commit = plugin.transform(commit) - - if len(commit.delta) > 0 or len(commit.preview) > 0: - # Avoid re-sending text after long pauses - if shared_data.cfg["reset_after_silence_s"] > 0: - silence_duration = 0 - if last_stable_commit: - last_commit_end_ts = \ - last_stable_commit.start_ts + \ - last_stable_commit.duration_s - silence_duration = commit.start_ts - last_commit_end_ts - if silence_duration > shared_data.cfg["reset_after_silence_s"]: - if shared_data.cfg["enable_debug_mode"]: - print(f"Resetting transcript after {silence_duration}-second " - "silence", file=sys.stderr) - transcript = "" - preview = "" - whisper.recent_context = "" # Reset context too - if commit.delta: - last_stable_commit = commit - - # Hard-cap displayed transcript length at 4k characters to prevent - # runaway memory use in UI. Keep the full transcript to avoid - # breaking OSC pager. - transcript = transcript[-4096:] - def join_segments(a, b): - if len(a) > 0 and a[-1] != ' ': - return a + ' ' + b - else: - return a + b - transcript = join_segments(transcript, commit.delta) - preview = commit.preview - - for filt in filters: - transcript, preview = filt.transform(transcript, preview) - - try: - print(f"Transcript: {transcript}", flush=True) - except UnicodeEncodeError: - print("Failed to encode transcript - discarding delta", - file=sys.stderr) - continue - try: - print(f"Preview: {preview}", flush=True) - except UnicodeEncodeError: - print("Failed to encode preview - discarding", file=sys.stderr) - - with shared_data.word_lock: - shared_data.word = join_segments(transcript, preview) - - if shared_data.cfg["enable_debug_mode"]: - print(f"commit latency: {commit.latency_s}", file=sys.stderr) - print(f"commit thresh: {commit.thresh_at_commit}", - file=sys.stderr) - - if len(transcript) > 0 and \ - (not transcript.endswith(' ')) and \ - (not commit.delta.startswith(' ')): - commit.delta = ' ' + commit.delta - if len(commit.delta) > 0 and \ - (not commit.delta.endswith(' ')) and \ - (not commit.preview.startswith(' ')): - commit.preview = ' ' + commit.preview + with shared_data.word_lock: + for plugin in plugins: + commit = plugin.transform(commit) + + if len(commit.delta) > 0 or len(commit.preview) > 0: + # Avoid re-sending text after long pauses + if shared_data.cfg["reset_after_silence_s"] > 0: + silence_duration = 0 + if last_stable_commit: + last_commit_end_ts = \ + last_stable_commit.start_ts + \ + last_stable_commit.duration_s + silence_duration = commit.start_ts - last_commit_end_ts + if silence_duration > shared_data.cfg["reset_after_silence_s"]: + if shared_data.cfg["enable_debug_mode"]: + print(f"Resetting transcript after {silence_duration}-second " + "silence", file=sys.stderr) + shared_data.transcript = "" + shared_data.preview = "" + whisper.recent_context = "" # Reset context too + if commit.delta: + last_stable_commit = commit + + # Hard-cap displayed transcript length to prevent + # runaway memory use in UI. Keep the full transcript to avoid + # breaking OSC pager. + if len(shared_data.transcript) >= 1024: + shared_data.transcript = shared_data.transcript[-512:] + def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + shared_data.transcript = \ + join_segments(shared_data.transcript, commit.delta) + shared_data.preview = commit.preview + + for filt in filters: + shared_data.transcript, shared_data.preview = \ + filt.transform(shared_data.transcript, + shared_data.preview) + + try: + print(f"Transcript: {shared_data.transcript}", flush=True) + except UnicodeEncodeError: + print("Failed to encode transcript - discarding delta", + file=sys.stderr) + continue + try: + print(f"Preview: {shared_data.preview}", flush=True) + except UnicodeEncodeError: + print("Failed to encode preview - discarding", file=sys.stderr) + + if shared_data.cfg["enable_debug_mode"]: + print(f"commit latency: {commit.latency_s}", file=sys.stderr) + print(f"commit thresh: {commit.thresh_at_commit}", + file=sys.stderr) + + if len(shared_data.transcript) > 0 and \ + (not shared_data.transcript.endswith(' ')) and \ + (not commit.delta.startswith(' ')): + commit.delta = ' ' + commit.delta + if len(commit.delta) > 0 and \ + (not commit.delta.endswith(' ')) and \ + (not commit.preview.startswith(' ')): + commit.preview = ' ' + commit.preview for plugin in plugins: plugin.stop() for filt in filters: diff --git a/config.yaml b/config.yaml index fea03bb..6f4b65b 100644 --- a/config.yaml +++ b/config.yaml @@ -1,11 +1,15 @@ compute_type: float16 language: english model: turbo -microphone: 2 -user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. +microphone: 1 +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. Phi, NOPPERS, clearrainbow, Noia, Kuuderekitten. +keybind: ctrl+alt+x +button_hand: right +button_type: b gpu_idx: 0 max_speech_duration_s: 10 -min_silence_duration_ms: 250 +min_speech_duration_ms: 250 +min_silence_duration_ms: 100 reset_after_silence_s: 15 transcription_loop_delay_ms: 100 block_width: 2 @@ -16,9 +20,12 @@ beam_size: 5 best_of: 5 enable_debug_mode: 0 enable_previews: 1 -save_audio: 0 +save_audio: 1 use_cpu: 0 enable_lowercase_filter: 0 enable_uppercase_filter: 0 enable_profanity_filter: 0 remove_trailing_period: 0 +reset_on_toggle: 0 +enable_local_beep: 1 +use_builtin: 1 diff --git a/ui/config-schema.js b/ui/config-schema.js index b1108ff..6b11277 100644 --- a/ui/config-schema.js +++ b/ui/config-schema.js @@ -6,11 +6,15 @@ const CONFIG_SCHEMA = { model: { type: 'select', default: 'turbo' }, microphone: { type: 'number', default: 0 }, user_prompt: { type: 'text', default: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.' }, + keybind: { type: 'text', default: 'ctrl+alt+x' }, + button_hand: { type: 'select', default: 'right' }, + button_type: { type: 'select', default: 'b' }, // Number fields gpu_idx: { type: 'number', default: 0 }, max_speech_duration_s: { type: 'number', default: 10 }, - min_silence_duration_ms: { type: 'number', default: 250 }, + min_speech_duration_ms: { type: 'number', default: 250 }, + min_silence_duration_ms: { type: 'number', default: 100 }, reset_after_silence_s: { type: 'number', default: 15 }, transcription_loop_delay_ms: { type: 'number', default: 100 }, block_width: { type: 'number', default: 2 }, @@ -28,7 +32,10 @@ const CONFIG_SCHEMA = { enable_lowercase_filter: { type: 'boolean', default: 0 }, enable_uppercase_filter: { type: 'boolean', default: 0 }, enable_profanity_filter: { type: 'boolean', default: 0 }, - remove_trailing_period: { type: 'boolean', default: 0 } + remove_trailing_period: { type: 'boolean', default: 0 }, + reset_on_toggle: { type: 'boolean', default: 0 }, + enable_local_beep: { type: 'boolean', default: 1 }, + use_builtin: { type: 'boolean', default: 1 } }; // Helper to extract just the default values diff --git a/ui/index.html b/ui/index.html index 97da3d2..99e64dd 100644 --- a/ui/index.html +++ b/ui/index.html @@ -64,6 +64,31 @@
+
+ + +
+
+ + +
+
+ + +
@@ -110,6 +135,10 @@
+
+ + +
@@ -211,9 +240,30 @@
+ +
+

Input Settings

+
+ + +
+
+

Custom Chatbox Settings

+
+ +
diff --git a/ui/index.js b/ui/index.js index 7717c92..24a7e13 100644 --- a/ui/index.js +++ b/ui/index.js @@ -246,6 +246,21 @@ ipcMain.handle('reset-config', async () => { } }); +ipcMain.handle('deleteVenvIndicatorFile', async () => { + const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up'); + try { + await fs.unlink(venvMarkerPath); + return { success: true, message: '.venv_is_set_up deleted successfully.' }; + } catch (error) { + if (error.code === 'ENOENT') { + return { success: true, message: '.venv_is_set_up not found.' }; + } + console.error('Error deleting .venv_is_set_up file:', error); + sendPythonOutput(`Error deleting .venv_is_set_up: ${error.message}`, 'stderr'); + throw error; + } +}); + // Generic function to ensure required files are present async function ensureRequiredFiles(config) { const { @@ -332,7 +347,6 @@ ipcMain.handle('install-requirements', async () => { // Check if venv is already set up try { await fs.access(venvMarkerPath); - sendPythonOutput('Virtual environment already set up, skipping installation', 'info'); return { success: true, message: 'Virtual environment already set up' }; } catch (error) { // Marker doesn't exist, proceed with setup diff --git a/ui/preload.js b/ui/preload.js index 35cc8d6..f2e0a81 100644 --- a/ui/preload.js +++ b/ui/preload.js @@ -6,6 +6,7 @@ contextBridge.exposeInMainWorld('electronAPI', { resetConfig: () => ipcRenderer.invoke('reset-config'), getMicrophones: () => ipcRenderer.invoke('get-microphones'), installRequirements: () => ipcRenderer.invoke('install-requirements'), + deleteVenvIndicatorFile: () => ipcRenderer.invoke('deleteVenvIndicatorFile'), resetVenv: () => ipcRenderer.invoke('reset-venv'), startProcess: () => ipcRenderer.invoke('start-process'), stopProcess: () => ipcRenderer.invoke('stop-process'), diff --git a/ui/renderer.js b/ui/renderer.js index 133a79b..2f4c8f1 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -162,11 +162,28 @@ function setFormValues(config) { } } + // Handle use_builtin toggle state + const useBuiltin = config.use_builtin === 1; + const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols']; + customChatboxInputs.forEach(inputId => { + const input = document.getElementById(inputId); + if (input) { + input.disabled = useBuiltin; + if (useBuiltin) { + input.classList.add('opacity-50', 'cursor-not-allowed'); + } else { + input.classList.remove('opacity-50', 'cursor-not-allowed'); + } + } + }); + isSettingValues = false; // Re-enable auto-save } // Console management const consoleContent = document.getElementById('console-content'); +const MAX_CONSOLE_LINES = 512; +let consoleLineCount = 0; function appendToConsole(message, type = 'stdout') { const timestamp = new Date().toLocaleTimeString(); @@ -183,6 +200,28 @@ function appendToConsole(message, type = 'stdout') { lineDiv.appendChild(messageSpan); consoleContent.appendChild(lineDiv); + consoleLineCount++; + + // Remove old lines if we exceed the limit + if (consoleLineCount > MAX_CONSOLE_LINES) { + // Calculate how many lines to remove (remove 10% to avoid frequent trimming) + const linesToRemove = Math.floor(MAX_CONSOLE_LINES * 0.1); + + // Remove the oldest lines + for (let i = 0; i < linesToRemove; i++) { + if (consoleContent.firstChild) { + consoleContent.removeChild(consoleContent.firstChild); + } + } + + consoleLineCount -= linesToRemove; + + // Add a notice that lines were trimmed + const trimNotice = document.createElement('div'); + trimNotice.className = 'console-info'; + trimNotice.innerHTML = '[System] ... older lines removed to maintain performance ...'; + consoleContent.insertBefore(trimNotice, consoleContent.firstChild); + } // Auto-scroll to bottom const pythonConsole = document.getElementById('python-console'); @@ -316,11 +355,30 @@ function setupEventHandlers() { } }); + // Use builtin chatbox toggle + document.getElementById('use_builtin').addEventListener('change', (e) => { + const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols']; + const isBuiltin = e.target.checked; + + customChatboxInputs.forEach(inputId => { + const input = document.getElementById(inputId); + if (input) { + input.disabled = isBuiltin; + if (isBuiltin) { + input.classList.add('opacity-50', 'cursor-not-allowed'); + } else { + input.classList.remove('opacity-50', 'cursor-not-allowed'); + } + } + }); + }); + // Setup virtual environment document.getElementById('setup-venv').addEventListener('click', async () => { loadingOverlay.show('Setting up virtual environment - please wait...'); // Show overlay with custom message try { await buttonManager.withButtonLoading('setupVenv', async () => { + await window.electronAPI.deleteVenvIndicatorFile(); await handleAsyncAction('Install requirements', () => window.electronAPI.installRequirements()); }); } finally { -- cgit v1.2.3 From 790c91d7ad515c3c0a22ca1341316265b8f0d779 Mon Sep 17 00:00:00 2001 From: yum Date: Wed, 23 Jul 2025 17:41:49 -0700 Subject: bugfixes * fix model acquisition * fix local beepsnd * fix volume control --- app/hi.py | 45 ++++-------- app/requirements.txt | 1 + app/stt.py | 62 ++++++++++++---- config.yaml | 8 +-- ui/config-schema.js | 2 +- ui/index.html | 13 ++-- ui/index.js | 17 +++-- ui/preload.js | 1 + ui/renderer.js | 198 ++++++++++++++++++++++++++++----------------------- 9 files changed, 196 insertions(+), 151 deletions(-) diff --git a/app/hi.py b/app/hi.py index 1297b37..bb09418 100644 --- a/app/hi.py +++ b/app/hi.py @@ -26,9 +26,6 @@ TESTS_ENABLED = True # 0 = quiet, 1 = verbose, 2 = very verbose LOG_LEVEL = 0 -# Global volume control (0.0 to 1.0) -VOLUME = 0.3 - APP_ROOT = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(APP_ROOT) @@ -347,7 +344,8 @@ def osc_thread(shared_data: SharedThreadData): if time.time() - last_change < 1.5: continue addr = "/chatbox/input" - print(f"Send {local_word}", flush=True) + if shared_data.cfg["enable_debug_mode"]: + print(f"Send {local_word}", flush=True) osc_client.send_message(addr, (local_word, True, False)) last_change = time.time() remote_word = local_word @@ -420,20 +418,16 @@ def vrInputThread(shared_data: SharedThreadData): if last_rising - last_medium_press_end < 1.0: # Type transcription - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform3) + play_sound_with_volume(waveform3, shared_data.cfg) else: - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform1) + play_sound_with_volume(waveform1, shared_data.cfg) elif now - last_rising > 0.5: # Medium press print("CLEARING", file=sys.stderr) last_medium_press_end = now state = PAUSE_STATE - - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform2) + play_sound_with_volume(waveform2, shared_data.cfg) # Flush the *entire* pipeline. shared_data.stream.pause(True) @@ -449,9 +443,7 @@ def vrInputThread(shared_data: SharedThreadData): state = PAUSE_STATE shared_data.stream.pause(True) - - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform1) + play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: print("RECORDING", file=sys.stderr) state = RECORD_STATE @@ -469,9 +461,7 @@ def vrInputThread(shared_data: SharedThreadData): #audio_state.text += audio_state.preview_text shared_data.stream.pause(False) - - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform0) + play_sound_with_volume(waveform0, shared_data.cfg) def kbInputThread(shared_data: SharedThreadData): @@ -514,9 +504,7 @@ def kbInputThread(shared_data: SharedThreadData): if event == EVENT_DOUBLE_PRESS: print("CLEARING", file=sys.stderr) state = PAUSE_STATE - - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform2) + play_sound_with_volume(waveform2, shared_data.cfg) # Flush the *entire* pipeline. shared_data.stream.pause(True) @@ -530,11 +518,8 @@ def kbInputThread(shared_data: SharedThreadData): if state == RECORD_STATE: print("PAUSED", file=sys.stderr) state = PAUSE_STATE - shared_data.stream.pause(True) - - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform1) + play_sound_with_volume(waveform1, shared_data.cfg) elif state == PAUSE_STATE: print("RECORDING", file=sys.stderr) state = RECORD_STATE @@ -548,20 +533,16 @@ def kbInputThread(shared_data: SharedThreadData): if shared_data.cfg["enable_debug_mode"]: print("Toggle detected, committing preview text (2)", file=sys.stderr) - #audio_state.text += audio_state.preview_text - shared_data.stream.pause(False) + play_sound_with_volume(waveform0, shared_data.cfg) - if shared_data.cfg["enable_local_beep"]: - play_sound_with_volume(waveform0) - -def play_sound_with_volume(filepath): +def play_sound_with_volume(filepath, cfg): """Play a WAV file with adjusted volume""" - volume = VOLUME + volume = cfg.get("volume", 30) try: sound = pygame.mixer.Sound(filepath) - sound.set_volume(volume) + sound.set_volume(volume * 0.01) sound.play() except Exception as e: print(f"Error playing sound {filepath}: {e}", file=sys.stderr) diff --git a/app/requirements.txt b/app/requirements.txt index e68a16c..c8d69df 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -2,6 +2,7 @@ faster-whisper hf-xet keyboard langcodes +noisereduce pyaudio pygame pydub diff --git a/app/stt.py b/app/stt.py index c1f4836..79ab0d1 100644 --- a/app/stt.py +++ b/app/stt.py @@ -3,6 +3,7 @@ from faster_whisper import WhisperModel import langcodes import numpy as np import os +import noisereduce as nr try: from profanity_filter import ProfanityFilter PROFANITY_FILTER_AVAILABLE = True @@ -260,9 +261,13 @@ class NormalizingAudioCollector(AudioCollectorFilter): return frames class BoostingAudioCollector(AudioCollectorFilter): - def __init__(self, parent: AudioCollector, target_dBFS: float, cfg: typing.Dict): + def __init__(self, parent: AudioCollector, + target_dBFS: float, + max_gain_dB: float, + cfg: typing.Dict): AudioCollectorFilter.__init__(self, parent) self.target_dBFS = target_dBFS + self.max_gain_dB = max_gain_dB self.cfg = cfg def getAudio(self) -> bytes: @@ -270,9 +275,10 @@ class BoostingAudioCollector(AudioCollectorFilter): audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ, frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS) + gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB) if self.cfg["enable_debug_mode"]: - print(f"Boosting audio from {audio.dBFS}dB to {self.target_dBFS}dB", file=sys.stderr) - audio = audio.apply_gain(self.target_dBFS - audio.dBFS) + print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True) + audio = audio.apply_gain(gain) frames = np.array(audio.get_array_of_samples()) frames = np.int16(frames).tobytes() @@ -296,6 +302,26 @@ class CompressingAudioCollector(AudioCollectorFilter): return frames +class NoiseReducingAudioCollector(AudioCollectorFilter): + def __init__(self, parent: AudioCollector, cfg: typing.Dict): + AudioCollectorFilter.__init__(self, parent) + self.cfg = cfg + + def getAudio(self) -> bytes: + audio = self.parent.getAudio() + audio_array = np.frombuffer(audio, dtype=np.int16).astype(np.float32) + + reduced_audio = nr.reduce_noise( + y=audio_array, + sr=AudioStream.FPS, + ) + + # Convert back to int16 + reduced_audio = np.clip(reduced_audio, -32768, 32767) + frames = np.int16(reduced_audio).tobytes() + + return frames + class AudioSegmenter: def __init__(self, min_silence_ms=250, @@ -398,6 +424,12 @@ class Segment: avg_logprob = f"(avg_logprob: {self.avg_logprob}) " return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob +def join_segments(a, b): + if len(a) > 0 and a[-1] != ' ': + return a + ' ' + b + else: + return a + b + class Whisper: def __init__(self, collector: AudioCollector, @@ -421,6 +453,9 @@ class Whisper: already_downloaded = os.path.exists(model_root) + if not already_downloaded: + print(f"Model {model_str} not already downloaded, downloading now...", flush=True) + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], @@ -433,10 +468,12 @@ class Whisper: def update_context(self, committed_text: str): """Update the context with recently committed text.""" - self.recent_context = (self.recent_context + " " + committed_text).strip() - # Keep only the last N characters to avoid prompt getting too long + self.recent_context = join_segments(self.recent_context, committed_text).strip() + # Drop half of the context window. if len(self.recent_context) > self.context_window_chars: - self.recent_context = self.recent_context[-self.context_window_chars:] + words = self.recent_context.split() + words = words[len(words)//2:] + self.recent_context = ' '.join(words) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: @@ -449,6 +486,8 @@ class Whisper: # Build context-aware prompt prompt = self._build_prompt() + print(f"Prompt: {prompt}", flush=True) + t0 = time.time() segments, info = self.model.transcribe( audio, @@ -698,8 +737,10 @@ def transcriptionThread(shared_data: SharedThreadData): stream = MicStream(shared_data.cfg) collector = AudioCollector(stream) collector = CompressingAudioCollector(collector) - collector = BoostingAudioCollector(collector, -12.0, shared_data.cfg) - collector = NormalizingAudioCollector(collector) + collector = BoostingAudioCollector(collector, -24.0, 24.0, + shared_data.cfg) + collector = NoiseReducingAudioCollector(collector, shared_data.cfg) + #collector = NormalizingAudioCollector(collector) whisper = Whisper(collector, shared_data.cfg) segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], @@ -761,11 +802,6 @@ def transcriptionThread(shared_data: SharedThreadData): # breaking OSC pager. if len(shared_data.transcript) >= 1024: shared_data.transcript = shared_data.transcript[-512:] - def join_segments(a, b): - if len(a) > 0 and a[-1] != ' ': - return a + ' ' + b - else: - return a + b shared_data.transcript = \ join_segments(shared_data.transcript, commit.delta) shared_data.preview = commit.preview diff --git a/config.yaml b/config.yaml index 6f4b65b..dfa2e1f 100644 --- a/config.yaml +++ b/config.yaml @@ -1,8 +1,8 @@ compute_type: float16 language: english model: turbo -microphone: 1 -user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. Phi, NOPPERS, clearrainbow, Noia, Kuuderekitten. +microphone: 4 +user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm. keybind: ctrl+alt+x button_hand: right button_type: b @@ -18,6 +18,7 @@ rows: 10 cols: 24 beam_size: 5 best_of: 5 +volume: 10 enable_debug_mode: 0 enable_previews: 1 save_audio: 1 @@ -26,6 +27,5 @@ enable_lowercase_filter: 0 enable_uppercase_filter: 0 enable_profanity_filter: 0 remove_trailing_period: 0 -reset_on_toggle: 0 -enable_local_beep: 1 +reset_on_toggle: 1 use_builtin: 1 diff --git a/ui/config-schema.js b/ui/config-schema.js index 6b11277..bf91fce 100644 --- a/ui/config-schema.js +++ b/ui/config-schema.js @@ -23,6 +23,7 @@ const CONFIG_SCHEMA = { cols: { type: 'number', default: 24 }, beam_size: { type: 'number', default: 5 }, best_of: { type: 'number', default: 5 }, + volume: { type: 'number', default: 30 }, // Boolean fields (stored as 1/0) enable_debug_mode: { type: 'boolean', default: 0 }, @@ -34,7 +35,6 @@ const CONFIG_SCHEMA = { enable_profanity_filter: { type: 'boolean', default: 0 }, remove_trailing_period: { type: 'boolean', default: 0 }, reset_on_toggle: { type: 'boolean', default: 0 }, - enable_local_beep: { type: 'boolean', default: 1 }, use_builtin: { type: 'boolean', default: 1 } }; diff --git a/ui/index.html b/ui/index.html index 99e64dd..19c41ce 100644 --- a/ui/index.html +++ b/ui/index.html @@ -248,10 +248,13 @@ Reset transcript on toggle - +
+ + +
@@ -314,7 +317,7 @@ -
diff --git a/ui/index.js b/ui/index.js index 24a7e13..5a5d0a6 100644 --- a/ui/index.js +++ b/ui/index.js @@ -530,19 +530,20 @@ ipcMain.handle('start-process', async () => { }); ipcMain.handle('stop-process', async () => { + if (!runningProcess) { + sendPythonOutput('No process to stop', 'info'); + return { success: true, forcefullyKilled: false }; + } + return new Promise((resolve) => { let forcefullyKilled = false; - - if (!runningProcess) { - resolve({ success: true, forcefullyKilled }); - } // Set up a timeout to force kill after 10 seconds const killTimeout = setTimeout(() => { if (runningProcess) { sendPythonOutput('Process did not stop gracefully, forcing termination...', 'stderr'); forcefullyKilled = true; - runningProcess.kill(); + runningProcess.kill('SIGKILL'); } }, 10000); @@ -562,10 +563,14 @@ ipcMain.handle('stop-process', async () => { // Send termination signal sendPythonOutput('Stopping process gracefully...', 'info'); - runningProcess.kill(); + runningProcess.kill('SIGTERM'); }); }); +ipcMain.handle('get-process-state', () => { + return { isRunning: runningProcess !== null }; +}); + // Clean up on app quit app.on('before-quit', () => { if (runningProcess) { diff --git a/ui/preload.js b/ui/preload.js index f2e0a81..6f6e54f 100644 --- a/ui/preload.js +++ b/ui/preload.js @@ -10,6 +10,7 @@ contextBridge.exposeInMainWorld('electronAPI', { resetVenv: () => ipcRenderer.invoke('reset-venv'), startProcess: () => ipcRenderer.invoke('start-process'), stopProcess: () => ipcRenderer.invoke('stop-process'), + getProcessState: () => ipcRenderer.invoke('get-process-state'), onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)), onProcessStopped: (callback) => ipcRenderer.on('process-stopped', () => callback()) }); diff --git a/ui/renderer.js b/ui/renderer.js index 2f4c8f1..008e0da 100644 --- a/ui/renderer.js +++ b/ui/renderer.js @@ -1,6 +1,21 @@ // Import configuration schema const CONFIG_FIELDS = window.CONFIG_SCHEMA; +// Process state tracking +let isProcessRunning = false; +let buttonManager; +let loadingOverlay; + +// Auto-save functionality with debouncing +let saveTimeout; +const SAVE_DELAY = 500; +let isSettingValues = false; + +// Console management +const consoleContent = document.getElementById('console-content'); +const MAX_CONSOLE_LINES = 512; +let consoleLineCount = 0; + // Button management system class ButtonManager { constructor() { @@ -11,33 +26,30 @@ class ButtonManager { resetVenv: document.getElementById('reset-venv'), refreshMicrophones: document.getElementById('refresh-microphones') }; - - // Initialize button states on construction + + // Initialize button states - process is not running at startup this.setProcessStopped(); } - + setState(buttonName, disabled) { const button = this.buttons[buttonName]; if (!button) return; - + button.disabled = disabled; - if (disabled) { - button.classList.add('opacity-50', 'cursor-not-allowed'); - } else { - button.classList.remove('opacity-50', 'cursor-not-allowed'); - } } - + setProcessRunning() { this.setState('start', true); this.setState('stop', false); + isProcessRunning = true; } - + setProcessStopped() { this.setState('start', false); this.setState('stop', true); + isProcessRunning = false; } - + async withButtonLoading(buttonName, asyncFn) { this.setState(buttonName, true); try { @@ -48,8 +60,6 @@ class ButtonManager { } } -const buttonManager = new ButtonManager(); - // Add loading overlay management class LoadingOverlay { constructor() { @@ -57,8 +67,9 @@ class LoadingOverlay { this.form = document.getElementById('config-form'); this.messageElement = this.overlay.querySelector('p'); this.defaultMessage = 'Environment setup underway - please wait.'; + this.originalStates = new Map(); // Track original disabled states } - + show(message = null) { this.messageElement.textContent = message || this.defaultMessage; this.overlay.classList.remove('hidden'); @@ -66,68 +77,69 @@ class LoadingOverlay { const leftPanel = this.overlay.parentElement; const inputs = leftPanel.querySelectorAll('input, select, textarea, button'); inputs.forEach(input => { + // Store original disabled state before disabling + this.originalStates.set(input, input.disabled); input.disabled = true; input.classList.add('opacity-50'); }); } - + hide() { this.overlay.classList.add('hidden'); - // Re-enable all form inputs and buttons in the entire left panel + // Restore original states of form inputs and buttons const leftPanel = this.overlay.parentElement; const inputs = leftPanel.querySelectorAll('input, select, textarea, button'); inputs.forEach(input => { - input.disabled = false; + // Restore original disabled state + input.disabled = this.originalStates.get(input) || false; input.classList.remove('opacity-50'); }); + // Clear the stored states + this.originalStates.clear(); // Reset to default message this.messageElement.textContent = this.defaultMessage; } } -const loadingOverlay = new LoadingOverlay(); - -// Add a flag to prevent auto-save during programmatic updates -let isSettingValues = false; - // Handle status messages with better color management function showStatus(message, type = 'info') { const statusEl = document.getElementById('status-message'); statusEl.textContent = message; - + // Remove all status classes const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800']; statusEl.classList.remove(...statusClasses); - + // Add appropriate classes based on type const typeMap = { success: ['bg-green-100', 'text-green-800'], error: ['bg-red-100', 'text-red-800'], info: ['bg-blue-100', 'text-blue-800'] }; - + statusEl.classList.add(...(typeMap[type] || typeMap.info)); - + // Also log to console appendToConsole(message, type === 'error' ? 'stderr' : 'info'); - + setTimeout(() => statusEl.classList.add('hidden'), 5000); } // Get form values using field mappings function getFormValues() { const config = {}; - + for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) { const element = document.getElementById(fieldName); if (!element) continue; - + switch (fieldConfig.type) { case 'boolean': config[fieldName] = element.checked ? 1 : 0; break; case 'number': - config[fieldName] = parseInt(element.value) || fieldConfig.default; + const numValue = parseInt(element.value); + config[fieldName] = isNaN(numValue) ? fieldConfig.default : numValue; break; case 'text': config[fieldName] = element.value || fieldConfig.default; @@ -136,20 +148,20 @@ function getFormValues() { config[fieldName] = element.value || fieldConfig.default; } } - + return config; } // Set form values using field mappings function setFormValues(config) { isSettingValues = true; // Disable auto-save temporarily - + for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) { const element = document.getElementById(fieldName); if (!element) continue; - + const value = config[fieldName] ?? fieldConfig.default; - + switch (fieldConfig.type) { case 'boolean': element.checked = value === 1; @@ -161,7 +173,7 @@ function setFormValues(config) { element.value = value; } } - + // Handle use_builtin toggle state const useBuiltin = config.use_builtin === 1; const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols']; @@ -176,53 +188,54 @@ function setFormValues(config) { } } }); - + + // Update volume display + if (config.volume !== undefined) { + const volumePercent = Math.round(config.volume); + document.getElementById('volume-display').textContent = `${volumePercent}%`; + } + isSettingValues = false; // Re-enable auto-save } -// Console management -const consoleContent = document.getElementById('console-content'); -const MAX_CONSOLE_LINES = 512; -let consoleLineCount = 0; - function appendToConsole(message, type = 'stdout') { const timestamp = new Date().toLocaleTimeString(); const timestampSpan = document.createElement('span'); timestampSpan.className = 'console-timestamp'; timestampSpan.textContent = `[${timestamp}] `; - + const messageSpan = document.createElement('span'); messageSpan.className = `console-${type}`; messageSpan.textContent = message; - + const lineDiv = document.createElement('div'); lineDiv.appendChild(timestampSpan); lineDiv.appendChild(messageSpan); - + consoleContent.appendChild(lineDiv); consoleLineCount++; - + // Remove old lines if we exceed the limit if (consoleLineCount > MAX_CONSOLE_LINES) { // Calculate how many lines to remove (remove 10% to avoid frequent trimming) const linesToRemove = Math.floor(MAX_CONSOLE_LINES * 0.1); - + // Remove the oldest lines for (let i = 0; i < linesToRemove; i++) { if (consoleContent.firstChild) { consoleContent.removeChild(consoleContent.firstChild); } } - + consoleLineCount -= linesToRemove; - + // Add a notice that lines were trimmed const trimNotice = document.createElement('div'); trimNotice.className = 'console-info'; trimNotice.innerHTML = '[System] ... older lines removed to maintain performance ...'; consoleContent.insertBefore(trimNotice, consoleContent.firstChild); } - + // Auto-scroll to bottom const pythonConsole = document.getElementById('python-console'); pythonConsole.scrollTop = pythonConsole.scrollHeight; @@ -242,24 +255,20 @@ async function handleAsyncAction(actionName, actionFn) { } } -// Auto-save functionality with debouncing -let saveTimeout; -const SAVE_DELAY = 500; - async function autoSaveConfig() { if (isSettingValues) return; - + clearTimeout(saveTimeout); saveTimeout = setTimeout(async () => { try { const config = getFormValues(); await window.electronAPI.saveConfig(config); showStatus('Configuration saved', 'success'); - + // Restart process if running - if (!buttonManager.buttons.stop.disabled) { + if (isProcessRunning) { appendToConsole('Restarting process with new configuration...', 'info'); - + try { await window.electronAPI.stopProcess(); await new Promise(resolve => setTimeout(resolve, 1000)); @@ -281,9 +290,9 @@ async function autoSaveConfig() { function setupAutoSave() { const form = document.getElementById('config-form'); const inputs = form.querySelectorAll('input, select, textarea'); - + inputs.forEach(input => { - const eventType = input.type === 'checkbox' ? 'change' : + const eventType = input.type === 'checkbox' ? 'change' : (input.type === 'number' || input.type === 'text' || input.tagName === 'TEXTAREA') ? 'input' : 'change'; input.addEventListener(eventType, autoSaveConfig); }); @@ -292,7 +301,7 @@ function setupAutoSave() { // Microphone loading async function loadMicrophones() { const microphoneSelect = document.getElementById('microphone'); - + try { // Check/install requirements during startup appendToConsole('Checking virtual environment and requirements...', 'info'); @@ -305,15 +314,15 @@ async function loadMicrophones() { appendToConsole('Loading available microphones...', 'info'); const microphones = await window.electronAPI.getMicrophones(); - + microphoneSelect.innerHTML = ''; - + if (microphones.length === 0) { microphoneSelect.innerHTML = ''; appendToConsole('No microphones found', 'stderr'); return; } - + appendToConsole(`Found ${microphones.length} microphone(s)`, 'info'); microphones.forEach(mic => { const option = document.createElement('option'); @@ -322,7 +331,7 @@ async function loadMicrophones() { microphoneSelect.appendChild(option); appendToConsole(` - ${mic.name} (Device ${mic.index})`, 'stdout'); }); - + // Restore previously selected microphone try { const config = await window.electronAPI.loadConfig(); @@ -332,7 +341,7 @@ async function loadMicrophones() { } catch (error) { // Ignore config load errors here } - + } catch (error) { appendToConsole(`Failed to load microphones: ${error.message}`, 'stderr'); microphoneSelect.innerHTML = ''; @@ -345,7 +354,7 @@ function setupEventHandlers() { document.getElementById('toggle-advanced').addEventListener('click', () => { const advancedSettings = document.getElementById('advanced-settings'); const chevron = document.getElementById('chevron'); - + if (advancedSettings.classList.contains('hidden')) { advancedSettings.classList.remove('hidden'); chevron.classList.add('rotate-90'); @@ -354,12 +363,12 @@ function setupEventHandlers() { chevron.classList.remove('rotate-90'); } }); - + // Use builtin chatbox toggle document.getElementById('use_builtin').addEventListener('change', (e) => { const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols']; const isBuiltin = e.target.checked; - + customChatboxInputs.forEach(inputId => { const input = document.getElementById(inputId); if (input) { @@ -372,7 +381,13 @@ function setupEventHandlers() { } }); }); - + + // Volume slider update + document.getElementById('volume').addEventListener('input', (e) => { + const volumePercent = Math.round(e.target.value); + document.getElementById('volume-display').textContent = `${volumePercent}%`; + }); + // Setup virtual environment document.getElementById('setup-venv').addEventListener('click', async () => { loadingOverlay.show('Setting up virtual environment - please wait...'); // Show overlay with custom message @@ -385,7 +400,7 @@ function setupEventHandlers() { loadingOverlay.hide(); // Always hide overlay when done } }); - + // Reset virtual environment document.getElementById('reset-venv').addEventListener('click', async () => { loadingOverlay.show('Resetting virtual environment - please wait...'); // Show overlay with custom message @@ -397,33 +412,33 @@ function setupEventHandlers() { loadingOverlay.hide(); // Always hide overlay when done } }); - + // Reset configuration document.getElementById('reset-config').addEventListener('click', async () => { const confirmReset = confirm('Are you sure you want to reset all settings to defaults? This cannot be undone.'); if (!confirmReset) return; - + try { // Stop process if running - const wasRunning = !buttonManager.buttons.stop.disabled; + const wasRunning = isProcessRunning; if (wasRunning) { appendToConsole('Stopping process before resetting configuration...', 'info'); await window.electronAPI.stopProcess(); buttonManager.setProcessStopped(); await new Promise(resolve => setTimeout(resolve, 500)); } - + // Reset configuration appendToConsole('Resetting configuration to defaults...', 'info'); const result = await window.electronAPI.resetConfig(); - + // Reload configuration with defaults const config = await window.electronAPI.loadConfig(); setFormValues(config); - + showStatus(result.message, 'success'); appendToConsole('Configuration reset successfully', 'info'); - + // Restart process if it was running if (wasRunning) { appendToConsole('Restarting process with default configuration...', 'info'); @@ -436,18 +451,18 @@ function setupEventHandlers() { appendToConsole(`Failed to reset configuration: ${error.message}`, 'stderr'); } }); - + // Refresh microphones document.getElementById('refresh-microphones').addEventListener('click', async () => { await buttonManager.withButtonLoading('refreshMicrophones', async () => { await loadMicrophones(); }); }); - + // Start process document.getElementById('start-process').addEventListener('click', async () => { buttonManager.setState('start', true); - + try { // The installRequirements function will now check if venv is set up. loadingOverlay.show('Verifying environment setup - please wait...'); // Show overlay with custom message @@ -457,7 +472,7 @@ function setupEventHandlers() { } finally { loadingOverlay.hide(); // Always hide overlay when done } - + await window.electronAPI.startProcess(); buttonManager.setProcessRunning(); appendToConsole('Process started successfully', 'info'); @@ -466,11 +481,11 @@ function setupEventHandlers() { buttonManager.setState('start', false); } }); - + // Stop process document.getElementById('stop-process').addEventListener('click', async () => { buttonManager.setState('stop', true); - + try { await window.electronAPI.stopProcess(); appendToConsole('Process stop initiated', 'info'); @@ -479,7 +494,7 @@ function setupEventHandlers() { buttonManager.setState('stop', false); } }); - + // Listen for process stopped event window.electronAPI.onProcessStopped(() => { buttonManager.setProcessStopped(); @@ -489,12 +504,15 @@ function setupEventHandlers() { // Initialize application window.addEventListener('load', async () => { appendToConsole('TaSTT Configuration UI initialized', 'info'); - + + loadingOverlay = new LoadingOverlay(); + buttonManager = new ButtonManager(); + // Set up Python output listener first so we capture all output window.electronAPI.onPythonOutput((data) => { appendToConsole(data.message, data.type); }); - + // Load configuration try { const config = await window.electronAPI.loadConfig(); @@ -503,11 +521,11 @@ window.addEventListener('load', async () => { } catch (error) { appendToConsole(`Failed to load configuration: ${error.message}`, 'stderr'); } - + // Load microphones await loadMicrophones(); - + // Setup event handlers and auto-save setupEventHandlers(); setupAutoSave(); -}); \ No newline at end of file +}); -- cgit v1.2.3 From 9bf33a4cad8196bfe7253c841ab5c35ffdbc0173 Mon Sep 17 00:00:00 2001 From: yum Date: Wed, 23 Jul 2025 19:05:15 -0700 Subject: add segment metadata logging feature Segment metadata can now be logged to a json as the app runs. The goal is to identify the params that heavily correlate with hallucinations. Also: * use 7zip for compression in build, speeding things up * log dll download progress every few seconds * shrink package --- app/stt.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++--- config.yaml | 1 + ui/.gitignore | 2 ++ ui/config-schema.js | 1 + ui/index.html | 6 ++++- ui/index.js | 28 +++++++++++++++++++-- ui/package.json | 35 ++++++++++++++++---------- 7 files changed, 126 insertions(+), 19 deletions(-) diff --git a/app/stt.py b/app/stt.py index 79ab0d1..f36de97 100644 --- a/app/stt.py +++ b/app/stt.py @@ -1,5 +1,6 @@ from datetime import datetime from faster_whisper import WhisperModel +import json import langcodes import numpy as np import os @@ -486,7 +487,8 @@ class Whisper: # Build context-aware prompt prompt = self._build_prompt() - print(f"Prompt: {prompt}", flush=True) + if self.cfg["enable_debug_mode"]: + print(f"Prompt: {prompt}", flush=True) t0 = time.time() segments, info = self.model.transcribe( @@ -578,16 +580,69 @@ def saveAudio(audio: bytes, path: str, cfg: typing.Dict): wf.writeframes(audio) +class SegmentLogger: + def __init__(self, cfg: typing.Dict): + self.cfg = cfg + self.enabled = cfg.get("enable_segment_logging", False) + self.session_data = [] + self.log_file = None + + if self.enabled: + log_dir = os.path.join(PROJECT_ROOT, "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # Create file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json") + print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr) + + def log_segment(self, segment: Segment, commit_type: str = "commit"): + if not self.enabled: + return + + segment_data = { + "timestamp": datetime.now().isoformat(), + "type": commit_type, + "text": segment.transcript, + "start_ts": segment.start_ts, + "end_ts": segment.end_ts, + "wall_ts": segment.wall_ts, + "avg_logprob": segment.avg_logprob, + "no_speech_prob": segment.no_speech_prob, + "compression_ratio": segment.compression_ratio, + "duration": segment.end_ts - segment.start_ts + } + + self.session_data.append(segment_data) + + # Write to file incrementally + try: + with open(self.log_file, 'w') as f: + json.dump({ + "session_start": self.session_data[0]["timestamp"] if self.session_data else None, + "segments": self.session_data + }, f, indent=2) + except Exception as e: + print(f"Error writing segment log: {e}", file=sys.stderr) + + def close(self): + if self.enabled and self.session_data: + print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr) + + class VadCommitter: def __init__(self, cfg: typing.Dict, collector: AudioCollector, whisper: Whisper, - segmenter: AudioSegmenter): + segmenter: AudioSegmenter, + segment_logger: SegmentLogger = None): self.cfg = cfg self.collector = collector self.whisper = whisper self.segmenter = segmenter + self.segment_logger = segment_logger def getDelta(self) -> TranscriptCommit: audio = self.collector.getAudio() @@ -618,6 +673,10 @@ class VadCommitter: if delta.strip(): self.whisper.update_context(delta.strip()) + if self.segment_logger: + for s in segments: + self.segment_logger.log_segment(s, "commit") + audio = self.collector.getAudio() if self.cfg["enable_debug_mode"]: for s in segments: @@ -638,6 +697,10 @@ class VadCommitter: segments = self.whisper.transcribe(audio) preview = "".join(s.transcript for s in segments) + if self.segment_logger: + for s in segments: + self.segment_logger.log_segment(s, "preview") + if not has_audio: self.collector.keepLast(1.0) @@ -745,7 +808,9 @@ def transcriptionThread(shared_data: SharedThreadData): segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"], max_speech_s=shared_data.cfg["max_speech_duration_s"], min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"]) - committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter) + + segment_logger = SegmentLogger(shared_data.cfg) + committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger) plugins = [] # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet @@ -839,4 +904,5 @@ def transcriptionThread(shared_data: SharedThreadData): plugin.stop() for filt in filters: filt.stop() + segment_logger.close() diff --git a/config.yaml b/config.yaml index dfa2e1f..db25405 100644 --- a/config.yaml +++ b/config.yaml @@ -22,6 +22,7 @@ volume: 10 enable_debug_mode: 0 enable_previews: 1 save_audio: 1 +enable_segment_logging: 0 use_cpu: 0 enable_lowercase_filter: 0 enable_uppercase_filter: 0 diff --git a/ui/.gitignore b/ui/.gitignore index 2109e19..c1dbe3c 100644 --- a/ui/.gitignore +++ b/ui/.gitignore @@ -1,3 +1,5 @@ build node_modules package-lock.json +output.css +dist diff --git a/ui/config-schema.js b/ui/config-schema.js index bf91fce..fb90f3f 100644 --- a/ui/config-schema.js +++ b/ui/config-schema.js @@ -29,6 +29,7 @@ const CONFIG_SCHEMA = { enable_debug_mode: { type: 'boolean', default: 0 }, enable_previews: { type: 'boolean', default: 1 }, save_audio: { type: 'boolean', default: 0 }, + enable_segment_logging: { type: 'boolean', default: 0 }, use_cpu: { type: 'boolean', default: 0 }, enable_lowercase_filter: { type: 'boolean', default: 0 }, enable_uppercase_filter: { type: 'boolean', default: 0 }, diff --git a/ui/index.html b/ui/index.html index 19c41ce..29d4a78 100644 --- a/ui/index.html +++ b/ui/index.html @@ -4,7 +4,7 @@ TaSTT - +
@@ -214,6 +214,10 @@ Save Audio Segments +
diff --git a/ui/index.js b/ui/index.js index 5a5d0a6..afaaf7f 100644 --- a/ui/index.js +++ b/ui/index.js @@ -6,7 +6,12 @@ const { spawn } = require('child_process'); const https = require('https'); const { CONFIG_SCHEMA, getDefaultConfig } = require('./config-schema.js'); -const APP_ROOT = path.join(__dirname, '..'); +// Detect if we're running in development or production +const isDev = !app.isPackaged; +const APP_ROOT = isDev + ? path.join(__dirname, '..') // Development: go up from ui/ to project root + : process.resourcesPath; // Production: use Electron's resource path + const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml'); let mainWindow; @@ -50,13 +55,32 @@ function createPythonEnvironment() { return env; } -// Helper function to download a file from URL +// Helper function to download a file from URL with progress function downloadFile(url, outputPath) { return new Promise((resolve, reject) => { const file = require('fs').createWriteStream(outputPath); + const fileName = path.basename(outputPath); const request = https.get(url, (response) => { if (response.statusCode === 200) { + const totalSize = parseInt(response.headers['content-length'], 10); + let downloadedSize = 0; + let lastProgressTime = Date.now(); + + response.on('data', (chunk) => { + downloadedSize += chunk.length; + + // Log progress every 5 seconds + const now = Date.now(); + if (totalSize && (now - lastProgressTime >= 5000)) { + const progress = Math.round((downloadedSize / totalSize) * 100); + const mb = (downloadedSize / 1024 / 1024).toFixed(1); + const totalMb = (totalSize / 1024 / 1024).toFixed(1); + sendPythonOutput(`Downloading ${fileName}: ${mb}/${totalMb} MB (${progress}%)`, 'info'); + lastProgressTime = now; + } + }); + response.pipe(file); file.on('finish', () => { diff --git a/ui/package.json b/ui/package.json index 3a58298..4742cd7 100644 --- a/ui/package.json +++ b/ui/package.json @@ -6,14 +6,16 @@ "homepage": "./", "scripts": { "start": "npm run build:css && electron .", - "build:css": "tailwindcss -i ./src/components.css -o ./build/output.css", - "watch:css": "tailwindcss -i ./src/components.css -o ./build/output.css --watch", + "build:css": "tailwindcss -i ./src/components.css -o ./output.css", + "watch:css": "tailwindcss -i ./src/components.css -o ./output.css --watch", "dev": "concurrently \"npm run watch:css\" \"electron .\"", "test": "echo \"Error: no test specified\" && exit 1", - "dist": "npm run build:css && electron-builder", - "dist:win": "npm run build:css && electron-builder --win", - "dist:portable": "npm run build:css && electron-builder --win portable", - "dist:zip": "npm run build:css && electron-builder --win zip" + "clean:meta": "node -e \"const fs=require('fs');const path=require('path');function deleteMeta(dir){fs.readdirSync(dir).forEach(f=>{const p=path.join(dir,f);if(f.endsWith('.meta'))fs.unlinkSync(p);else if(fs.statSync(p).isDirectory()&&!f.startsWith('.'))deleteMeta(p);})}deleteMeta('./node_modules')\"", + "prebuild": "node build_scripts/setup-empty-venv.js", + "dist": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder", + "dist:win": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win", + "dist:portable": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win portable", + "dist:zip": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win zip" }, "build": { "appId": "com.yum_food.tastt", @@ -46,11 +48,6 @@ "from": "../config.yaml", "to": "config.yaml" }, - { - "from": "../dll", - "to": "dll", - "filter": ["**/*"] - }, { "from": "../Images", "to": "Images", @@ -60,10 +57,20 @@ "from": "../bin", "to": "bin", "filter": ["**/*"] + }, + { + "from": "../venv_clean", + "to": "venv", + "filter": ["**/*"] + }, + { + "from": "../dll_empty", + "to": "dll", + "filter": ["**/*"] } ], "win": { - "icon": "../Images/logo.png", + "icon": "../Images/favicon.ico", "target": [ { "target": "portable", @@ -81,7 +88,9 @@ "nsis": { "oneClick": false, "allowToChangeInstallationDirectory": true - } + }, + "compression": "maximum", + "artifactName": "${productName}-${version}-${arch}.${ext}" }, "keywords": [], "author": "yum_food", -- cgit v1.2.3 From e1730a63538d2b1a23c948d25580612303733eba Mon Sep 17 00:00:00 2001 From: yum Date: Wed, 23 Jul 2025 19:51:35 -0700 Subject: Update avg_logprob cutoff, fix sounds, fix electron build --- app/stt.py | 9 ++- ui/build_scripts/setup-empty-venv.js | 25 +++++++ ui/index.html | 2 +- ui/index.js | 136 +++++++++++++++++------------------ ui/package.json | 7 +- 5 files changed, 108 insertions(+), 71 deletions(-) create mode 100644 ui/build_scripts/setup-empty-venv.js diff --git a/app/stt.py b/app/stt.py index f36de97..b476ac0 100644 --- a/app/stt.py +++ b/app/stt.py @@ -523,6 +523,13 @@ class Whisper: f"no_speech_prob={s.no_speech_prob}, " + f"avg_logprob={s.avg_logprob})", file=sys.stderr) continue + if s.avg_logprob < -0.75: + if self.cfg["enable_debug_mode"]: + print(f"Drop probable hallucination (case 3) " + + f"(text='{s.text}', " + + f"no_speech_prob={s.no_speech_prob}, " + + f"avg_logprob={s.avg_logprob})", file=sys.stderr) + continue if self.cfg["enable_debug_mode"]: print(f"s get: {s}") if s.avg_logprob < -1.0: @@ -686,7 +693,7 @@ class VadCommitter: if self.cfg["save_audio"] and len(delta) > 0: ts = datetime.fromtimestamp(self.collector.now() - latency_s) - filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + ".wav" + filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + delta.strip() + ".wav" audio_dir = os.path.join(PROJECT_ROOT, "audio") if not os.path.exists(audio_dir): os.makedirs(audio_dir) diff --git a/ui/build_scripts/setup-empty-venv.js b/ui/build_scripts/setup-empty-venv.js new file mode 100644 index 0000000..0691a51 --- /dev/null +++ b/ui/build_scripts/setup-empty-venv.js @@ -0,0 +1,25 @@ +const { execSync } = require('child_process'); +const path = require('path'); +const fs = require('fs'); + +const projectRoot = path.join(__dirname, '..', '..'); +const venvPath = path.join(projectRoot, 'venv_clean'); +const dllPath = path.join(projectRoot, 'dll_empty'); + +console.log('Creating empty virtual environment and dll directory...'); + +// Create empty dll directory +if (!fs.existsSync(dllPath)) { + fs.mkdirSync(dllPath, { recursive: true }); + console.log('Created empty dll directory'); +} + +try { + console.log('Creating new venv...'); + execSync(`python -m venv "${venvPath}"`, { stdio: 'inherit' }); + console.log('Empty venv created successfully!'); +} catch (error) { + console.error('Failed to create venv:', error); + process.exit(1); +} + diff --git a/ui/index.html b/ui/index.html index 29d4a78..70eaa68 100644 --- a/ui/index.html +++ b/ui/index.html @@ -216,7 +216,7 @@
diff --git a/ui/index.js b/ui/index.js index afaaf7f..63c633a 100644 --- a/ui/index.js +++ b/ui/index.js @@ -8,7 +8,7 @@ const { CONFIG_SCHEMA, getDefaultConfig } = require('./config-schema.js'); // Detect if we're running in development or production const isDev = !app.isPackaged; -const APP_ROOT = isDev +const APP_ROOT = isDev ? path.join(__dirname, '..') // Development: go up from ui/ to project root : process.resourcesPath; // Production: use Electron's resource path @@ -60,16 +60,16 @@ function downloadFile(url, outputPath) { return new Promise((resolve, reject) => { const file = require('fs').createWriteStream(outputPath); const fileName = path.basename(outputPath); - + const request = https.get(url, (response) => { if (response.statusCode === 200) { const totalSize = parseInt(response.headers['content-length'], 10); let downloadedSize = 0; let lastProgressTime = Date.now(); - + response.on('data', (chunk) => { downloadedSize += chunk.length; - + // Log progress every 5 seconds const now = Date.now(); if (totalSize && (now - lastProgressTime >= 5000)) { @@ -80,14 +80,14 @@ function downloadFile(url, outputPath) { lastProgressTime = now; } }); - + response.pipe(file); - + file.on('finish', () => { file.close(); resolve(); }); - + file.on('error', (err) => { fs.unlink(outputPath).catch(() => {}); // Clean up on error reject(err); @@ -98,7 +98,7 @@ function downloadFile(url, outputPath) { reject(new Error(`Failed to download: HTTP ${response.statusCode}`)); } }); - + request.on('error', (err) => { file.close(); fs.unlink(outputPath).catch(() => {}); // Clean up on error @@ -121,14 +121,14 @@ function setupProcessHandlers(process) { const text = data.toString(); sendPythonOutput(text.trimEnd(), 'stdout'); }); - + process.stderr.on('data', (data) => { const text = data.toString(); if (!shouldFilterMessage(text)) { sendPythonOutput(text.trimEnd(), 'stderr'); } }); - + process.on('error', (error) => { sendPythonOutput(`Process error: ${error.message}`, 'stderr'); runningProcess = null; @@ -136,7 +136,7 @@ function setupProcessHandlers(process) { mainWindow.webContents.send('process-stopped'); } }); - + process.on('close', (code) => { sendPythonOutput(`Process exited with code ${code}`, 'info'); runningProcess = null; @@ -152,23 +152,23 @@ function executePythonCommand(args, options = {}) { const pythonPath = getVenvPython(); const commandStr = `${path.basename(pythonPath)} ${args.join(' ')}`; sendPythonOutput(`> ${commandStr}`, 'info'); - + const spawnOptions = { ...options, env: createPythonEnvironment() }; - + const pythonProcess = spawn(pythonPath, args, spawnOptions); - + let stdout = ''; let stderr = ''; - + pythonProcess.stdout.on('data', (data) => { const text = data.toString(); stdout += text; sendPythonOutput(text.trimEnd(), 'stdout'); }); - + pythonProcess.stderr.on('data', (data) => { const text = data.toString(); stderr += text; @@ -177,12 +177,12 @@ function executePythonCommand(args, options = {}) { sendPythonOutput(text.trimEnd(), 'stderr'); } }); - + pythonProcess.on('error', (error) => { sendPythonOutput(`Failed to start Python process: ${error.message}`, 'stderr'); reject({ error: error.message, stdout, stderr }); }); - + pythonProcess.on('close', (code) => { if (code !== 0) { sendPythonOutput(`Process exited with code ${code}`, 'stderr'); @@ -287,15 +287,15 @@ ipcMain.handle('deleteVenvIndicatorFile', async () => { // Generic function to ensure required files are present async function ensureRequiredFiles(config) { - const { - directoryName, - requiredFiles, - downloadBaseUrl, - resourceType + const { + directoryName, + requiredFiles, + downloadBaseUrl, + resourceType } = config; - + const targetPath = path.join(APP_ROOT, directoryName); - + try { // Check if target directory exists, create it if not try { @@ -310,7 +310,7 @@ async function ensureRequiredFiles(config) { throw error; } } - + // Check each required file const missingFiles = []; for (const fileName of requiredFiles) { @@ -327,15 +327,15 @@ async function ensureRequiredFiles(config) { } } } - + // Download missing files if (missingFiles.length > 0) { sendPythonOutput(`Downloading ${missingFiles.length} missing ${resourceType} file${missingFiles.length > 1 ? 's' : ''}...`, 'info'); - + for (const fileName of missingFiles) { const filePath = path.join(targetPath, fileName); const downloadUrl = `${downloadBaseUrl}/${fileName}`; - + try { sendPythonOutput(`Downloading ${fileName}...`, 'info'); await downloadFile(downloadUrl, filePath); @@ -345,14 +345,14 @@ async function ensureRequiredFiles(config) { throw new Error(`Failed to download ${fileName}: ${downloadError.message}`); } } - + sendPythonOutput(`All missing ${resourceType} files downloaded successfully`, 'info'); } else { sendPythonOutput(`All required ${resourceType} files are present`, 'info'); } - - return { - success: true, + + return { + success: true, message: `${resourceType} setup complete. ${missingFiles.length} file${missingFiles.length > 1 ? 's' : ''} downloaded.`, downloadedFiles: missingFiles }; @@ -366,7 +366,7 @@ async function ensureRequiredFiles(config) { ipcMain.handle('install-requirements', async () => { const requirementsPath = path.join(APP_ROOT, 'app', 'requirements.txt'); const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up'); - + try { // Check if venv is already set up try { @@ -375,10 +375,10 @@ ipcMain.handle('install-requirements', async () => { } catch (error) { // Marker doesn't exist, proceed with setup } - + // Check if requirements.txt exists await fs.access(requirementsPath); - + await executePythonCommand(['-m', 'pip', 'install', '-r', requirementsPath]); await ensureRequiredFiles({ @@ -389,10 +389,10 @@ ipcMain.handle('install-requirements', async () => { }); await fs.mkdir(path.join(APP_ROOT, 'Models'), { recursive: true }); - + await fs.writeFile(venvMarkerPath, new Date().toISOString(), 'utf8'); sendPythonOutput('Created .venv_is_set_up marker file', 'info'); - + return { success: true, message: 'Requirements and dependencies installed successfully' }; } catch (error) { console.error('Error installing requirements:', error); @@ -405,7 +405,7 @@ ipcMain.handle('install-requirements', async () => { ipcMain.handle('get-microphones', async () => { const scriptPath = path.join(APP_ROOT, 'app', 'list_microphones.py'); - + try { const result = await executePythonCommand([scriptPath]); const microphones = JSON.parse(result.stdout.trim()); @@ -421,24 +421,24 @@ async function clearDirectory(dirPath, dirName) { try { await fs.access(dirPath); sendPythonOutput(`Clearing ${dirName} directory...`, 'info'); - + const files = await fs.readdir(dirPath); let deletedCount = 0; - + for (const file of files) { const filePath = path.join(dirPath, file); - + try { await fs.rm(filePath, { recursive: true, force: true }); sendPythonOutput(`✗ Deleted file ${file}`, 'info'); - + deletedCount++; } catch (deleteError) { sendPythonOutput(`Warning: Could not delete ${file}: ${deleteError.message}`, 'stderr'); // Continue with other files even if one fails } } - + sendPythonOutput(`${dirName} directory cleared`, 'info'); return deletedCount; } catch (error) { @@ -454,10 +454,10 @@ async function clearDirectory(dirPath, dirName) { ipcMain.handle('reset-venv', async () => { const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up'); - + try { sendPythonOutput('Starting virtual environment reset...', 'info'); - + // Delete the venv marker file first try { await fs.unlink(venvMarkerPath); @@ -467,14 +467,14 @@ ipcMain.handle('reset-venv', async () => { sendPythonOutput(`Warning: Could not delete marker file: ${error.message}`, 'stderr'); } } - + // Get list of installed packages sendPythonOutput('Getting list of installed packages...', 'info'); const freezeResult = await executePythonCommand(['-m', 'pip', 'freeze']); const installedPackages = freezeResult.stdout.trim(); - + let uninstalledPackages = []; - + if (!installedPackages) { sendPythonOutput('No packages found to uninstall', 'info'); } else { @@ -483,38 +483,38 @@ ipcMain.handle('reset-venv', async () => { const packageNames = packageLines .map(line => line.split('==')[0].trim()) .filter(name => name && !name.startsWith('#')); - + const corePackages = ['pip', 'setuptools', 'wheel']; const packagesToUninstall = packageNames.filter(name => !corePackages.includes(name.toLowerCase())); - + if (packagesToUninstall.length === 0) { sendPythonOutput('Only core packages found, nothing to uninstall', 'info'); } else { sendPythonOutput(`Uninstalling ${packagesToUninstall.length} packages...`, 'info'); - + const uninstallArgs = ['-m', 'pip', 'uninstall', '-y', ...packagesToUninstall]; await executePythonCommand(uninstallArgs); uninstalledPackages = packagesToUninstall; } } - + // Clear downloaded files sendPythonOutput('Clearing downloaded files...', 'info'); - + const dllPath = path.join(APP_ROOT, 'dll'); const modelsPath = path.join(APP_ROOT, 'Models'); const binPath = path.join(APP_ROOT, 'bin'); - + const deletedDlls = await clearDirectory(dllPath, 'DLL'); const deletedModels = await clearDirectory(modelsPath, 'Models'); const deletedBins = await clearDirectory(binPath, 'Binary'); - + const totalDeletedFiles = deletedDlls + deletedModels + deletedBins; - + sendPythonOutput('Virtual environment reset successfully!', 'info'); - - return { - success: true, + + return { + success: true, message: `Virtual environment reset complete. Uninstalled ${uninstalledPackages.length} packages and deleted ${totalDeletedFiles} downloaded files.`, uninstalledPackages, deletedFiles: { @@ -538,14 +538,14 @@ ipcMain.handle('start-process', async () => { const scriptPath = path.join(APP_ROOT, 'app', 'hi.py'); const args = [scriptPath, '--config', CONFIG_PATH]; - + try { const pythonPath = getVenvPython(); sendPythonOutput(`Starting process: ${path.basename(pythonPath)} ${args.join(' ')}`, 'info'); - + runningProcess = spawn(pythonPath, args, { env: createPythonEnvironment() }); setupProcessHandlers(runningProcess); - + return { success: true }; } catch (error) { runningProcess = null; @@ -561,7 +561,7 @@ ipcMain.handle('stop-process', async () => { return new Promise((resolve) => { let forcefullyKilled = false; - + // Set up a timeout to force kill after 10 seconds const killTimeout = setTimeout(() => { if (runningProcess) { @@ -570,21 +570,21 @@ ipcMain.handle('stop-process', async () => { runningProcess.kill('SIGKILL'); } }, 10000); - + // Listen for the process to exit runningProcess.once('exit', (code, signal) => { clearTimeout(killTimeout); runningProcess = null; - + if (forcefullyKilled) { sendPythonOutput('Process forcefully terminated', 'info'); } else { sendPythonOutput('Process stopped gracefully', 'info'); } - + resolve({ success: true, forcefullyKilled }); }); - + // Send termination signal sendPythonOutput('Stopping process gracefully...', 'info'); runningProcess.kill('SIGTERM'); diff --git a/ui/package.json b/ui/package.json index 4742cd7..d99424c 100644 --- a/ui/package.json +++ b/ui/package.json @@ -67,6 +67,11 @@ "from": "../dll_empty", "to": "dll", "filter": ["**/*"] + }, + { + "from": "../Sounds", + "to": "Sounds", + "filter": ["*.wav"] } ], "win": { @@ -89,7 +94,7 @@ "oneClick": false, "allowToChangeInstallationDirectory": true }, - "compression": "maximum", + "compression": "normal", "artifactName": "${productName}-${version}-${arch}.${ext}" }, "keywords": [], -- cgit v1.2.3 From 043a447133695bfd2285a534b941db972873a692 Mon Sep 17 00:00:00 2001 From: yum Date: Wed, 23 Jul 2025 20:43:25 -0700 Subject: Set target loudness to -16, and enable segment metadata logging by default --- app/stt.py | 2 +- config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/stt.py b/app/stt.py index b476ac0..18f0f60 100644 --- a/app/stt.py +++ b/app/stt.py @@ -807,7 +807,7 @@ def transcriptionThread(shared_data: SharedThreadData): stream = MicStream(shared_data.cfg) collector = AudioCollector(stream) collector = CompressingAudioCollector(collector) - collector = BoostingAudioCollector(collector, -24.0, 24.0, + collector = BoostingAudioCollector(collector, -16.0, 24.0, shared_data.cfg) collector = NoiseReducingAudioCollector(collector, shared_data.cfg) #collector = NormalizingAudioCollector(collector) diff --git a/config.yaml b/config.yaml index db25405..9cec4a3 100644 --- a/config.yaml +++ b/config.yaml @@ -22,7 +22,7 @@ volume: 10 enable_debug_mode: 0 enable_previews: 1 save_audio: 1 -enable_segment_logging: 0 +enable_segment_logging: 1 use_cpu: 0 enable_lowercase_filter: 0 enable_uppercase_filter: 0 -- cgit v1.2.3