summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-07-23 22:39:45 -0700
committeryum <yum.food.vr@gmail.com>2025-07-23 22:39:45 -0700
commitf6b93a20d754579008076e85f5c0a97e1bcbc258 (patch)
tree7288699d6f22e76c4f30636a37e94265b3ef7708
parentf3782c200c9a2ec2b77708da67b4127a38465ad1 (diff)
parent043a447133695bfd2285a534b941db972873a692 (diff)
Import FastTextPager repo
-rw-r--r--.cursorignore2
-rw-r--r--.gitignore5
-rw-r--r--Designs/fast_text_paging.md321
-rw-r--r--GenerateTextAnimator.cs200
-rw-r--r--Images/favicon.icobin0 -> 92015 bytes
-rw-r--r--Images/unigram_lut_for_visualization.pngbin0 -> 489395 bytes
-rw-r--r--LICENSE2
m---------Third_Party/Profanity0
-rw-r--r--app/app_config.py39
-rw-r--r--app/hi.py580
-rw-r--r--app/keybind_event_machine.py21
-rw-r--r--app/list_microphones.py24
-rw-r--r--app/profanity_filter.py43
-rw-r--r--app/requirements.txt12
-rw-r--r--app/shared_thread_data.py14
-rw-r--r--app/steamvr.py87
-rw-r--r--app/stt.py915
-rw-r--r--bpe_dump.py61
-rw-r--r--config.yaml32
-rw-r--r--generate_bpe_lut.py119
-rw-r--r--generate_tokenizer.py525
-rw-r--r--tokenize_me.py28
-rw-r--r--ui/.gitignore5
-rw-r--r--ui/build_scripts/setup-empty-venv.js25
-rw-r--r--ui/config-schema.js57
-rw-r--r--ui/index.html349
-rw-r--r--ui/index.js616
-rw-r--r--ui/package.json118
-rw-r--r--ui/postcss.config.js6
-rw-r--r--ui/preload.js17
-rw-r--r--ui/renderer.js531
-rw-r--r--ui/src/components.css122
-rw-r--r--ui/src/input.css3
-rw-r--r--ui/tailwind.config.js13
-rw-r--r--ui_design.md39
35 files changed, 4929 insertions, 2 deletions
diff --git a/.cursorignore b/.cursorignore
new file mode 100644
index 0000000..a8f4624
--- /dev/null
+++ b/.cursorignore
@@ -0,0 +1,2 @@
+**/node_modules
+**/site-packages \ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 0b41544..a82a975 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
# Ignore vim swap files.
-*.sw[po]
+.*.sw[po]
*.dll
+*.meta
+.venv_is_set_up
+
diff --git a/Designs/fast_text_paging.md b/Designs/fast_text_paging.md
new file mode 100644
index 0000000..abb0576
--- /dev/null
+++ b/Designs/fast_text_paging.md
@@ -0,0 +1,321 @@
+# Optimized text paging for VRChat
+
+This repo provides code to help you send English text into VRChat. It includes:
+
+1. Training code to produce an English-language tokenizer of any vocabulary
+ size.
+2. Code to turn your tokenizer into a lookup table for GPU decoding.
+3. Unity code to generate an animator to shuttle data from OSC to material
+ properties.
+4. OSC code to talk to your Unity animator.
+
+To get started, see Quick Start.
+
+## Quick start
+
+1. Clone this repo.
+2. Clone my toon shader, [2ner](https://github.com/yum-food/2ner).
+3. Install Lyuma's av3emulator.
+4. Drag STT.prefab onto your avatar's root.
+5. Enter play mode.
+6. Open PowerShell.
+
+```bash
+$ cd ~
+$ mkdir tmp
+$ cd tmp
+$ python.exe -m venv venv
+$ ./venv/Scripts/Activate.ps1
+$ pushd /path/to/FastTextPaging/
+$ pip3 install -r requirements.txt
+$ python3 ./hi.py
+```
+
+7. Start typing.
+
+## Design overview
+
+It is sometimes useful to send text data into VRChat, for example for
+speech-to-text (STT). This is typically done naively, with a "block" of
+n 8-bit characters\* sent in along with an 8-bit pointer. Since avatars can only
+send 256 bits at 3 Hz\*\* with OSC, this means you can only send (256 - 8) / 8 =
+31 characters per sync. The average English word is 4.79 characters long, so
+if we naively send in 1 character per byte, then we get a speed limit of 6.47
+words per sync. The works out to ~20 words per second or ~1200 words per minute
+(wpm). Adults typically read at about 238-260 wpm.
+
+\* Typically ASCII encoding is used.
+
+\*\* Experimentally, 3 Hz is the fastest you can reliably page data with OSC in
+busy instances.
+
+Sending in one character per byte gives you (1200/256) ~= 4.7 wpm per OSC bit
+used. Thus to reach a typical reading speed, you need to use (260/4.7) = 55.5
+OSC bits. The goal of this module is to get more out of these bits by
+compressing text over the wire.
+
+### Unigram tokenizer
+
+Byte pair encoding (BPE) is an encoding scheme frequently used in natural
+language processing (NLP) contexts. For any language with a fixed character set
+(e.g. an alphabet), you can count up how often each character is used in a
+large corpus of text. Then you can repeatedly join together characters and
+assign a unique token to joined characters in the order of the most frequently
+occurring sequences. You wind up with a lookup table like this:
+
+```
+0: [UNK]
+1: <s>
+2: </s>
+3: [CLS]
+4: [SEP]
+...
+100: from
+101: but
+102: he
+103: e
+104: now
+...
+10000: eo
+10001: is currently
+10002: dish
+10003: Mi
+10004: 6)
+```
+
+The above example is from a unigram\* sentencepiece organizer that I trained.
+
+\* A unigram tokenizer is a variant of the byte-pair tokenizer. Where a byte
+pair tokenizer views inputs as arbitrary sequences of bytes, a unigram
+tokenizer views it as a sequence of letters.
+
+The tokenizer has a vocabulary size of 65,536 tokens. It was trained on
+opensubtitles and 5% of wikipedia, with `unidecode` normalization applied to
+limit training data to ASCII. Subword lengths are distributed as follows:
+
+Subword length histogram:
+1: 95
+2: 1032
+3: 3350
+4: 5439
+5: 5445
+6: 5082
+7: 5334
+8: 5866
+9: 6172
+10: 5934
+11: 5329
+12: 4698
+13: 4016
+14: 3191
+15: 2434
+16: 2095
+
+In the test set - 25% wikipedia 75% conversational English -
+this tokenizer yields 4.896 characters per token. (Recall that
+the average English word is 4.8 characters long. Not bad!)
+
+If we naively send these tokens into the game with a 16-bit number, we can send
+floor((256-8)/16) = 15 tokens per sync. This gives us an average of 73.44
+characters per sync - more than 2x higher than the naive approach's 31.
+
+Here is how the unigram-tokenized scheme fairs against the naive scheme in
+every possible configuration of bits used:
+
+```
+bits naive rate bpe rate speedup factor
+8 n/a n/a n/a
+16 1 n/a 0.000
+24 2 4.896 2.448
+32 3 4.896 1.632
+40 4 9.792 2.448
+48 5 9.792 1.958
+56 6 14.688 2.448
+64 7 14.688 2.098
+72 8 19.584 2.448
+80 9 19.584 2.176
+88 10 24.480 2.448
+96 11 24.480 2.225
+104 12 29.376 2.448
+112 13 29.376 2.260
+120 14 34.272 2.448
+128 15 34.272 2.285
+136 16 39.168 2.448
+144 17 39.168 2.304
+152 18 44.064 2.448
+160 19 44.064 2.319
+168 20 48.960 2.448
+176 21 48.960 2.331
+184 22 53.856 2.448
+192 23 53.856 2.342
+200 24 58.752 2.448
+208 25 58.752 2.350
+216 26 63.648 2.448
+224 27 63.648 2.357
+232 28 68.544 2.448
+240 29 68.544 2.364
+248 30 73.440 2.448
+256 31 73.440 2.369
+```
+
+([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing))
+
+I reserve 39 token slots for sequences of whitespace characters of length 2-40. This helps simplify formatting. To end a line or position text, you can just send in the exact right number of spaces, and a fixed-width font renderer will position things as intended.
+
+### Paging data into shader
+
+Sending this data to a shader is pretty simple:
+
+- An OSC app encodes a string into tokens and pages it into the game with OSC.
+ - The app sends a pointer of *where* the tokens should be rendered along with them. Since the tokens can encode a variable length string, the pointer must be able to point to any spot in the rendering window. Thus we are limited to a 256-bit display with an 8-bit pointer, or a 64K display with a 16-bit pointer. We call this the visual pointer.
+ - A second pointer tells the animator which shader parameters the tokens and visual pointer should be written to. This can be 8-bit. We call this the logical pointer.
+- Animator uses the logical pointer to decide which shader parameters to send the visual pointer and tokens to.
+
+Here is the expected speedup in every possible configuration with a 1-byte
+overhead (1BO) or a 2-byte overhead (2BO):
+
+```
+bits naive rate bpe rate (1BO) speedup bpe rate (2BO) speedup
+8 n/a n/a n/a n/a n/a
+16 1 0.000 0.000 0.000 0.000
+24 2 0.000 0.000 0.000 0.000
+32 3 4.896 1.632 0.000 0.000
+40 4 4.896 1.224 4.896 1.224
+48 5 9.792 1.958 4.896 0.979
+56 6 9.792 1.632 9.792 1.632
+64 7 14.688 2.098 9.792 1.399
+72 8 14.688 1.836 14.688 1.836
+80 9 19.584 2.176 14.688 1.632
+88 10 19.584 1.958 19.584 1.958
+96 11 24.480 2.225 19.584 1.780
+104 12 24.480 2.040 24.480 2.040
+112 13 29.376 2.260 24.480 1.883
+120 14 29.376 2.098 29.376 2.098
+128 15 34.272 2.285 29.376 1.958
+136 16 34.272 2.142 34.272 2.142
+144 17 39.168 2.304 34.272 2.016
+152 18 39.168 2.176 39.168 2.176
+160 19 44.064 2.319 39.168 2.061
+168 20 44.064 2.203 44.064 2.203
+176 21 48.960 2.331 44.064 2.098
+184 22 48.960 2.225 48.960 2.225
+192 23 53.856 2.342 48.960 2.129
+200 24 53.856 2.244 53.856 2.244
+208 25 58.752 2.350 53.856 2.154
+216 26 58.752 2.260 58.752 2.260
+224 27 63.648 2.357 58.752 2.176
+232 28 63.648 2.273 63.648 2.273
+240 29 68.544 2.364 63.648 2.195
+248 30 68.544 2.285 68.544 2.285
+256 31 73.440 2.369 68.544 2.211
+```
+
+([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing))
+
+As you can see, a 2-byte visual pointer is very damaging to the speedup at low bit budgets. So in bit-constrained setups we should definitely use a smaller display.
+
+Notably, *there is only one crossover point*. If all configurations except the 2-byte overhead 48-bit configuration, BPE-based paging is always\* faster.
+
+\* Always, going off of the *expected* rate. If you get unlucky and your tokens all decode to 1 character, then BPE-based paging is about 50% *slower* than naive encoding.
+
+Because the Unity animator sucks shit, we're going to decode tokens on the GPU. As a refresher, the GPU sees data like this:
+
+```
+_Text_Block00_Visual_Ptr: 0
+_Text_Block00_Token00: 13,766
+_Text_Block00_Token01: 84
+_Text_Block01_Visual_Ptr: 13
+_Text_Block01_Token00: 599
+_Text_Block01_Token01: 8,301
+...
+```
+
+I.e. it sees "blocks" of data with tokens and visual pointers. The visual pointer just says where on a grid it should draw the subwords represented by the tokens.
+
+The pixel shader can trivially work out what grid location the current pixel belongs to. By scanning through the visual pointers, it can work out which block it has to draw.
+
+We can generate a function like this:
+
+```c
+#define BLOCK_WIDTH 2
+void GetBlock(uint which_block, out float data[BLOCK_WIDTH]) {
+ [loop]
+ for (uint i = 0; i < BLOCK_WIDTH; i++) {
+ data[i] = _Text_Blocks[which_block][i];
+ }
+}
+
+// Get the tokens that cover `screen_ptr`. Also returns `block_ptr`, the
+// location where this block of tokens begins.
+void GetTokens(uint screen_ptr, out uint block_ptr, out uint tokens[BLOCK_WIDTH]) {
+ uint which_block;
+ [loop]
+ for (uint i = 0; i < BLOCK_WIDTH; i++) {
+ if (screen_ptr >= _Text_Block_Visual_Ptrs[i]) {
+ which_block = i;
+ }
+ }
+ GetBlock(which_block, tokens);
+}
+```
+
+### GPU decoding
+
+Now we have to translate the tokens into text. I do this with a texture laid out as follows:
+
+1. A fixed-length array of (offset, length) pairs. Offset is 24 bits, giving us an address space of about 16 million slots. Length is 8 bits, but as established above, the longest token is only 16 characters long. So we're wasting about 4 bits. This tells us we should use an RGBA texture.
+
+2. A variable length array of ASCII-encoded strings. Each slot is RGBA, so it can hold 4 characters.
+
+My tokenizer's vocabulary is 65,536 tokens. If we add up the lengths of every token, rounding them up to the nearest multiple of 4, we get the length 667,532 This means that we need 166,883 slots to fit the actual content of the vocabulary.
+
+So, the entire vocabulary - length+offset head and content - requires a 32-bit RGBA texture with 232,419 slots. We'll just jam this into a 512x512 texture, at an occupancy ratio of 88.66% (11.34% waste). The total VRAM usage of that lookup table (LUT) is 1 MiB.
+
+![Unigram tokenizer texture](Images/unigram_lut_for_visualization.png)
+
+*A 64K vocabulary tokenizer I trained on Wikipedia and OpenSubtitles.*
+
+We want to implement this API:
+
+```c
+uint GetChar(uint screen_ptr);
+```
+
+Internally, it must:
+
+1. Get tokens. (GetTokens - already done)
+2. Figure out which token covers the screen\_ptr.
+3. Figure out which character in the token covers the screen\_ptr.
+
+Let's break down [2]. We can get the length of a token with a single texture tap. So, naively, we can just scan through the tokens in the current block, add up their lengths, and stop once we find a token covering the current slot. The scan incurs a worst-case cost of `BLOCK_WIDTH` texture taps. The character lookup is then a single tap.
+
+```c
+// Gets the length of the subword encoded by the token. Performs one texture
+// tap.
+void TokenLengthOffset(uint token, out uint length, out uint offset);
+// Gets the nth character of the token stored at `token_offset`.
+uint GetTokenChar(uint token_offset, uint nth);
+
+uint GetChar(uint screen_ptr) {
+ uint block_ptr;
+ uint tokens[BLOCK_WIDTH];
+ GetTokens(screen_ptr, tokens, block_ptr);
+ uint start = block_ptr;
+ uint covering_token = 0;
+ uint token_ptr = block_ptr;
+ uint token_offset;
+ for (uint ii = 0; ii < BLOCK_WIDTH; ii++) {
+ uint token_length;
+ TokenLengthOffset(tokens[ii], token_length, token_offset);
+ if (token_ptr + token_length >= screen_ptr) {
+ covering_token = tokens[ii];
+ }
+ token_ptr += token_length;
+ }
+ // covering_token covers screen_ptr. It starts at token_ptr.
+ return GetTokenChar(token_offset, screen_ptr - token_ptr);
+}
+```
+
+That's actually it for the GPU decoding. Once you have the character, you can use standard fixed-width font rendering techniques to display it (e.g. [disinfo](https://github.com/yum-food/disinfo) and [msdf](https://github.com/Chlumsky/msdfgen)).
+
diff --git a/GenerateTextAnimator.cs b/GenerateTextAnimator.cs
new file mode 100644
index 0000000..fb4a1ee
--- /dev/null
+++ b/GenerateTextAnimator.cs
@@ -0,0 +1,200 @@
+#if UNITY_EDITOR
+using AnimatorAsCode.V1;
+using AnimatorAsCode.V1.ModularAvatar;
+using YumTools;
+using nadena.dev.ndmf;
+using System.Collections.Generic;
+using UnityEditor;
+using UnityEditor.Animations;
+using UnityEngine;
+using VRC.SDK3.Avatars.Components;
+using VRC.SDKBase;
+
+// This example uses NDMF. See https://github.com/bdunderscore/ndmf?tab=readme-ov-file#getting-started
+[assembly: ExportsPlugin(typeof(GenerateTextAnimatorPlugin))]
+namespace YumTools
+{
+ public class GenerateTextAnimator : MonoBehaviour, IEditorOnly
+ {
+ // The number of blocks addressable.
+ public int numBlocks = 10;
+ // The number of datums sent per block.
+ public int blockWidth = 5;
+ // The number of bytes per datum.
+ public int bytesPerDatum = 2;
+
+ public string oscPointerParam = "_Unigram_Letter_Grid_OSC_Pointer";
+
+ // Data sent through OSC uses this pattern.
+ public string[] oscParamPerBlockDatumByte = {"_Unigram_Letter_Grid_OSC_Datum{0:00}_Byte{1:00}"};
+ // The pattern of the parameters which this script will animate.
+ public string[] matPropPerBlockDatumByte = {"_Unigram_Letter_Grid_Data_Block{0:00}_Datum{1:00}_Byte{2:00}_Animated"};
+
+ public string[] oscParamPerBlockDatum = {};
+ public string[] matPropPerBlockDatum = {};
+
+ public string[] oscParamPerBlock = {"_Unigram_Letter_Grid_OSC_Visual_Pointer"};
+ public string[] matPropPerBlock = {"_Unigram_Letter_Grid_Block_{0:00}_Visual_Pointer_Animated"};
+ }
+
+ public class GenerateTextAnimatorPlugin : Plugin<GenerateTextAnimatorPlugin>
+ {
+ public override string QualifiedName => "yum.generate_chat_animator";
+ public override string DisplayName => "Chat Animator";
+
+ private const string SystemName = "GenerateTextAnimator";
+ // Direct blendtrees have special semantics with write defaults. We want
+ // them on. They will not fuck up the rest of the animator, whether it uses
+ // write defaults or not.
+ private const bool UseWriteDefaults = true;
+
+ protected override void Configure()
+ {
+ InPhase(BuildPhase.Generating).Run($"Generate {DisplayName}", Generate);
+ }
+
+ private AacFlBlendTreeDirect Add8BitBlendTree(
+ AacFlBase aac,
+ AacFlLayer layer,
+ GenerateTextAnimator cfg,
+ AacFlBlendTreeDirect tree,
+ string matProp, string oscParam, string oneParam) {
+ var offAnim = aac.NewClip().Animating(clip =>
+ {
+ clip.Animates(cfg.GetComponent<Renderer>(), "material." + matProp).WithFrameCountUnit(keyframes =>
+ keyframes.Constant(/*when=*/0, /*value=*/0));
+ });
+ var onAnim = aac.NewClip().Animating(clip =>
+ {
+ clip.Animates(cfg.GetComponent<Renderer>(), "material." + matProp).WithFrameCountUnit(keyframes =>
+ keyframes.Constant(/*when=*/0, /*value=*/255));
+ });
+ var subtree = aac.NewBlendTree().Simple1D(layer.FloatParameter(oscParam))
+ .WithAnimation(offAnim, /*threshold=*/-1)
+ .WithAnimation(onAnim, /*threshold=*/ 1);
+ return tree.WithAnimation(subtree, layer.FloatParameter("AlwaysOne"));
+ }
+
+ private void Generate(BuildContext ctx)
+ {
+ var components = ctx.AvatarRootTransform.GetComponentsInChildren<GenerateTextAnimator>(true);
+ if (components.Length == 0) return;
+
+ var aac = AacV1.Create(new AacConfiguration
+ {
+ SystemName = SystemName,
+ AnimatorRoot = ctx.AvatarRootTransform,
+ DefaultValueRoot = ctx.AvatarRootTransform,
+ AssetKey = GUID.Generate().ToString(),
+ AssetContainer = ctx.AssetContainer,
+ ContainerMode = AacConfiguration.Container.OnlyWhenPersistenceRequired,
+ AssetContainerProvider = new NDMFContainerProvider(ctx),
+ DefaultsProvider = new AacDefaultsProvider(UseWriteDefaults)
+ });
+
+ // Create a new object in the scene. We will add Modular Avatar components inside it.
+ var modularAvatar = MaAc.Create(new GameObject(SystemName)
+ {
+ transform = { parent = ctx.AvatarRootTransform }
+ });
+
+ var ctrl = aac.NewAnimatorController();
+ for (int component_i = 0; component_i < components.Length; component_i++) {
+ GenerateTextAnimator cfg = components[component_i];
+ var layer = ctrl.NewLayer($"Chatbox Plumbing (component #{component_i})");
+
+ var baseState = layer.NewState("Entry (noop)").WithWriteDefaultsSetTo(false);
+
+ // Create "always one" value to use in DBT.
+ // See https://vrc.school/docs/Other/DBT-Combining/ for details.
+ layer.OverrideValue(layer.FloatParameter("AlwaysOne"), 1.0f);
+
+ // Create blendtrees. One for each block of data.
+ var block_trees = new List<AacFlState>();
+ AacFlState last_block_state = null;
+ // For each block.
+ for (int i = 0; i < cfg.numBlocks; i++) {
+ var block_tree = aac.NewBlendTree().Direct();
+
+ // Create block-level animations.
+ for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) {
+ string matPropB = string.Format(cfg.matPropPerBlock[ii], i);
+ string oscParamB = cfg.oscParamPerBlock[ii];
+ block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropB, oscParamB, "AlwaysOne");
+ }
+
+ // For each datum per block.
+ for (int j = 0; j < cfg.blockWidth; j++) {
+ // Create (block, datum)-level animations.
+ for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) {
+ string matPropBD = string.Format(cfg.matPropPerBlockDatum[ii], i, j);
+ string oscParamBD = string.Format(cfg.oscParamPerBlockDatum[ii], j);
+ block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBD, oscParamBD, "AlwaysOne");
+ }
+
+ // For each byte per datum.
+ for (int k = 0; k < cfg.bytesPerDatum; k++) {
+ // Create (block, datum, byte)-level animations.
+ for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) {
+ string matPropBDB = string.Format(cfg.matPropPerBlockDatumByte[ii], i, j, k);
+ //Debug.Log($"animating property: {matPropBDB}");
+ string oscParamBDB = string.Format(cfg.oscParamPerBlockDatumByte[ii], j, k);
+ block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBDB, oscParamBDB, "AlwaysOne");
+ }
+ }
+ }
+
+ var cur_block_state = layer.NewState($"Block {i}").WithAnimation(block_tree);
+ if (last_block_state != null) {
+ cur_block_state = cur_block_state.Under(last_block_state);
+ }
+ last_block_state = cur_block_state;
+ block_trees.Add(cur_block_state);
+ }
+
+ // Create transitions to each block's blendtree.
+ for (int i = 0; i < cfg.numBlocks; i++) {
+ block_trees[i].TransitionsFromAny()
+ //.WithInterruption(TransitionInterruptionSource.Source)
+ .When(layer.IntParameter(cfg.oscPointerParam).IsEqualTo(i));
+ }
+
+ // Create sync params (VRCSDK)
+ for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) {
+ string bParam = cfg.oscParamPerBlock[ii];
+ modularAvatar.NewParameter(layer.FloatParameter(bParam));
+ }
+ for (int i = 0; i < cfg.blockWidth; i++) {
+ for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) {
+ string bdParam = string.Format(cfg.oscParamPerBlockDatum[ii], i);
+ modularAvatar.NewParameter(layer.FloatParameter(bdParam));
+ }
+ for (int j = 0; j < cfg.bytesPerDatum; j++) {
+ for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) {
+ string bdbParam = string.Format(cfg.oscParamPerBlockDatumByte[ii], i, j);
+ modularAvatar.NewParameter(layer.FloatParameter(bdbParam)).WithDefaultValue(1.0f);
+ }
+ }
+ }
+
+ modularAvatar.NewParameter(layer.IntParameter(cfg.oscPointerParam));
+ }
+
+ // By creating a Modular Avatar Merge Animator component,
+ // our animator controller will be added to the avatar's FX layer.
+ modularAvatar.NewMergeAnimator(ctrl.AnimatorController, VRCAvatarDescriptor.AnimLayerType.FX);
+ }
+ }
+
+ // (For AAC 1.2.0 and above) This is recommended starting from NDMF 1.6.0. You only need to define this class once.
+ internal class NDMFContainerProvider : IAacAssetContainerProvider
+ {
+ private readonly BuildContext _ctx;
+ public NDMFContainerProvider(BuildContext ctx) => _ctx = ctx;
+ public void SaveAsPersistenceRequired(Object objectToAdd) => _ctx.AssetSaver.SaveAsset(objectToAdd);
+ public void SaveAsRegular(Object objectToAdd) { } // Let NDMF crawl our assets when it finishes
+ public void ClearPreviousAssets() { } // ClearPreviousAssets is never used in non-destructive contexts
+ }
+}
+#endif
+
diff --git a/Images/favicon.ico b/Images/favicon.ico
new file mode 100644
index 0000000..25ea9ac
--- /dev/null
+++ b/Images/favicon.ico
Binary files differ
diff --git a/Images/unigram_lut_for_visualization.png b/Images/unigram_lut_for_visualization.png
new file mode 100644
index 0000000..622419d
--- /dev/null
+++ b/Images/unigram_lut_for_visualization.png
Binary files differ
diff --git a/LICENSE b/LICENSE
index fe6d817..ffe35bf 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,4 +1,4 @@
-Copyright 2022-2024 yum_food
+Copyright 2022-2025 yum_food
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
diff --git a/Third_Party/Profanity b/Third_Party/Profanity
new file mode 160000
+Subproject 5faf2ba42d7b1c0977169ec3611df25a3c08eb1
diff --git a/app/app_config.py b/app/app_config.py
new file mode 100644
index 0000000..f911456
--- /dev/null
+++ b/app/app_config.py
@@ -0,0 +1,39 @@
+import os
+import sys
+import typing
+
+def getConfig(path: str) -> typing.Dict[str, typing.Union[str, float, int, bool]]:
+ # Helper functions to detect and convert the type
+ def is_int(value: str) -> bool:
+ try:
+ int(value)
+ return True
+ except ValueError:
+ return False
+
+ def is_float(value: str) -> bool:
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+ def convert_value(key: str, value: str):
+ if key.startswith(("enable_", "remove_", "use_", "clear_")):
+ return bool(int(value))
+ elif is_int(value):
+ return int(value)
+ elif is_float(value):
+ return float(value)
+ else:
+ return value
+
+ config = {}
+ with open(path, 'r') as file:
+ for line in file:
+ key_value = line.strip().split(": ", maxsplit=1)
+ key = key_value[0]
+ value = key_value[1] if len(key_value) > 1 else ""
+ config[key] = convert_value(key, value.strip())
+ return config
+
diff --git a/app/hi.py b/app/hi.py
new file mode 100644
index 0000000..bb09418
--- /dev/null
+++ b/app/hi.py
@@ -0,0 +1,580 @@
+import app_config
+import argparse
+import io
+import keybind_event_machine
+from math import floor, ceil
+import msvcrt
+import os
+from pythonosc import udp_client
+import sentencepiece as spm
+import steamvr
+from shared_thread_data import SharedThreadData
+import stt
+import sys
+import threading
+import time
+import pygame
+
+sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
+sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+
+# Initialize pygame mixer
+pygame.mixer.init()
+
+TESTS_ENABLED = True
+
+# 0 = quiet, 1 = verbose, 2 = very verbose
+LOG_LEVEL = 0
+
+APP_ROOT = os.path.dirname(os.path.abspath(__file__))
+PROJECT_ROOT = os.path.dirname(APP_ROOT)
+
+def get_tokenizer():
+ model_path = os.path.join(PROJECT_ROOT, "custom_unigram_tokenizer_65k", "unigram.model")
+ print(f"Loading SentencePiece tokenizer from: {model_path}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
+ return sp
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, help="Path to config file (YAML).", required=True)
+ return parser.parse_args()
+
+def assert_equal(a, b):
+ err_msg = f"{a} != {b}"
+ assert a == b, err_msg
+
+# Turn a whitespace-delimited string into a list of strings no longer than
+# `cols`.
+# Preferentially breaks strings at whitespace boundaries. Preserves whitespace
+# between words, except if that whitespace comes between lines. Breaks words
+# longer than `cols` with a hyphen.
+def wrap_line(line: str, cols):
+ # First, split line into alternating chunks of words and whitespace.
+ def get_sequences(line):
+ is_space = False
+ sequences = []
+ seq_start = 0
+ seq_end = -1
+ for i in range(0, len(line)):
+ if line[i].isspace():
+ if is_space:
+ seq_end = i
+ continue
+ # We were looking at text, now we see whitespace.
+ seq = line[seq_start:seq_end+1]
+ if len(seq) > 0:
+ sequences.append(seq)
+ seq_start = i
+ seq_end = i
+ is_space = True
+ else:
+ if not is_space:
+ seq_end = i
+ continue
+ # We were looking at whitespace, now we see text.
+ seq = line[seq_start:seq_end+1]
+ if len(seq) > 0:
+ sequences.append(seq)
+ seq_start = i
+ seq_end = i
+ is_space = False
+ sequences.append(line[seq_start:seq_end+1])
+ return sequences
+ if TESTS_ENABLED:
+ assert_equal(get_sequences("foo"), ["foo"])
+ assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"])
+ assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"])
+ assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"])
+
+ # Next, greedily construct lines out of those sequences.
+ # Whitespace gets treated specially. If it would push us over the limit, we
+ # end the line and drop the whitespace.
+ sequences = get_sequences(line)
+ def coalesce_sequences(sequences, cols):
+ cur_line = ""
+ lines = []
+ for seq in sequences:
+ if len(cur_line) + len(seq) <= cols:
+ cur_line += seq
+ continue
+ if seq.isspace():
+ lines.append(cur_line)
+ cur_line = ""
+ continue
+ if len(cur_line) > 0:
+ lines.append(cur_line)
+ # Edge case: text sequence is longer than a line.
+ while len(seq) > cols:
+ seq_prefix = seq[0:cols-1] + "-"
+ seq = seq[cols-1:]
+ lines.append(seq_prefix)
+ cur_line = seq
+ if len(cur_line) > 0:
+ lines.append(cur_line)
+ return lines
+ if TESTS_ENABLED:
+ assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"])
+ assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"])
+ assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"])
+ assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"])
+ assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"])
+
+ lines = coalesce_sequences(sequences, cols)
+
+ # Next, pad each line with whitespace.
+ def pad_lines(lines, cols):
+ for i in range(0, len(lines)):
+ lines[i] += ' ' * (cols - len(lines[i]))
+ return lines
+ if TESTS_ENABLED:
+ assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "])
+ assert_equal(pad_lines(["foo"], 2), ["foo"])
+
+ return pad_lines(lines, cols)
+
+def get_blocks(lines, tokenizer, block_width, num_blocks):
+ if LOG_LEVEL == 2:
+ print(f"Lines sent to tokenizer: {''.join(lines)}")
+ tokens = tokenizer.encode_as_ids(''.join(lines))
+ if LOG_LEVEL == 2:
+ print(f"Tokens: {tokens}")
+ pieces = []
+ for tok in tokens:
+ piece = tokenizer.id_to_piece(tok)
+ pieces.append(piece)
+ if LOG_LEVEL == 2:
+ print(f"Pieces: {pieces}")
+
+ # Group tokens into blocks and pad with empty characters.
+ # Also get visual pointers - the location where each block will be rendered.
+ def get_blocks():
+ blocks = []
+ visual_pointer = 0
+ visual_pointers = []
+ for i in range(0, ceil(len(tokens) / block_width)):
+ visual_pointers.append(visual_pointer)
+ block = []
+ for j in range(0, block_width):
+ if i*block_width + j >= len(tokens):
+ # Pad block with empty characters. 65535 is a special token.
+ block += [65535] * (block_width - len(block))
+ break
+ block.append(tokens[i*block_width+j])
+ visual_pointer += len(pieces[i*block_width+j])
+ blocks.append(block)
+ return (blocks, visual_pointers)
+ blocks, visual_pointers = get_blocks()
+ if LOG_LEVEL == 2:
+ print(f"Blocks: {blocks}")
+ print(f"Visual pointers: {visual_pointers}")
+
+ # Set all blocks up to the next `num_blocks` boundary to blank tokens.
+ # This handles the edge case where a prior message wrote data there which
+ # is covering up our new data.
+ def pad_blocks(blocks, visual_pointers):
+ cur_num_blocks = len(blocks)
+ num_pad_blocks = num_blocks - cur_num_blocks
+ for i in range(0, num_pad_blocks):
+ blocks.append([65535] * block_width)
+ visual_pointers.append(255)
+ return blocks, visual_pointers
+ blocks, visual_pointers = pad_blocks(blocks, visual_pointers)
+ if LOG_LEVEL == 2:
+ print(f"Blocks (padded): {blocks}")
+ print(f"Visual pointers (padded): {visual_pointers}")
+
+ return blocks, visual_pointers
+
+def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks,
+ cur_visual_pointers):
+ diff_indices = []
+ diff_blocks = []
+ diff_visual_pointers = []
+
+ for i in range(0, len(cur_blocks)):
+ if i >= len(prev_blocks):
+ diff_blocks.append(cur_blocks[i])
+ diff_visual_pointers.append(cur_visual_pointers[i])
+ diff_indices.append(i)
+ continue
+ if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]:
+ diff_blocks.append(cur_blocks[i])
+ diff_visual_pointers.append(cur_visual_pointers[i])
+ diff_indices.append(i)
+
+ return diff_indices, diff_blocks, diff_visual_pointers
+
+def send_data(osc_client, indices, blocks, visual_pointers):
+ def split_blocks_by_byte(blocks):
+ blocks_byte00 = []
+ blocks_byte01 = []
+ for block in blocks:
+ block_byte00 = []
+ block_byte01 = []
+ for datum in block:
+ block_byte00.append((datum >> 0) & 0xFF)
+ block_byte01.append((datum >> 8) & 0xFF)
+ blocks_byte00.append(block_byte00)
+ blocks_byte01.append(block_byte01)
+ return blocks_byte00, blocks_byte01
+
+ blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks)
+ if LOG_LEVEL == 2:
+ print(f"Blocks (byte 00): {blocks_byte00}")
+ print(f"Blocks (byte 01): {blocks_byte01}")
+
+ def send_osc(osc_client, addr, data):
+ #print(f"Sending {data} to {addr}")
+ osc_client.send_message(addr, data)
+
+ for i in range(0, len(blocks)):
+ lp_int = indices[i]
+ lp_param = "_Unigram_Letter_Grid_OSC_Pointer"
+ addr = "/avatar/parameters/" + lp_param
+ send_osc(osc_client, addr, lp_int)
+
+ vp_float = (-127.5 + visual_pointers[i]) / 127.5
+ vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer"
+ addr = "/avatar/parameters/" + vp_param
+ send_osc(osc_client, addr, vp_float)
+ if LOG_LEVEL == 2:
+ print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}")
+ for j in range(0, len(blocks[i])):
+ byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5
+ byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5
+ byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00"
+ byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01"
+ addr = "/avatar/parameters/" + byte00_param
+ send_osc(osc_client, addr, byte00_float)
+ addr = "/avatar/parameters/" + byte01_param
+ send_osc(osc_client, addr, byte01_float)
+ time.sleep(0.34)
+
+def getOscClient(ip = "127.0.0.1", port = 9000):
+ return udp_client.SimpleUDPClient(ip, port)
+
+class InputState:
+ def __init__(self):
+ self.page = 0
+ # Initialize the known state of the board to empty array. This will cause
+ # our paging logic to re-send everything the first time around.
+ self.blocks = []
+ self.visual_pointers = []
+ pass
+
+def handle_input(state: InputState, line: str, tokenizer, osc_client, cfg):
+ line_wrapped = wrap_line(line, cfg["cols"])
+ if TESTS_ENABLED:
+ for line in line_wrapped:
+ assert_equal(len(line), cfg["cols"])
+ if LOG_LEVEL == 2:
+ print(f"Wrapped lines: {line_wrapped}")
+
+ # Get several blank lines whenever we roll over.
+ # It's better for the reader to have some continuity when the board pages
+ # over. If we simply replaced the entire screen, it would be harder to
+ # understand.
+ line_rollover = cfg["rows"] - 2
+ blank_line = ' ' * cfg["cols"]
+ # We show a full page, then only `line_rollover` additional lines per page.
+ end_ptr = cfg["rows"]
+ which_page = 0
+ while end_ptr < len(line_wrapped):
+ end_ptr += line_rollover
+ which_page += 1
+ if state.page != which_page:
+ state.blocks = []
+ state.visual_pointers = []
+ state.page = which_page
+ line_wrapped = line_wrapped[end_ptr-cfg["rows"]:]
+
+ # Get blocks and visual pointers.
+ blocks, visual_pointers = get_blocks(line_wrapped, tokenizer,
+ cfg["block_width"], cfg["num_blocks"])
+
+ # Note that because we only send one page of data at a time, we don't have
+ # to worry about wrapping visual pointers! We will basically never run out
+ # of space.
+ indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers)
+ indices = [idx % cfg["num_blocks"] for idx in indices]
+ # Send only one block at a time to make things snappier in interactive use
+ # case.
+ # TODO use a continuation (yield) instead of returning. Then we can be a
+ # little lighter on the cpu. Measurements show that this script is
+ # already very light but we're clearly wasting a lot of work by
+ # re-tokenizing the entire input every time we send a block.
+ if len(indices) == 0:
+ return
+ if indices[0] == len(state.blocks):
+ state.blocks.append(diff_blocks[0])
+ state.visual_pointers.append(diff_visual_pointers[0])
+ elif indices[0] > len(state.blocks):
+ print(f"This should never happen!")
+ sys.exit(1)
+ else:
+ state.blocks[indices[0]] = diff_blocks[0]
+ state.visual_pointers[indices[0]] = diff_visual_pointers[0]
+
+ send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]])
+
+def osc_thread(shared_data: SharedThreadData):
+ osc_client = getOscClient()
+
+ def join_segments(a, b):
+ if len(a) > 0 and a[-1] != ' ':
+ return a + ' ' + b
+ else:
+ return a + b
+
+ if shared_data.cfg["use_builtin"]:
+ last_change = time.time()
+ remote_word = ""
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.1)
+ local_word = ""
+ with shared_data.word_lock:
+ local_word = join_segments(shared_data.transcript,
+ shared_data.preview)
+ local_word = local_word[-140:]
+ if local_word == remote_word:
+ continue
+ if time.time() - last_change < 1.5:
+ continue
+ addr = "/chatbox/input"
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"Send {local_word}", flush=True)
+ osc_client.send_message(addr, (local_word, True, False))
+ last_change = time.time()
+ remote_word = local_word
+ else:
+ # Custom chatbox
+ tokenizer = get_tokenizer()
+
+ # Prime the board
+ print("Priming the board")
+ input_state = InputState()
+ handle_input(input_state, "", tokenizer, osc_client, shared_data.cfg)
+
+ while not shared_data.exit_event.is_set():
+ word_copy = ""
+ with shared_data.word_lock:
+ word_copy = shared_data.word
+ handle_input(input_state, word_copy, tokenizer, osc_client, shared_data.cfg)
+ time.sleep(0.01)
+
+
+def vrInputThread(shared_data: SharedThreadData):
+ RECORD_STATE = 0
+ PAUSE_STATE = 1
+ state = PAUSE_STATE
+
+ hand_id = shared_data.cfg["button_hand"]
+ button_id = shared_data.cfg["button_type"]
+
+ # Rough description of state machine:
+ # Single short press: toggle transcription
+ # Medium press: dismiss custom chatbox
+ # Long press: update chatbox in place
+ # Medium press + long press: type transcription
+
+ last_rising = time.time()
+ last_medium_press_end = 0
+
+ waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav")
+ waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav")
+ waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav")
+ waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav")
+
+ button_generator = steamvr.pollButtonPress(hand=hand_id, button=button_id,
+ shared_data=shared_data)
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.01)
+ try:
+ event = next(button_generator)
+ except StopIteration:
+ break
+
+ with shared_data.word_lock:
+ if not shared_data.stream or not shared_data.collector:
+ continue
+
+ if event.opcode == steamvr.EVENT_RISING_EDGE:
+ last_rising = time.time()
+
+ if state == PAUSE_STATE:
+ shared_data.stream.pause(False)
+ shared_data.stream.getSamples()
+
+ elif event.opcode == steamvr.EVENT_FALLING_EDGE:
+ now = time.time()
+ if now - last_rising > 1.5:
+ # Long press: treat as the end of transcription.
+ state = PAUSE_STATE
+
+ shared_data.stream.pause(True)
+
+ if last_rising - last_medium_press_end < 1.0:
+ # Type transcription
+ play_sound_with_volume(waveform3, shared_data.cfg)
+ else:
+ play_sound_with_volume(waveform1, shared_data.cfg)
+
+ elif now - last_rising > 0.5:
+ # Medium press
+ print("CLEARING", file=sys.stderr)
+ last_medium_press_end = now
+ state = PAUSE_STATE
+ play_sound_with_volume(waveform2, shared_data.cfg)
+
+ # Flush the *entire* pipeline.
+ shared_data.stream.pause(True)
+ shared_data.stream.getSamples()
+ shared_data.collector.dropAudio()
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ continue
+
+ # Short hold
+ if state == RECORD_STATE:
+ print("PAUSED", file=sys.stderr)
+ state = PAUSE_STATE
+
+ shared_data.stream.pause(True)
+ play_sound_with_volume(waveform1, shared_data.cfg)
+ elif state == PAUSE_STATE:
+ print("RECORDING", file=sys.stderr)
+ state = RECORD_STATE
+ if shared_data.cfg["reset_on_toggle"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, dropping transcript (3)",
+ file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ #audio_state.drop_transcription = True
+ else:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, committing preview text (3)",
+ file=sys.stderr)
+ #audio_state.text += audio_state.preview_text
+
+ shared_data.stream.pause(False)
+ play_sound_with_volume(waveform0, shared_data.cfg)
+
+
+def kbInputThread(shared_data: SharedThreadData):
+ machine = keybind_event_machine.KeybindEventMachine(shared_data.cfg["keybind"])
+ last_press_time = 0
+
+ # double pressing the keybind
+ double_press_timeout = 0.5
+
+ RECORD_STATE = 0
+ PAUSE_STATE = 1
+ state = PAUSE_STATE
+
+ waveform0 = os.path.join(PROJECT_ROOT, "Sounds/Noise_On_Quiet.wav")
+ waveform1 = os.path.join(PROJECT_ROOT, "Sounds/Noise_Off_Quiet.wav")
+ waveform2 = os.path.join(PROJECT_ROOT, "Sounds/Dismiss_Noise_Quiet.wav")
+ waveform3 = os.path.join(PROJECT_ROOT, "Sounds/KB_Noise_Off_Quiet.wav")
+
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.01)
+
+ cur_press_time = machine.getNextPressTime()
+ if cur_press_time == 0:
+ continue
+
+ with shared_data.word_lock:
+ if not shared_data.stream or not shared_data.collector:
+ continue
+
+ EVENT_SINGLE_PRESS = 0
+ EVENT_DOUBLE_PRESS = 1
+ if last_press_time == 0:
+ event = EVENT_SINGLE_PRESS
+ elif cur_press_time - last_press_time < double_press_timeout:
+ event = EVENT_DOUBLE_PRESS
+ else:
+ event = EVENT_SINGLE_PRESS
+ last_press_time = cur_press_time
+
+ if event == EVENT_DOUBLE_PRESS:
+ print("CLEARING", file=sys.stderr)
+ state = PAUSE_STATE
+ play_sound_with_volume(waveform2, shared_data.cfg)
+
+ # Flush the *entire* pipeline.
+ shared_data.stream.pause(True)
+ shared_data.stream.getSamples()
+ shared_data.collector.dropAudio()
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ continue
+
+ # Short hold
+ if state == RECORD_STATE:
+ print("PAUSED", file=sys.stderr)
+ state = PAUSE_STATE
+ shared_data.stream.pause(True)
+ play_sound_with_volume(waveform1, shared_data.cfg)
+ elif state == PAUSE_STATE:
+ print("RECORDING", file=sys.stderr)
+ state = RECORD_STATE
+ if shared_data.cfg["reset_on_toggle"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, dropping transcript (2)",
+ file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ else:
+ if shared_data.cfg["enable_debug_mode"]:
+ print("Toggle detected, committing preview text (2)",
+ file=sys.stderr)
+ shared_data.stream.pause(False)
+ play_sound_with_volume(waveform0, shared_data.cfg)
+
+def play_sound_with_volume(filepath, cfg):
+ """Play a WAV file with adjusted volume"""
+ volume = cfg.get("volume", 30)
+
+ try:
+ sound = pygame.mixer.Sound(filepath)
+ sound.set_volume(volume * 0.01)
+ sound.play()
+ except Exception as e:
+ print(f"Error playing sound {filepath}: {e}", file=sys.stderr)
+
+if __name__ == "__main__":
+ cli_args = parse_args()
+ cfg = app_config.getConfig(cli_args.config)
+ shared_data = SharedThreadData(cfg)
+ osc_thread = threading.Thread(
+ target=osc_thread,
+ args=(shared_data,))
+ osc_thread.start()
+
+ transcribe_thread = threading.Thread(
+ target=stt.transcriptionThread,
+ args=(shared_data,))
+ transcribe_thread.start()
+
+ vr_input_thd = threading.Thread(target=vrInputThread, args=(shared_data,))
+ vr_input_thd.start()
+
+ kb_input_thd = threading.Thread(target=kbInputThread, args=(shared_data,))
+ kb_input_thd.start()
+
+ word_is_over = False
+ local_word = ""
+ while True:
+ time.sleep(0.1)
+ continue
+ shared_data.exit_event.set()
+ osc_thread.join()
+ transcribe_thread.join()
+ vr_input_thd.join()
+ kb_input_thd.join()
+
diff --git a/app/keybind_event_machine.py b/app/keybind_event_machine.py
new file mode 100644
index 0000000..3ce6794
--- /dev/null
+++ b/app/keybind_event_machine.py
@@ -0,0 +1,21 @@
+import keyboard
+import time
+
+class KeybindEventMachine:
+ def __init__(self, keybind: str):
+ self.keybind = keybind
+ self.events = []
+ keyboard.add_hotkey(keybind, self.onPress)
+
+ def onPress(self) -> None:
+ self.events.append(time.time())
+
+ # Returns the timestamp when the keybind was pressed, or 0 if no keypresses
+ # are queued.
+ def getNextPressTime(self) -> int:
+ if len(self.events) == 0:
+ return 0
+ ret = self.events[0]
+ self.events = self.events[1:]
+ return ret
+
diff --git a/app/list_microphones.py b/app/list_microphones.py
new file mode 100644
index 0000000..a6b1f36
--- /dev/null
+++ b/app/list_microphones.py
@@ -0,0 +1,24 @@
+import pyaudio
+import json
+import sys
+
+try:
+ p = pyaudio.PyAudio()
+ info = p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+
+ microphones = []
+ for i in range(0, numdevices):
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
+ if device_info.get('maxInputChannels') > 0:
+ microphones.append({
+ 'index': i,
+ 'name': device_info.get('name'),
+ 'defaultSampleRate': device_info.get('defaultSampleRate')
+ })
+
+ print(json.dumps(microphones))
+ p.terminate()
+except Exception as e:
+ print(json.dumps({'error': str(e)}), file=sys.stderr)
+ sys.exit(1) \ No newline at end of file
diff --git a/app/profanity_filter.py b/app/profanity_filter.py
new file mode 100644
index 0000000..b8c84ed
--- /dev/null
+++ b/app/profanity_filter.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+
+class ProfanityFilter:
+ def __init__(self, en_path: str):
+ self.en_path = en_path
+ self.en_profanity = set()
+
+ def load(self):
+ with open(self.en_path, 'r') as f:
+ for line in f:
+ self.en_profanity.add(line.strip())
+
+ def filter(self, line: str, language_code: str = "en") -> str:
+ filtered = ""
+
+ if language_code not in {"en"}:
+ raise ValueError(f"Language code \"{language_code}\" is " +
+ "unsupported by the profanity filter")
+
+ # Translation table converting vowels to asterisks.
+ vowel_to_asterisk = str.maketrans('aeiouAEIOU', '**********')
+
+ result = []
+ for word in line.split():
+ word_clean = word.lower()
+ # Filter out non-alphabet characters from the word.
+ word_clean = ''.join([char for char in word_clean if char.isalpha()])
+ if word_clean in self.en_profanity:
+ result.append(word.translate(vowel_to_asterisk))
+ else:
+ result.append(word)
+
+ return " ".join(result)
+
+if __name__ == "__main__":
+ en_path = "/mnt/d/vrc/TaSTT/GUI/Profanity/Profanity/en"
+ p = ProfanityFilter(en_path)
+ p.load()
+ assert(p.filter("fuck") == "f*ck")
+ assert(p.filter("fuck!") == "f*ck!")
+ assert(p.filter("fuck shit") == "f*ck sh*t")
+ assert(p.filter("fuck shit this should not be filtered") == "f*ck sh*t this should not be filtered")
+ assert(p.filter("ASS") == "*SS")
diff --git a/app/requirements.txt b/app/requirements.txt
new file mode 100644
index 0000000..c8d69df
--- /dev/null
+++ b/app/requirements.txt
@@ -0,0 +1,12 @@
+faster-whisper
+hf-xet
+keyboard
+langcodes
+noisereduce
+pyaudio
+pygame
+pydub
+python-osc
+sentencepiece
+silero-vad
+openvr
diff --git a/app/shared_thread_data.py b/app/shared_thread_data.py
new file mode 100644
index 0000000..40885e8
--- /dev/null
+++ b/app/shared_thread_data.py
@@ -0,0 +1,14 @@
+import threading
+
+class SharedThreadData:
+ def __init__(self, cfg):
+ self.transcript = ""
+ self.preview = ""
+
+ self.stream = None
+ self.collector = None
+
+ self.word_lock = threading.Lock()
+ self.exit_event = threading.Event()
+ self.cfg = cfg
+
diff --git a/app/steamvr.py b/app/steamvr.py
new file mode 100644
index 0000000..64f34f5
--- /dev/null
+++ b/app/steamvr.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+
+import openvr as vr
+import sys
+import time
+
+EVENT_NONE = 0
+EVENT_RISING_EDGE = 1
+EVENT_FALLING_EDGE = 2
+
+class InputEvent:
+ def __init__(self,
+ opcode: int):
+ self.opcode = opcode
+
+# Checks if the given button on the given controller is pressed.
+def pollButtonPress(
+ hand: str = "right",
+ button: str = "b",
+ shared_data = None # SharedThreadData object
+ ) -> int:
+ hands = {}
+ hands["left"] = vr.TrackedControllerRole_LeftHand
+ hands["right"] = vr.TrackedControllerRole_RightHand
+
+ buttons = {}
+ buttons["a"] = vr.k_EButton_IndexController_A
+ buttons["b"] = vr.k_EButton_IndexController_B
+ buttons["thumbstick"] = vr.k_EButton_Axis0
+
+ system = None
+ first = True
+ while not shared_data.exit_event.is_set() and not system:
+ try:
+ system = vr.init(vr.VRApplication_Background)
+ except Exception as e:
+ if first:
+ print(f"Failed to start steamVR input thread: {repr(e)}", file=sys.stderr)
+ first = False
+ time.sleep(1)
+ last_packet = 0
+ event_high = False
+
+ while not shared_data.exit_event.is_set():
+ time.sleep(0.01)
+
+ lh_idx = system.getTrackedDeviceIndexForControllerRole(hands[hand])
+ #print("left hand device idx: {}".format(lh_idx))
+
+ got_state, state = system.getControllerState(lh_idx)
+ if not got_state:
+ continue
+
+ if state.unPacketNum == last_packet:
+ continue
+
+ # Clicking joysticks and moving joysticks fire the same events. To
+ # differentiate movement from clicking, we create a dead zone: if the event
+ # fires while the stick isn't moved far from center, we assume it's a
+ # click, not movement.
+ dead_zone_radius = 0.7
+
+ button_mask = (1 << buttons[button])
+ ret = EVENT_NONE
+ if (state.ulButtonPressed & button_mask) != 0 and\
+ (state.rAxis[0].x**2 + state.rAxis[0].y**2 < dead_zone_radius**2):
+ #print("button pressed: %016x" % state.ulButtonPressed)
+ #for i in range(0, 5):
+ # print("axis {} x: {} y: {}".format(i, state.rAxis[i].x, state.rAxis[i].y))
+ if not event_high:
+ yield InputEvent(EVENT_RISING_EDGE)
+ event_high = True
+ elif event_high:
+ event_high = False
+ yield InputEvent(EVENT_FALLING_EDGE)
+
+if __name__ == "__main__":
+ gen = pollButtonPress()
+ while True:
+ time.sleep(0.1)
+
+ event = pollButtonPress(session_state)
+ if event == EVENT_RISING_EDGE:
+ print("rising edge")
+ elif event == EVENT_FALLING_EDGE:
+ print("falling edge")
+
diff --git a/app/stt.py b/app/stt.py
new file mode 100644
index 0000000..18f0f60
--- /dev/null
+++ b/app/stt.py
@@ -0,0 +1,915 @@
+from datetime import datetime
+from faster_whisper import WhisperModel
+import json
+import langcodes
+import numpy as np
+import os
+import noisereduce as nr
+try:
+ from profanity_filter import ProfanityFilter
+ PROFANITY_FILTER_AVAILABLE = True
+except ImportError:
+ PROFANITY_FILTER_AVAILABLE = False
+ print("Warning: profanity_filter module not available", file=sys.stderr)
+import pyaudio
+from pydub import AudioSegment
+from shared_thread_data import SharedThreadData
+from silero_vad import load_silero_vad, get_speech_timestamps
+import sys
+import time
+import typing
+import wave
+
+APP_ROOT = os.path.dirname(os.path.abspath(__file__))
+PROJECT_ROOT = os.path.dirname(APP_ROOT)
+
+class AudioStream():
+ FORMAT = pyaudio.paInt16
+ # Size of each frame (audio sample), in bytes. If you change FORMAT, make
+ # sure this stays up to date!
+ FRAME_SZ = 2
+ # Frames per second.
+ FPS = 16000
+ CHANNELS = 1
+ def __init__(self):
+ pass
+
+ def getSamples(self) -> bytes:
+ raise NotImplementedError("getSamples is not implemented!")
+
+class MicStream(AudioStream):
+ CHUNK_SZ = 1024
+
+ def __init__(self, cfg: typing.Dict):
+ self.p = pyaudio.PyAudio()
+ self.stream = None
+ self.sample_rate = None
+ # Each time pyaudio gives us audio data, it's in the form of a chunk of
+ # samples. We keep these in a list to keep the audio callback as light
+ # as possible. Whenever downstream layers want data, we collapse the
+ # list into a single array of data (a bytes object).
+ self.chunks = []
+ # If set, incoming frames are simply discarded.
+ self.paused = False
+
+ which_mic = cfg["microphone"]
+
+ if cfg["enable_debug_mode"]:
+ print(f"Finding mic {which_mic}", file=sys.stderr)
+ self.dumpMicDevices()
+
+ got_match = False
+ device_index = -1
+ if which_mic == "index":
+ target_str = "Digital Audio Interface"
+ elif which_mic == "focusrite":
+ target_str = "Focusrite"
+ elif which_mic == "motu":
+ target_str = "In 1-2 (MOTU M Series)"
+ elif which_mic == "beyond":
+ target_str = "Microphone (Beyond)"
+ else:
+ if cfg["enable_debug_mode"]:
+ print(f"Mic {which_mic} requested, treating it as a numerical " +
+ "device ID", file=sys.stderr)
+ device_index = int(which_mic)
+ got_match = True
+ if not got_match:
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ if target_str in device_name:
+ print(f"Got matching mic: {device_name}",
+ file=sys.stderr)
+ device_index = i
+ got_match = True
+ break
+ if not got_match:
+ raise KeyError(f"Mic {which_mic} not found")
+
+ info = self.p.get_device_info_by_host_api_device_index(0, device_index)
+ if cfg["enable_debug_mode"]:
+ print(f"Found mic {which_mic}: {info['name']}", file=sys.stderr)
+ self.sample_rate = int(info['defaultSampleRate'])
+ if cfg["enable_debug_mode"]:
+ print(f"Mic sample rate: {self.sample_rate}", file=sys.stderr)
+
+ self.stream = self.p.open(
+ rate=self.sample_rate,
+ channels=AudioStream.CHANNELS,
+ format=AudioStream.FORMAT,
+ input=True,
+ frames_per_buffer=MicStream.CHUNK_SZ,
+ input_device_index=device_index,
+ stream_callback=self.onAudioFramesAvailable)
+
+ self.stream.start_stream()
+
+ AudioStream.__init__(self)
+
+ def pause(self, state: bool = True):
+ self.paused = state
+
+ def dumpMicDevices(self):
+ info = self.p.get_host_api_info_by_index(0)
+ numdevices = info.get('deviceCount')
+
+ for i in range(0, numdevices):
+ if (self.p.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ device_name = self.p.get_device_info_by_host_api_device_index(0, i).get('name')
+ print("Input Device id ", i, " - ", device_name)
+
+ def onAudioFramesAvailable(self,
+ frames,
+ frame_count,
+ time_info,
+ status_flags):
+ if self.paused:
+ # Don't literally pause, just start returning silence. This allows
+ # the `min_segment_age_s` check to work while paused.
+ n_frames = int(frame_count * AudioStream.FPS /
+ float(self.sample_rate))
+ self.chunks.append(np.zeros(n_frames,
+ dtype=np.int16).tobytes())
+ return (frames, pyaudio.paContinue)
+
+ decimated = b''
+ # In pyaudio, a `frame` is a single sample of audio data.
+ frame_len = AudioStream.FRAME_SZ
+ next_frame = 0.0
+ # The mic probably has a higher sample rate than Whisper wants, so
+ # decrease the sample rate by dropping samples. Note that this
+ # algorithm only works if the mic's rate is higher than whisper's
+ # expected rate.
+ keep_every = float(self.sample_rate) / AudioStream.FPS
+ for i in range(frame_count):
+ if i >= next_frame:
+ decimated += frames[i*frame_len:(i+1)*frame_len]
+ next_frame += keep_every
+ self.chunks.append(decimated)
+
+ return (frames, pyaudio.paContinue)
+
+ # Get audio data and the corresponding timestamp.
+ def getSamples(self) -> bytes:
+ chunks = self.chunks
+ self.chunks = []
+ result = b''.join(chunks)
+ return result
+
+class AudioCollector:
+ def __init__(self, stream: AudioStream):
+ self.stream = stream
+ self.frames = b''
+ # Note: by design, this is the only spot where we anchor our timestamps
+ # against the real world. This is done to make it possible to profile
+ # test cases which read from disk (at much faster than real speed) in
+ # the same way that we profile real-time data.
+ self.wall_ts = time.time()
+
+ def getAudio(self) -> bytes:
+ frames = self.stream.getSamples()
+ if frames:
+ self.frames += frames
+ return self.frames
+
+ def dropAudioPrefix(self, dur_s: float) -> bytes:
+ n_bytes = int(dur_s * AudioStream.FPS) * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ cut_portion = self.frames[:n_bytes]
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS
+ return cut_portion
+
+ def dropAudioPrefixByFrames(self, dur_frames: int) -> bytes:
+ n_bytes = dur_frames * self.stream.FRAME_SZ
+ n_bytes = min(n_bytes, len(self.frames))
+ cut_portion = self.frames[:n_bytes]
+ self.frames = self.frames[n_bytes:]
+ self.wall_ts += float(n_bytes / self.stream.FRAME_SZ) / self.stream.FPS
+ return cut_portion
+
+ def keepLast(self, dur_s: float) -> bytes:
+ drop_len = max(0, self.duration() - dur_s)
+ return self.dropAudioPrefix(drop_len)
+
+ def dropAudio(self):
+ self.wall_ts += self.duration()
+ cut_portion = self.frames
+ self.frames = b''
+ return cut_portion
+
+ def duration(self):
+ return len(self.frames) / (AudioStream.FPS * self.stream.FRAME_SZ)
+
+ def begin(self):
+ return self.wall_ts
+
+ def now(self):
+ return self.begin() + self.duration()
+
+class AudioCollectorFilter:
+ def __init__(self, parent: AudioCollector):
+ self.parent = parent
+
+ def getAudio(self) -> bytes:
+ return self.parent.getAudio()
+ def dropAudioPrefix(self, dur_s: float):
+ return self.parent.dropAudioPrefix(dur_s)
+ def dropAudioPrefixByFrames(self, dur_frames: int):
+ return self.parent.dropAudioPrefixByFrames(dur_frames)
+ def keepLast(self, dur_s):
+ return self.parent.keepLast(dur_s)
+ def dropAudio(self):
+ return self.parent.dropAudio()
+ def duration(self):
+ return self.parent.duration()
+ def begin(self):
+ return self.parent.begin()
+ def now(self):
+ return self.parent.now()
+
+# Audio collector that enforces a minimum length on its audio data.
+class LengthEnforcingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector, min_duration_s: float):
+ AudioCollectorFilter.__init__(self, parent)
+ self.min_duration_s = min_duration_s
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+ min_duration_frames = int(self.min_duration_s * AudioStream.FPS)
+ pad_len_frames = max(0, min_duration_frames - int(len(audio) /
+ AudioStream.FRAME_SZ))
+ pad = np.zeros(pad_len_frames, dtype=np.int16).tobytes()
+ return pad + audio
+
+class NormalizingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ audio = audio.normalize()
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
+class BoostingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector,
+ target_dBFS: float,
+ max_gain_dB: float,
+ cfg: typing.Dict):
+ AudioCollectorFilter.__init__(self, parent)
+ self.target_dBFS = target_dBFS
+ self.max_gain_dB = max_gain_dB
+ self.cfg = cfg
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ gain = min(self.target_dBFS - audio.dBFS, self.max_gain_dB)
+ if self.cfg["enable_debug_mode"]:
+ print(f"Boosting audio by {gain} dB (from {audio.dBFS} to {audio.dBFS + gain})", flush=True)
+ audio = audio.apply_gain(gain)
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
+class CompressingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector):
+ AudioCollectorFilter.__init__(self, parent)
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+
+ audio = AudioSegment(audio, sample_width=AudioStream.FRAME_SZ,
+ frame_rate=AudioStream.FPS, channels=AudioStream.CHANNELS)
+ # subtle compression has a slight positive effect on my benchmark
+ audio = audio.compress_dynamic_range(threshold=-10, ratio=2.0)
+
+ frames = np.array(audio.get_array_of_samples())
+ frames = np.int16(frames).tobytes()
+
+ return frames
+
+class NoiseReducingAudioCollector(AudioCollectorFilter):
+ def __init__(self, parent: AudioCollector, cfg: typing.Dict):
+ AudioCollectorFilter.__init__(self, parent)
+ self.cfg = cfg
+
+ def getAudio(self) -> bytes:
+ audio = self.parent.getAudio()
+ audio_array = np.frombuffer(audio, dtype=np.int16).astype(np.float32)
+
+ reduced_audio = nr.reduce_noise(
+ y=audio_array,
+ sr=AudioStream.FPS,
+ )
+
+ # Convert back to int16
+ reduced_audio = np.clip(reduced_audio, -32768, 32767)
+ frames = np.int16(reduced_audio).tobytes()
+
+ return frames
+
+class AudioSegmenter:
+ def __init__(self,
+ min_silence_ms=250,
+ max_speech_s=5,
+ min_speech_duration_ms=100):
+ self.min_silence_ms = min_silence_ms
+ self.max_speech_s = max_speech_s
+ self.min_speech_duration_ms = min_speech_duration_ms
+
+ # Load Silero VAD model
+ self.model = load_silero_vad()
+
+ self.vad_threshold = 0.3
+ self.min_silence_duration_ms = min_silence_ms
+ self.max_speech_duration_s = max_speech_s
+ self.min_speech_duration_ms = min_speech_duration_ms
+
+ def segmentAudio(self, audio: bytes):
+ # Convert audio bytes to numpy array expected by silero-vad
+ audio_array = np.frombuffer(audio,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+
+ # Get speech timestamps using silero-vad
+ # Note: silero-vad expects sample rate of 16000 Hz which matches AudioStream.FPS
+ speech_timestamps = get_speech_timestamps(
+ audio_array,
+ self.model,
+ sampling_rate=AudioStream.FPS,
+ threshold=self.vad_threshold,
+ min_silence_duration_ms=self.min_silence_duration_ms,
+ max_speech_duration_s=self.max_speech_duration_s,
+ min_speech_duration_ms=self.min_speech_duration_ms,
+ return_seconds=False # We want frame indices, not seconds
+ )
+
+ return speech_timestamps
+
+ # Returns the stable cutoff (if any) and whether there are any segments.
+ def getStableCutoff(self, audio: bytes) -> typing.Tuple[int, bool]:
+ min_delta_frames = int((self.min_silence_duration_ms *
+ AudioStream.FPS) / 1000.0)
+ cutoff = None
+
+ last_end = None
+ segments = self.segmentAudio(audio)
+
+ for i in range(len(segments)):
+ s = segments[i]
+ #print(f"s: {s}")
+ #print(f"last_end: {last_end}")
+
+ if last_end:
+ delta_frames = s['start'] - last_end
+ #print(f"delta frames: {delta_frames}")
+ if delta_frames > min_delta_frames:
+ cutoff = s['start']
+ else:
+ last_end = s['end']
+
+ if i == len(segments) - 1:
+ now = int(len(audio) / AudioStream.FRAME_SZ)
+ #print(f"now: {now}")
+ #print(f"min d: {min_delta_frames}")
+ delta_frames = now - s['end']
+ if delta_frames > min_delta_frames:
+ cutoff = now - int(min_delta_frames / 2)
+
+ return (cutoff, len(segments) > 0)
+
+# A segment of transcribed audio. `start_ts` and `end_ts` are floating point
+# number of seconds since the beginning of audio data.
+class Segment:
+ def __init__(self,
+ transcript: str,
+ start_ts: float,
+ end_ts: float,
+ wall_ts: float,
+ avg_logprob: float,
+ no_speech_prob: float,
+ compression_ratio: float):
+ self.transcript = transcript
+ # start_ts, end_ts are timestamps in seconds relative to `wall_ts`.
+ self.start_ts = start_ts
+ self.end_ts = end_ts
+ # wall_ts is the time.time() at which the oldest audio sample leading
+ # to this transcript was collected.
+ self.wall_ts = wall_ts
+ self.avg_logprob = avg_logprob
+ self.no_speech_prob = no_speech_prob
+ self.compression_ratio = compression_ratio
+
+ def __str__(self):
+ ts = f"(ts: {self.start_ts}-{self.end_ts}) "
+
+ wall_ts_start = datetime.utcfromtimestamp(self.start_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts_end = datetime.utcfromtimestamp(self.end_ts + self.wall_ts).strftime('%H:%M:%S')
+ wall_ts = f"(wall ts: {wall_ts_start}-{wall_ts_end}) "
+
+ no_speech = f"(no_speech: {self.no_speech_prob}) "
+ avg_logprob = f"(avg_logprob: {self.avg_logprob}) "
+ return f"{self.transcript} " + ts + wall_ts + no_speech + avg_logprob
+
+def join_segments(a, b):
+ if len(a) > 0 and a[-1] != ' ':
+ return a + ' ' + b
+ else:
+ return a + b
+
+class Whisper:
+ def __init__(self,
+ collector: AudioCollector,
+ cfg: typing.Dict):
+ self.collector = collector
+ self.model = None
+ self.cfg = cfg
+
+ model_str = cfg["model"]
+ model_root = os.path.join(PROJECT_ROOT, "Models",
+ os.path.normpath(model_str))
+ if cfg["enable_debug_mode"]:
+ print(f"Model {cfg['model']} will be saved to {model_root}",
+ file=sys.stderr)
+
+ model_device = "cuda"
+ compute_type = cfg["compute_type"]
+ if cfg["use_cpu"]:
+ model_device = "cpu"
+ compute_type = "int8"
+
+ already_downloaded = os.path.exists(model_root)
+
+ if not already_downloaded:
+ print(f"Model {model_str} not already downloaded, downloading now...", flush=True)
+
+ self.model = WhisperModel(model_str,
+ device = model_device,
+ device_index = cfg["gpu_idx"],
+ compute_type = compute_type,
+ download_root = model_root,
+ local_files_only = already_downloaded)
+
+ self.context_window_chars = 200 # Keep last 200 chars of context
+ self.recent_context = "" # Store recent committed text
+
+ def update_context(self, committed_text: str):
+ """Update the context with recently committed text."""
+ self.recent_context = join_segments(self.recent_context, committed_text).strip()
+ # Drop half of the context window.
+ if len(self.recent_context) > self.context_window_chars:
+ words = self.recent_context.split()
+ words = words[len(words)//2:]
+ self.recent_context = ' '.join(words)
+
+ def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
+ if frames is None:
+ frames = self.collector.getAudio()
+
+ # Convert audio to float32
+ audio = np.frombuffer(frames,
+ dtype=np.int16).flatten().astype(np.float32) / 32768.0
+
+ # Build context-aware prompt
+ prompt = self._build_prompt()
+
+ if self.cfg["enable_debug_mode"]:
+ print(f"Prompt: {prompt}", flush=True)
+
+ t0 = time.time()
+ segments, info = self.model.transcribe(
+ audio,
+ language = langcodes.find(self.cfg["language"]).language,
+ vad_filter = True,
+ temperature=0.0,
+ without_timestamps = False,
+ initial_prompt=prompt,
+ beam_size=self.cfg.get("beam_size", 5),
+ best_of=self.cfg.get("best_of", 5),
+ condition_on_previous_text=True
+ )
+ res = []
+ for s in segments:
+ # Manual touchup. I see a decent number of hallucinations sneaking
+ # in with high `no_speech_prob` and modest `avg_logprob`.
+ if s.no_speech_prob > 0.6 and s.avg_logprob < -0.5:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 1) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ continue
+ # Another touchup targeted at the vexatious "thanks for watching!"
+ # hallucination. This triggers a lot when listening to
+ # instrumental/electronic music.
+ if s.no_speech_prob > 0.15 and s.avg_logprob < -0.7:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 2) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ continue
+ if s.avg_logprob < -0.75:
+ if self.cfg["enable_debug_mode"]:
+ print(f"Drop probable hallucination (case 3) " +
+ f"(text='{s.text}', " +
+ f"no_speech_prob={s.no_speech_prob}, " +
+ f"avg_logprob={s.avg_logprob})", file=sys.stderr)
+ continue
+ if self.cfg["enable_debug_mode"]:
+ print(f"s get: {s}")
+ if s.avg_logprob < -1.0:
+ continue
+ if s.compression_ratio > 2.4:
+ continue
+ res.append(Segment(s.text, s.start, s.end,
+ self.collector.begin(),
+ s.avg_logprob, s.no_speech_prob, s.compression_ratio))
+ t1 = time.time()
+ if self.cfg["enable_debug_mode"]:
+ print(f"Transcription latency (s): {t1 - t0}")
+ return res
+
+ def _build_prompt(self) -> str:
+ """Build a context-aware prompt for Whisper."""
+ user_prompt = self.cfg["user_prompt"]
+ context_prompt = ""
+ if self.recent_context and len(self.recent_context) > 0:
+ context_prompt = f"Here is the context so far: {self.recent_context}"
+
+ prompts = [user_prompt, context_prompt]
+ prompts = [p for p in prompts if p and len(p) > 0]
+ return " ".join(prompts)
+
+class TranscriptCommit:
+ def __init__(self,
+ delta: str,
+ preview: str,
+ latency_s: float = None,
+ thresh_at_commit: int = None,
+ audio: bytes = None,
+ duration_s: float = None,
+ start_ts: float = None):
+ self.delta = delta
+ self.preview = preview
+ self.latency_s = latency_s
+ self.thresh_at_commit = thresh_at_commit
+ self.audio = audio
+ # Time at which the commit is generated
+ self.ts = time.time()
+ # Time corresponding to the start of the segment
+ self.start_ts = start_ts
+ # The duration of the audio segment, in seconds.
+ self.duration_s = duration_s
+
+
+def saveAudio(audio: bytes, path: str, cfg: typing.Dict):
+ with wave.open(path, 'wb') as wf:
+ if cfg["enable_debug_mode"]:
+ print(f"Saving audio to {path}", file=sys.stderr)
+ wf.setnchannels(AudioStream.CHANNELS)
+ wf.setsampwidth(AudioStream.FRAME_SZ)
+ wf.setframerate(AudioStream.FPS)
+ wf.writeframes(audio)
+
+
+class SegmentLogger:
+ def __init__(self, cfg: typing.Dict):
+ self.cfg = cfg
+ self.enabled = cfg.get("enable_segment_logging", False)
+ self.session_data = []
+ self.log_file = None
+
+ if self.enabled:
+ log_dir = os.path.join(PROJECT_ROOT, "logs")
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ # Create file
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.log_file = os.path.join(log_dir, f"session_debug_{timestamp}.json")
+ print(f"Segment logging enabled. Logging to: {self.log_file}", file=sys.stderr)
+
+ def log_segment(self, segment: Segment, commit_type: str = "commit"):
+ if not self.enabled:
+ return
+
+ segment_data = {
+ "timestamp": datetime.now().isoformat(),
+ "type": commit_type,
+ "text": segment.transcript,
+ "start_ts": segment.start_ts,
+ "end_ts": segment.end_ts,
+ "wall_ts": segment.wall_ts,
+ "avg_logprob": segment.avg_logprob,
+ "no_speech_prob": segment.no_speech_prob,
+ "compression_ratio": segment.compression_ratio,
+ "duration": segment.end_ts - segment.start_ts
+ }
+
+ self.session_data.append(segment_data)
+
+ # Write to file incrementally
+ try:
+ with open(self.log_file, 'w') as f:
+ json.dump({
+ "session_start": self.session_data[0]["timestamp"] if self.session_data else None,
+ "segments": self.session_data
+ }, f, indent=2)
+ except Exception as e:
+ print(f"Error writing segment log: {e}", file=sys.stderr)
+
+ def close(self):
+ if self.enabled and self.session_data:
+ print(f"Session complete. Logged {len(self.session_data)} segments to {self.log_file}", file=sys.stderr)
+
+
+class VadCommitter:
+ def __init__(self,
+ cfg: typing.Dict,
+ collector: AudioCollector,
+ whisper: Whisper,
+ segmenter: AudioSegmenter,
+ segment_logger: SegmentLogger = None):
+ self.cfg = cfg
+ self.collector = collector
+ self.whisper = whisper
+ self.segmenter = segmenter
+ self.segment_logger = segment_logger
+
+ def getDelta(self) -> TranscriptCommit:
+ audio = self.collector.getAudio()
+ stable_cutoff, has_audio = self.segmenter.getStableCutoff(audio)
+
+ delta = ""
+ commit_audio = None
+ latency_s = None
+ duration_s = self.collector.duration()
+ start_ts = self.collector.begin()
+
+ if has_audio and stable_cutoff:
+ latency_s = self.collector.now() - self.collector.begin()
+ duration_s = stable_cutoff / AudioStream.FPS
+ start_ts = self.collector.begin()
+
+ # Get the filtered audio first, then extract the portion we need
+ filtered_audio = self.collector.getAudio()
+ commit_audio = filtered_audio[:stable_cutoff * AudioStream.FRAME_SZ]
+
+ # Now drop the prefix from the collector
+ self.collector.dropAudioPrefixByFrames(stable_cutoff)
+
+ segments = self.whisper.transcribe(commit_audio)
+ delta = ''.join(s.transcript for s in segments)
+
+ # Update whisper's context with the committed text
+ if delta.strip():
+ self.whisper.update_context(delta.strip())
+
+ if self.segment_logger:
+ for s in segments:
+ self.segment_logger.log_segment(s, "commit")
+
+ audio = self.collector.getAudio()
+ if self.cfg["enable_debug_mode"]:
+ for s in segments:
+ print(f"commit segment: {s}", file=sys.stderr)
+ if len(delta) > 0:
+ print(f"delta get: {delta}", file=sys.stderr)
+
+ if self.cfg["save_audio"] and len(delta) > 0:
+ ts = datetime.fromtimestamp(self.collector.now() - latency_s)
+ filename = str(ts.strftime('%Y_%m_%d__%H-%M-%S')) + delta.strip() + ".wav"
+ audio_dir = os.path.join(PROJECT_ROOT, "audio")
+ if not os.path.exists(audio_dir):
+ os.makedirs(audio_dir)
+ saveAudio(commit_audio, os.path.join(audio_dir, filename), self.cfg)
+
+ preview = ""
+ if self.cfg["enable_previews"] and has_audio:
+ segments = self.whisper.transcribe(audio)
+ preview = "".join(s.transcript for s in segments)
+
+ if self.segment_logger:
+ for s in segments:
+ self.segment_logger.log_segment(s, "preview")
+
+ if not has_audio:
+ self.collector.keepLast(1.0)
+
+ return TranscriptCommit(
+ delta.strip(),
+ preview.strip(),
+ latency_s,
+ audio=audio,
+ duration_s=duration_s,
+ start_ts=start_ts)
+
+
+class StreamingPlugin:
+ def __init__(self):
+ pass
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ return commit
+
+ def stop(self):
+ pass
+
+
+class LowercasePlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_lowercase_filter"]:
+ commit.delta = commit.delta.lower()
+ commit.preview = commit.preview.lower()
+ return commit
+
+
+class UppercasePlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_uppercase_filter"]:
+ commit.delta = commit.delta.upper()
+ commit.preview = commit.preview.upper()
+ return commit
+
+
+class ProfanityPlugin(StreamingPlugin):
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.filter = None
+ if PROFANITY_FILTER_AVAILABLE and cfg["enable_profanity_filter"]:
+ en_profanity_path = os.path.join(PROJECT_ROOT, "Third_Party/Profanity/en")
+ try:
+ self.filter = ProfanityFilter(en_profanity_path)
+ self.filter.load()
+ except Exception as e:
+ print(f"Warning: Could not load profanity filter: {e}", file=sys.stderr)
+ self.filter = None
+
+ def transform(self, commit: TranscriptCommit) -> TranscriptCommit:
+ if self.cfg["enable_profanity_filter"] and self.filter:
+ commit.delta = self.filter.filter(commit.delta)
+ commit.preview = self.filter.filter(commit.preview)
+ return commit
+
+
+class PresentationFilter:
+ def __init__(self):
+ pass
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ return transcript, preview
+
+ def stop(self):
+ pass
+
+
+class TrailingPeriodFilter(PresentationFilter):
+ def __init__(self, cfg):
+ self.cfg = cfg
+
+ def transform(self, transcript: str, preview: str) -> typing.Tuple[str, str]:
+ if self.cfg["remove_trailing_period"]:
+ def _remove_trailing_period(s: str) -> str:
+ if len(s) > 0 and s[-1] == '.' and not s.endswith("..."):
+ s = s[0:len(s)-1]
+ return s
+ if len(preview) == 0:
+ transcript = _remove_trailing_period(transcript)
+ else:
+ preview = _remove_trailing_period(preview)
+ return transcript, preview
+
+
+def transcriptionThread(shared_data: SharedThreadData):
+ last_stable_commit = None
+
+ stream = MicStream(shared_data.cfg)
+ collector = AudioCollector(stream)
+ collector = CompressingAudioCollector(collector)
+ collector = BoostingAudioCollector(collector, -16.0, 24.0,
+ shared_data.cfg)
+ collector = NoiseReducingAudioCollector(collector, shared_data.cfg)
+ #collector = NormalizingAudioCollector(collector)
+ whisper = Whisper(collector, shared_data.cfg)
+ segmenter = AudioSegmenter(min_silence_ms=shared_data.cfg["min_silence_duration_ms"],
+ max_speech_s=shared_data.cfg["max_speech_duration_s"],
+ min_speech_duration_ms=shared_data.cfg["min_speech_duration_ms"])
+
+ segment_logger = SegmentLogger(shared_data.cfg)
+ committer = VadCommitter(shared_data.cfg, collector, whisper, segmenter, segment_logger)
+
+ plugins = []
+ # plugins.append(TranslationPlugin(shared_data.cfg)) # Not implemented yet
+ plugins.append(UppercasePlugin(shared_data.cfg))
+ plugins.append(LowercasePlugin(shared_data.cfg))
+ plugins.append(ProfanityPlugin(shared_data.cfg))
+ # plugins.append(UwuPlugin(shared_data.cfg)) # Not implemented yet
+ # plugins.append(BrowserSource(shared_data.cfg)) # Not implemented yet
+
+ filters = []
+ filters.append(TrailingPeriodFilter(shared_data.cfg))
+
+ transcript = ""
+ preview = ""
+
+ with shared_data.word_lock:
+ shared_data.stream = stream
+ shared_data.collector = collector
+
+ print(f"Ready to go!", flush=True)
+
+ while not shared_data.exit_event.is_set():
+ time.sleep(shared_data.cfg["transcription_loop_delay_ms"] / 1000.0);
+
+ op = None
+
+ commit = committer.getDelta()
+
+ with shared_data.word_lock:
+ for plugin in plugins:
+ commit = plugin.transform(commit)
+
+ if len(commit.delta) > 0 or len(commit.preview) > 0:
+ # Avoid re-sending text after long pauses
+ if shared_data.cfg["reset_after_silence_s"] > 0:
+ silence_duration = 0
+ if last_stable_commit:
+ last_commit_end_ts = \
+ last_stable_commit.start_ts + \
+ last_stable_commit.duration_s
+ silence_duration = commit.start_ts - last_commit_end_ts
+ if silence_duration > shared_data.cfg["reset_after_silence_s"]:
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"Resetting transcript after {silence_duration}-second "
+ "silence", file=sys.stderr)
+ shared_data.transcript = ""
+ shared_data.preview = ""
+ whisper.recent_context = "" # Reset context too
+ if commit.delta:
+ last_stable_commit = commit
+
+ # Hard-cap displayed transcript length to prevent
+ # runaway memory use in UI. Keep the full transcript to avoid
+ # breaking OSC pager.
+ if len(shared_data.transcript) >= 1024:
+ shared_data.transcript = shared_data.transcript[-512:]
+ shared_data.transcript = \
+ join_segments(shared_data.transcript, commit.delta)
+ shared_data.preview = commit.preview
+
+ for filt in filters:
+ shared_data.transcript, shared_data.preview = \
+ filt.transform(shared_data.transcript,
+ shared_data.preview)
+
+ try:
+ print(f"Transcript: {shared_data.transcript}", flush=True)
+ except UnicodeEncodeError:
+ print("Failed to encode transcript - discarding delta",
+ file=sys.stderr)
+ continue
+ try:
+ print(f"Preview: {shared_data.preview}", flush=True)
+ except UnicodeEncodeError:
+ print("Failed to encode preview - discarding", file=sys.stderr)
+
+ if shared_data.cfg["enable_debug_mode"]:
+ print(f"commit latency: {commit.latency_s}", file=sys.stderr)
+ print(f"commit thresh: {commit.thresh_at_commit}",
+ file=sys.stderr)
+
+ if len(shared_data.transcript) > 0 and \
+ (not shared_data.transcript.endswith(' ')) and \
+ (not commit.delta.startswith(' ')):
+ commit.delta = ' ' + commit.delta
+ if len(commit.delta) > 0 and \
+ (not commit.delta.endswith(' ')) and \
+ (not commit.preview.startswith(' ')):
+ commit.preview = ' ' + commit.preview
+ for plugin in plugins:
+ plugin.stop()
+ for filt in filters:
+ filt.stop()
+ segment_logger.close()
+
diff --git a/bpe_dump.py b/bpe_dump.py
new file mode 100644
index 0000000..74ee08f
--- /dev/null
+++ b/bpe_dump.py
@@ -0,0 +1,61 @@
+import math
+import sentencepiece as spm
+
+def get_tokenizer():
+ model_path = "./custom_unigram_tokenizer_65k/unigram.model"
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ return sp
+
+tokenizer = get_tokenizer()
+
+print(f"vocabulary size: {tokenizer.get_piece_size()}")
+# Sentencepiece uses U+2581 (lower one eighth block) to indicate a space before
+# a subword.
+sp_space = chr(9601)
+tokens_with_non_ascii = set()
+subword_len_histo = dict()
+# The sum of the lengths of each subword in the vocabulary. These are rounded
+# up to 4 characters.
+vocab_len_4c_quantized = 0
+
+for i in range(tokenizer.get_piece_size()):
+ k = tokenizer.id_to_piece(i)
+ v = i
+ print(f" Original token ({v}): {repr(k)} ({' '.join(str(ord(k_c)) for k_c in k)})")
+ for k_c in k:
+ if ord(k_c) > 127 and ord(k_c) != 9601:
+ tokens_with_non_ascii.add(k)
+ break
+ k_processed = k.replace(sp_space, ' ')
+ if not k.startswith(sp_space) and k not in ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]:
+ k_processed = k
+ else:
+ k_processed = k_processed
+
+ current_len = len(k_processed)
+ if current_len in subword_len_histo:
+ subword_len_histo[current_len] += 1
+ else:
+ subword_len_histo[current_len] = 1
+
+ vocab_len_4c_quantized += math.ceil(current_len / 4.0) * 4.0
+ print(f" {v}: {k_processed}")
+
+print(f"Num tokens with non-ascii: {len(tokens_with_non_ascii)} ({100 * len(tokens_with_non_ascii) / tokenizer.get_piece_size():.2f})%")
+
+print(f"Subword length histogram:")
+avg_subword_len = 0
+total_pieces_for_avg = 0
+for k_len, v_count in sorted(subword_len_histo.items(), key=lambda x: x[0]):
+ avg_subword_len += k_len * v_count
+ total_pieces_for_avg += v_count
+ print(f" {k_len}: {v_count}")
+
+if total_pieces_for_avg > 0:
+ avg_subword_len /= total_pieces_for_avg
+ print(f"Average subword length: {avg_subword_len:.4f}")
+else:
+ print("Average subword length: N/A (no pieces analyzed)")
+
+print(f"Sum of all subword lengths, quantized to 4 character chunks: {vocab_len_4c_quantized}")
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 0000000..9cec4a3
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,32 @@
+compute_type: float16
+language: english
+model: turbo
+microphone: 4
+user_prompt: Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.
+keybind: ctrl+alt+x
+button_hand: right
+button_type: b
+gpu_idx: 0
+max_speech_duration_s: 10
+min_speech_duration_ms: 250
+min_silence_duration_ms: 100
+reset_after_silence_s: 15
+transcription_loop_delay_ms: 100
+block_width: 2
+num_blocks: 40
+rows: 10
+cols: 24
+beam_size: 5
+best_of: 5
+volume: 10
+enable_debug_mode: 0
+enable_previews: 1
+save_audio: 1
+enable_segment_logging: 1
+use_cpu: 0
+enable_lowercase_filter: 0
+enable_uppercase_filter: 0
+enable_profanity_filter: 0
+remove_trailing_period: 0
+reset_on_toggle: 1
+use_builtin: 1
diff --git a/generate_bpe_lut.py b/generate_bpe_lut.py
new file mode 100644
index 0000000..72b1201
--- /dev/null
+++ b/generate_bpe_lut.py
@@ -0,0 +1,119 @@
+from math import ceil, floor
+from PIL import Image
+from unidecode import unidecode
+import sentencepiece as spm
+
+IMG_RES = 512 # square image
+
+def get_tokenizer():
+ use_sentencepiece = True
+
+ if not use_sentencepiece:
+ from tokenizers import Tokenizer
+ tokenizer_json = "./custom_wordpiece_tokenizer_65k/tokenizer.json"
+ print(f"Loading Tokenizers library tokenizer from: {tokenizer_json}")
+ return Tokenizer.from_file(tokenizer_json)
+ else:
+ model_path = "./custom_unigram_tokenizer_65k/unigram.model"
+ print(f"Loading SentencePiece tokenizer from: {model_path}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
+ return sp
+
+def get_words():
+ tokenizer = get_tokenizer()
+
+ print(f"vocabulary size: {tokenizer.get_piece_size()}")
+ # sp_space = sentencepiece space.
+ # A special character sentencepiece uses to represent spaces before words.
+ sp_space = chr(9601)
+ words = []
+
+ # Accumulate words into a list, indexed by the token number. Sanitize them as
+ # you go.
+ for i in range(tokenizer.get_piece_size()):
+ word = tokenizer.id_to_piece(i)
+ tok = i
+ #print(f" Original token ({tok}): {repr(word)} ({' '.join(str(ord(c)) for c in word)})")
+ word_sanitized = ""
+ # Dirty hack: convert non-ASCII characters to nearest ASCII equivalent
+ for c in word:
+ if ord(c) > 127 and c != sp_space:
+ c_plain = unidecode(c)
+ print(f" Resolved {c} to {c_plain}")
+ word_sanitized += c_plain
+ else:
+ word_sanitized += c
+ # Replace sp_space with ' '
+ word_sanitized = word_sanitized.replace(sp_space, ' ')
+ #print(f" {tok}: {word_sanitized}")
+ words.append(word_sanitized)
+
+ # Special word: empty string. SentencePiece doesn't support this natively.
+ words.append('')
+
+ return words
+
+# Fold a flat index into a IMG_RESxIMG_RES box. Return the (x,y) coordinate of
+# the folded index.
+def fold_idx(flat_idx):
+ return (flat_idx % IMG_RES, int(floor(flat_idx / IMG_RES)))
+
+def unfold_idx(coord):
+ return coord[0] + coord[1] * IMG_RES
+
+assert unfold_idx(fold_idx(1533125)) == 1533125
+assert unfold_idx(fold_idx(8538235)) == 8538235
+assert fold_idx(unfold_idx((192,235))) == (192,235)
+assert fold_idx(unfold_idx((83,388))) == (83,388)
+
+def generate_lut(words, filename):
+ # Write the texture header.
+ black = (0, 0, 0, 255)
+ img = Image.new('RGBA', (IMG_RES, IMG_RES), black)
+
+ # The header is `len(words)` slots long. Thus the actual LUT content starts at
+ # the index `len(words)`.
+ pixel_data = img.load()
+ lut_ptr = len(words)
+ for i in range(0, len(words)):
+ # Get pointer to the actual word data.
+ tok_ptr = lut_ptr
+ tok_len = len(words[i])
+ rgba = ((tok_ptr >> 0) & 0xFF,
+ (tok_ptr >> 8) & 0xFF,
+ (tok_ptr >> 16) & 0xFF,
+ tok_len)
+ print(f"Writing {rgba} to {i} / {fold_idx(i)}")
+ idx_x, idx_y = fold_idx(i)
+ pixel_data[idx_x, idx_y] = rgba
+
+ for j in range(0, ceil(tok_len/4.0)):
+ quad_ptr = tok_ptr + j
+ tok_0 = ord(words[i][j*4])
+ tok_1 = ord(words[i][j*4+1] if tok_len > j*4+1 else ' ')
+ tok_2 = ord(words[i][j*4+2] if tok_len > j*4+2 else ' ')
+ tok_3 = ord(words[i][j*4+3] if tok_len > j*4+3 else ' ')
+ rgba = (tok_0, tok_1, tok_2, tok_3)
+ idx_x, idx_y = fold_idx(quad_ptr)
+ print(f" Writing {rgba} to {quad_ptr} / {fold_idx(quad_ptr)}")
+ pixel_data[idx_x, idx_y] = rgba
+
+ # Advance the LUT ptr. Since we store 4 chars per pixel (RGBA), we advance
+ # it by ceil(tok_len/4).
+ lut_ptr += int(ceil(tok_len/4.0))
+
+ pretty = False
+ if pretty:
+ for y in range(0, IMG_RES):
+ for x in range(0, IMG_RES):
+ rgba = pixel_data[x, y]
+ pixel_data[x, y] = (rgba[0], rgba[1], rgba[2], 255)
+
+ print(f"Saving to {filename}")
+ img.save(filename)
+
+if __name__ == "__main__":
+ words = get_words()
+ generate_lut(words, "bpe_lut.png")
diff --git a/generate_tokenizer.py b/generate_tokenizer.py
new file mode 100644
index 0000000..88903a8
--- /dev/null
+++ b/generate_tokenizer.py
@@ -0,0 +1,525 @@
+# !!! AI ARTIFACT !!!
+# This file was primarily written with AI.
+
+import os
+import argparse
+from datasets import load_dataset, interleave_datasets
+import sentencepiece as spm
+import itertools
+import random
+from unidecode import unidecode
+
+# --- Dataset 1: Wikipedia ---
+DATASET_NAME_WIKI = "wikipedia"
+DATASET_CONFIG_WIKI = "20220301.en"
+
+DATASET_SPLIT_WIKI = "train"
+TEXT_COLUMN_WIKI = "text"
+
+# --- Dataset 2: DailyDialog ---
+DATASET_NAME_DD = "daily_dialog"
+DATASET_SPLIT_DD = "train"
+UTTERANCE_COLUMN_DD = "dialog"
+
+# --- Dataset 3: BlendedSkillTalk (includes Persona-Chat) ---
+DATASET_NAME_BST = "blended_skill_talk"
+DATASET_SPLIT_BST = "train"
+
+# --- Dataset 4: OpenSubtitles ---
+DATASET_NAME_OS = "Helsinki-NLP/open_subtitles"
+DATASET_LANG_PAIR_OS = ("en", "fr") # Load en-fr and use the 'en' part
+DATASET_SPLIT_OS = "train"
+TEXT_COLUMN_OS = "en" # Access via item['translation']['en']
+
+# Reserve one space for special "empty" token.
+VOCAB_SIZE = 65535
+OUTPUT_DIR = "./custom_unigram_tokenizer_65k"
+MODEL_PREFIX = os.path.join(OUTPUT_DIR, "unigram")
+MODEL_FILE = MODEL_PREFIX + ".model"
+VOCAB_FILE = MODEL_PREFIX + ".vocab"
+
+UNK_TOKEN = "[UNK]"
+PAD_TOKEN = "[PAD]"
+CONTROL_SYMBOLS = ["[CLS]", "[SEP]", "[MASK]"]
+
+BATCH_SIZE = 1000
+
+def wiki_iterator(dataset_wiki, batch_size=BATCH_SIZE):
+ if dataset_wiki:
+ for i in range(0, len(dataset_wiki), batch_size):
+ yield [unidecode(text) for text in dataset_wiki[i : i + batch_size][TEXT_COLUMN_WIKI] if text]
+
+def dd_iterator(dataset_dd, batch_size=BATCH_SIZE):
+ if dataset_dd:
+ current_batch = []
+ for dialogue in dataset_dd:
+ utterances = dialogue[UTTERANCE_COLUMN_DD]
+ for utterance in utterances:
+ if utterance:
+ normalized_utterance = unidecode(utterance)
+ current_batch.append(normalized_utterance)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def bst_iterator(dataset_bst, batch_size=BATCH_SIZE):
+ if dataset_bst:
+ current_batch = []
+ for session in dataset_bst:
+ texts_to_add = []
+ if session.get("previous_utterance"):
+ texts_to_add.append(session["previous_utterance"])
+ if session.get("free_messages"):
+ texts_to_add.extend(session["free_messages"])
+ if session.get("guided_messages"):
+ texts_to_add.extend(session["guided_messages"])
+
+ for text in texts_to_add:
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ current_batch.append(normalized_text)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def os_iterator(dataset_os, batch_size=BATCH_SIZE, lang_code=TEXT_COLUMN_OS):
+ if dataset_os:
+ current_batch = []
+ for item in dataset_os:
+ text = item['translation'][lang_code]
+ if text:
+ normalized_text = unidecode(text)
+ current_batch.append(normalized_text)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def count_wiki_items_chars(dataset_wiki):
+ item_count = 0
+ char_count = 0
+ if dataset_wiki:
+ item_count = len(dataset_wiki)
+ try:
+ char_count = sum(len(unidecode(text)) for text in dataset_wiki[TEXT_COLUMN_WIKI] if text and isinstance(text, str))
+ except Exception:
+ print(f"Warning: Direct column access for char count failed for {DATASET_NAME_WIKI}. Iterating row by row (slower).")
+ char_count = sum(len(unidecode(row[TEXT_COLUMN_WIKI])) for row in dataset_wiki if row.get(TEXT_COLUMN_WIKI) and isinstance(row[TEXT_COLUMN_WIKI], str))
+
+ return item_count, char_count
+
+def count_dd_items_chars(dataset_dd):
+ item_count = 0
+ char_count = 0
+ if dataset_dd:
+ for dialogue in dataset_dd:
+ utterances = dialogue[UTTERANCE_COLUMN_DD]
+ for utterance in utterances:
+ if utterance and isinstance(utterance, str):
+ item_count += 1
+ char_count += len(unidecode(utterance))
+ return item_count, char_count
+
+def count_bst_items_chars(dataset_bst):
+ item_count = 0
+ char_count = 0
+ if dataset_bst:
+ for session in dataset_bst:
+ texts_to_process = []
+ # Gather all potential text strings first
+ prev_utt = session.get("previous_utterance")
+ if prev_utt and isinstance(prev_utt, list):
+ if len(prev_utt) > 0 and isinstance(prev_utt[0], str):
+ texts_to_process.append(prev_utt[0])
+ elif prev_utt and isinstance(prev_utt, str):
+ texts_to_process.append(prev_utt)
+
+ free_msgs = session.get("free_messages")
+ if free_msgs and isinstance(free_msgs, list):
+ for item in free_msgs:
+ if isinstance(item, list):
+ texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
+ elif isinstance(item, str):
+ texts_to_process.append(item)
+
+ guided_msgs = session.get("guided_messages")
+ if guided_msgs and isinstance(guided_msgs, list):
+ for item in guided_msgs:
+ if isinstance(item, list):
+ texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
+ elif isinstance(item, str):
+ texts_to_process.append(item)
+
+ # Count items and chars from the gathered list
+ for text in texts_to_process:
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ if normalized_text:
+ item_count += 1
+ char_count += len(normalized_text)
+
+ return item_count, char_count
+
+def count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS):
+ item_count = 0
+ char_count = 0
+ if dataset_os:
+ for item in dataset_os:
+ text = item['translation'][lang_code]
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ if normalized_text: # Ensure not empty after unidecode
+ item_count += 1
+ char_count += len(normalized_text)
+ return item_count, char_count
+
+def load_and_count_datasets(wiki_fraction, subtitles_fraction):
+ """Loads, potentially shrinks, and counts items/chars in datasets."""
+ datasets = {}
+ counts = {}
+ total_items = 0
+ total_chars = 0
+
+ # --- Wikipedia ---
+ print(f"Loading dataset 1: {DATASET_NAME_WIKI} ({DATASET_CONFIG_WIKI}), split: {DATASET_SPLIT_WIKI}")
+ dataset_wiki_full = load_dataset(DATASET_NAME_WIKI, DATASET_CONFIG_WIKI, split=DATASET_SPLIT_WIKI, trust_remote_code=True)
+ print(f"Original Wikipedia dataset size: {len(dataset_wiki_full):,}")
+
+ # Shrink the Wikipedia dataset
+ split_test_size = 1.0 - wiki_fraction
+ shrunk_dataset_split = dataset_wiki_full.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))
+ dataset_wiki = shrunk_dataset_split['train']
+ print(f"Using {wiki_fraction*100:.3f}% of Wikipedia dataset: {len(dataset_wiki):,} items")
+
+ count_wiki, chars_wiki = count_wiki_items_chars(dataset_wiki)
+ print(f"Wikipedia dataset loaded (shrunk). Precise text items: {count_wiki:,}, Characters: {chars_wiki:,}")
+ datasets['wiki'] = dataset_wiki
+ counts['wiki'] = (count_wiki, chars_wiki)
+ total_items += count_wiki
+ total_chars += chars_wiki
+
+ # --- DailyDialog ---
+ print(f"\nLoading dataset 2: {DATASET_NAME_DD}, split: {DATASET_SPLIT_DD}")
+ dataset_dd = load_dataset(DATASET_NAME_DD, split=DATASET_SPLIT_DD, trust_remote_code=True)
+ count_dd, chars_dd = count_dd_items_chars(dataset_dd)
+ print(f"DailyDialog dataset loaded. Precise text items (utterances): {count_dd:,}, Characters: {chars_dd:,}")
+ datasets['dd'] = dataset_dd
+ counts['dd'] = (count_dd, chars_dd)
+ total_items += count_dd
+ total_chars += chars_dd
+
+ # --- BlendedSkillTalk ---
+ print(f"\nLoading dataset 3: {DATASET_NAME_BST}, split: {DATASET_SPLIT_BST}")
+ dataset_bst = load_dataset(DATASET_NAME_BST, split=DATASET_SPLIT_BST, trust_remote_code=True)
+ count_bst, chars_bst = count_bst_items_chars(dataset_bst)
+ print(f"BlendedSkillTalk dataset loaded. Precise text items (extracted): {count_bst:,}, Characters: {chars_bst:,}")
+ datasets['bst'] = dataset_bst
+ counts['bst'] = (count_bst, chars_bst)
+ total_items += count_bst
+ total_chars += chars_bst
+
+ # --- OpenSubtitles ---
+ print(f"\nLoading dataset 4: {DATASET_NAME_OS}, lang_pair: {DATASET_LANG_PAIR_OS}, split: {DATASET_SPLIT_OS}")
+ # Note: OpenSubtitles can be very large. Consider streaming or specific configurations if memory is an issue.
+ # For now, loading a standard configuration.
+ dataset_os = load_dataset(DATASET_NAME_OS, lang1=DATASET_LANG_PAIR_OS[0], lang2=DATASET_LANG_PAIR_OS[1], split=DATASET_SPLIT_OS, trust_remote_code=True)
+ split_test_size = 1.0 - subtitles_fraction
+ dataset_os = dataset_os.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))['train']
+ count_os, chars_os = count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS)
+ print(f"OpenSubtitles dataset loaded. Precise text items: {count_os:,}, Characters: {chars_os:,}")
+ datasets['os'] = dataset_os
+ counts['os'] = (count_os, chars_os)
+ total_items += count_os
+ total_chars += chars_os
+
+ print(f"\nTotal precise text items from loaded datasets: {total_items:,}")
+ print(f"Total characters from loaded datasets: {total_chars:,}")
+
+ return datasets, counts
+
+def train_tokenizer(model_prefix, datasets, counts):
+ """Trains a Unigram tokenizer using SentencePiece and saves it."""
+ total_items = sum(c[0] for c in counts.values())
+ total_chars = sum(c[1] for c in counts.values())
+ if total_items == 0:
+ print("Error: No items found in the datasets to train on. Exiting training.")
+ return
+
+ print(f"\nTotal precise text items for training: {total_items:,}")
+ print(f"Total characters for training: {total_chars:,}")
+
+ output_dir = os.path.dirname(model_prefix)
+ os.makedirs(output_dir, exist_ok=True)
+
+ print(f"\nStarting SentencePiece Unigram tokenizer training with vocab size: {VOCAB_SIZE}")
+ print(f"Using combined dataset iterator.")
+ print(f"Output model prefix: {model_prefix}")
+
+ iterators_to_chain = []
+ def flatten_iterator(iterator):
+ for batch in iterator:
+ for item in batch:
+ yield item
+
+ if datasets.get('wiki') and counts['wiki'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(wiki_iterator(datasets['wiki'])))
+ if datasets.get('dd') and counts['dd'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(dd_iterator(datasets['dd'])))
+ if datasets.get('bst') and counts['bst'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(bst_iterator(datasets['bst'])))
+ if datasets.get('os') and counts['os'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(os_iterator(datasets['os'])))
+
+ if not iterators_to_chain:
+ print("Error: No valid dataset iterators available for training. Exiting.")
+ return
+
+ combined_iterator = itertools.chain(*iterators_to_chain)
+
+ # Include whitespace symbols so we can efficiently break lines.
+ # If we include the single space, it prevents the tokenizer from merging
+ # spaces with regular words, and tanks the average chars/token.
+ # This many tokens is kinda overkill, but it gives us a way to efficiently
+ # clear even fairly large boards, so I think it's worth.
+ whitespace_symbols = []
+ for i in range(2, 40):
+ whitespace_symbols.append('▁' * i)
+
+ spm.SentencePieceTrainer.train(
+ sentence_iterator=combined_iterator,
+ model_prefix=model_prefix,
+ vocab_size=VOCAB_SIZE,
+ model_type='unigram',
+ character_coverage=1.0,
+ unk_piece=UNK_TOKEN,
+ pad_piece=PAD_TOKEN,
+ control_symbols=CONTROL_SYMBOLS,
+ user_defined_symbols=whitespace_symbols,
+ # These whitespace options must be false, or else whitespace won't
+ # be respected when encoding.
+ add_dummy_prefix=False,
+ remove_extra_whitespaces=False,
+ split_by_whitespace=False,
+ num_threads=os.cpu_count(),
+ input_sentence_size=total_items,
+ )
+ print("\nTraining finished.")
+ print(f"SentencePiece model saved to: {model_prefix}.model")
+ print(f"SentencePiece vocabulary saved to: {model_prefix}.vocab")
+
+def extract_text_samples(dataset, count, num_samples, text_extractor_func):
+ """Extracts a specified number of text samples from a dataset."""
+ samples = []
+ if dataset is None or count == 0 or num_samples == 0:
+ return samples
+ # Take samples from the beginning, ensure we don't exceed dataset size
+ actual_samples_to_take = min(num_samples, count)
+ # Use the provided function to extract text correctly for this dataset type
+ samples = text_extractor_func(dataset.select(range(actual_samples_to_take)))
+ return [unidecode(s) for s in samples if s and isinstance(s,str)]
+
+def wiki_text_extractor(dataset_slice):
+ """Extracts text from a slice of the Wikipedia dataset."""
+ return [text for text in dataset_slice[TEXT_COLUMN_WIKI] if text]
+
+def dd_text_extractor(dataset_slice):
+ """Extracts text from a slice of the DailyDialog dataset."""
+ texts = []
+ for dialogue in dataset_slice:
+ texts.extend(utterance for utterance in dialogue[UTTERANCE_COLUMN_DD] if utterance)
+ return texts
+
+def bst_text_extractor(dataset_slice):
+ """Extracts text from a slice of the BlendedSkillTalk dataset."""
+ texts = []
+ for session in dataset_slice:
+ if session.get("previous_utterance"):
+ texts.append(session["previous_utterance"])
+ if session.get("free_messages"):
+ texts.extend(session["free_messages"])
+ if session.get("guided_messages"):
+ texts.extend(session["guided_messages"])
+ return [t for t in texts if t and isinstance(t,str)] # Ensure only valid strings
+
+def os_text_extractor(dataset_slice, lang_code=TEXT_COLUMN_OS):
+ """Extracts text from a slice of the OpenSubtitles dataset."""
+ texts = []
+ for item in dataset_slice:
+ text = item['translation'][lang_code]
+ if text and isinstance(text, str):
+ texts.append(text)
+ return texts
+
+def test_tokenizer(model_path, datasets, counts, sample_size=1000):
+ print("\n--- Testing the trained Unigram (SentencePiece) tokenizer on data sample ---")
+ if sample_size <= 0:
+ print("Error: Sample size for testing must be positive.")
+ return
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ print(f"Successfully loaded SentencePiece model from: {model_path}")
+ print(f"Vocabulary size: {sp.get_piece_size()}")
+
+ # Identify available datasets with content
+ available_datasets = {
+ name: data
+ for name, data in datasets.items()
+ if counts[name][0] > 0
+ }
+ num_available_datasets = len(available_datasets)
+
+ if num_available_datasets == 0:
+ print("Warning: No data available in loaded datasets to sample for testing.")
+ return
+
+ print(f"Found {num_available_datasets} non-empty dataset(s) for testing.")
+ print(f"Attempting to sample up to {sample_size} total items equally from these datasets...")
+
+ # Calculate equal number of samples per available dataset
+ samples_per_dataset = sample_size // num_available_datasets
+ print(f"Targeting approximately {samples_per_dataset} samples per dataset.")
+
+ test_samples = []
+ dataset_extractors = {
+ 'wiki': wiki_text_extractor,
+ 'dd': dd_text_extractor,
+ 'bst': bst_text_extractor,
+ 'os': os_text_extractor,
+ }
+
+ actual_samples_collected = {}
+
+ # Sample equally from each available dataset
+ for name, dataset in available_datasets.items():
+ count = counts[name][0]
+ extractor = dataset_extractors[name]
+
+ # Determine the target number of samples for *this* dataset
+ num_samples_to_target = min(samples_per_dataset, count)
+ if num_samples_to_target <= 0:
+ print(f" Skipping '{name}' (no items requested or available).")
+ actual_samples_collected[name] = 0
+ continue
+
+ print(f" Sampling from '{name}' (target: {num_samples_to_target} items)...")
+
+ items_before = len(test_samples)
+ # For OpenSubtitles, pass the lang_code if the extractor needs it
+ if name == 'os':
+ extracted_items = extract_text_samples(dataset, count, num_samples_to_target, lambda ds_slice: extractor(ds_slice, lang_code=TEXT_COLUMN_OS))
+ else:
+ extracted_items = extract_text_samples(dataset, count, num_samples_to_target, extractor)
+
+ # Limit the extracted items to the target number
+ final_samples_for_dataset = extracted_items[:num_samples_to_target]
+
+ test_samples.extend(final_samples_for_dataset)
+ items_added = len(final_samples_for_dataset)
+ actual_samples_collected[name] = items_added
+ print(f" Added {items_added} items from '{name}'.")
+
+ actual_sample_size = len(test_samples)
+ if actual_sample_size == 0:
+ print("\nCould not gather any samples for testing.")
+ print("--- Test finished ---")
+ return
+
+ print(f"\n--- Starting Test Run ---")
+ print(f"Testing on {actual_sample_size} sampled text items (final count).")
+ print(f"Samples breakdown: {actual_samples_collected}")
+
+ total_chars = 0
+ total_tokens = 0
+ examples_to_show = 5
+ examples_shown = 0
+
+ random.shuffle(test_samples)
+
+ for i, text_sample in enumerate(test_samples):
+ if not text_sample: # Should already be filtered by extractor, but double-check
+ continue
+ try:
+ tokens = sp.encode_as_pieces(text_sample)
+ num_tokens = len(tokens)
+ num_chars = len(text_sample)
+ total_tokens += num_tokens
+ total_chars += num_chars
+
+ if examples_shown < examples_to_show:
+ print(f"\nSample {examples_shown + 1} (Overall index {i}):") # Clarify sample index
+ print(f" Text ({num_chars} chars): {text_sample[:100]}{'...' if len(text_sample)>100 else ''}")
+ print(f" Tokens ({num_tokens}): {tokens}")
+ examples_shown += 1
+ except Exception as e:
+ print(f"\nWarning: Error encoding sample (Overall index {i}) with SentencePiece. Skipping sample. Error: {e}")
+ print(f"Problematic sample text: {text_sample[:100]}...") # Show problematic text
+
+ # --- Test Summary ---
+ if total_tokens > 0 and total_chars > 0 : # Avoid division by zero
+ avg_chars_per_token = total_chars / total_tokens
+ print(f"\n--- Test Summary ---")
+ print(f"Tested on {actual_sample_size} text items.")
+ print(f"Final samples breakdown: {actual_samples_collected}")
+ print(f"Total characters in final sample: {total_chars:,}")
+ print(f"Total tokens in final sample: {total_tokens:,}")
+ print(f"Average characters per token: {avg_chars_per_token:.4f}")
+ elif actual_sample_size > 0:
+ print("\n--- Test Summary ---")
+ print(f"Tested on {actual_sample_size} text items.")
+ print(f"Final samples breakdown: {actual_samples_collected}")
+ print("No valid tokens were generated from the sample data (or samples were empty after unidecode).")
+ else:
+ # This case should be caught earlier, but included for completeness
+ print("\n--- Test Summary ---")
+ print("No samples were collected or processed.")
+
+ print("--- Test finished ---")
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train and test a Unigram tokenizer on combined datasets.")
+ parser.add_argument("--train", action="store_true", help="Train the tokenizer on datasets")
+ parser.add_argument("--test", action="store_true", help="Test the trained tokenizer")
+ parser.add_argument("--sample_size", type=int, default=1000, help="Number of items to sample for testing")
+ parser.add_argument("--wiki_fraction", type=float, default=0.05, help="Fraction of Wikipedia dataset to use (e.g., 0.05 for 5%)")
+ parser.add_argument("--subtitles_fraction", type=float,
+ default=0.05, help="Fraction of opensubtitles dataset to use (e.g., 0.05 for 5%)")
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ # Load datasets and calculate counts once
+ print("--- Loading and Counting Datasets ---")
+ datasets, counts = load_and_count_datasets(wiki_fraction=args.wiki_fraction, subtitles_fraction=args.subtitles_fraction)
+ print("--- Dataset Loading Finished ---")
+
+ run_train = args.train
+ run_test = args.test
+
+ # If no arguments provided, run both by default
+ if not args.train and not args.test:
+ run_train = True
+ run_test = True
+
+ if run_train:
+ print("\n--- Starting Tokenizer Training ---")
+ train_tokenizer(MODEL_PREFIX, datasets, counts)
+ print("--- Training Finished ---")
+
+ if run_test:
+ # Ensure tokenizer file exists if training didn't just run
+ if not os.path.exists(MODEL_FILE):
+ print(f"\nError: SentencePiece model file {MODEL_FILE} not found. Cannot run test.")
+ print("Please run with --train first or ensure the file exists.")
+ else:
+ print("\n--- Starting Tokenizer Testing ---")
+ test_tokenizer(MODEL_FILE, datasets, counts, sample_size=args.sample_size)
+ print("--- Testing Finished ---")
+
+ print("\nScript finished.")
diff --git a/tokenize_me.py b/tokenize_me.py
new file mode 100644
index 0000000..83be290
--- /dev/null
+++ b/tokenize_me.py
@@ -0,0 +1,28 @@
+import argparse
+import sentencepiece as spm
+
+def get_tokenizer():
+ model_path = "./custom_unigram_tokenizer_65k/unigram.model"
+ print(f"Loading SentencePiece tokenizer from: {model_path}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}")
+ return sp
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Tokenize a given string using a SentencePiece model.")
+ parser.add_argument("text", type=str, help="The string to tokenize.")
+ args = parser.parse_args()
+ return args
+
+args = parse_args()
+tok = get_tokenizer()
+tokens = tok.encode_as_pieces(args.text)
+print("Tokens:", tokens)
+
+token_ids = tok.encode_as_ids(args.text)
+print("Token IDs:", token_ids)
+
+# Split each token ID into two 8-bit chunks (high byte, low byte)
+byte_pairs = [(tid >> 8, tid & 0xFF) for tid in token_ids]
+print("Token ID Byte Pairs:", byte_pairs)
diff --git a/ui/.gitignore b/ui/.gitignore
new file mode 100644
index 0000000..c1dbe3c
--- /dev/null
+++ b/ui/.gitignore
@@ -0,0 +1,5 @@
+build
+node_modules
+package-lock.json
+output.css
+dist
diff --git a/ui/build_scripts/setup-empty-venv.js b/ui/build_scripts/setup-empty-venv.js
new file mode 100644
index 0000000..0691a51
--- /dev/null
+++ b/ui/build_scripts/setup-empty-venv.js
@@ -0,0 +1,25 @@
+const { execSync } = require('child_process');
+const path = require('path');
+const fs = require('fs');
+
+const projectRoot = path.join(__dirname, '..', '..');
+const venvPath = path.join(projectRoot, 'venv_clean');
+const dllPath = path.join(projectRoot, 'dll_empty');
+
+console.log('Creating empty virtual environment and dll directory...');
+
+// Create empty dll directory
+if (!fs.existsSync(dllPath)) {
+ fs.mkdirSync(dllPath, { recursive: true });
+ console.log('Created empty dll directory');
+}
+
+try {
+ console.log('Creating new venv...');
+ execSync(`python -m venv "${venvPath}"`, { stdio: 'inherit' });
+ console.log('Empty venv created successfully!');
+} catch (error) {
+ console.error('Failed to create venv:', error);
+ process.exit(1);
+}
+
diff --git a/ui/config-schema.js b/ui/config-schema.js
new file mode 100644
index 0000000..fb90f3f
--- /dev/null
+++ b/ui/config-schema.js
@@ -0,0 +1,57 @@
+// Shared configuration schema with types and defaults
+const CONFIG_SCHEMA = {
+ // String fields
+ compute_type: { type: 'select', default: 'float16' },
+ language: { type: 'select', default: 'english' },
+ model: { type: 'select', default: 'turbo' },
+ microphone: { type: 'number', default: 0 },
+ user_prompt: { type: 'text', default: 'Use proper punctuation and grammar. Prefer spelled out numbers like one, eleven, twenty, etc. Mm.' },
+ keybind: { type: 'text', default: 'ctrl+alt+x' },
+ button_hand: { type: 'select', default: 'right' },
+ button_type: { type: 'select', default: 'b' },
+
+ // Number fields
+ gpu_idx: { type: 'number', default: 0 },
+ max_speech_duration_s: { type: 'number', default: 10 },
+ min_speech_duration_ms: { type: 'number', default: 250 },
+ min_silence_duration_ms: { type: 'number', default: 100 },
+ reset_after_silence_s: { type: 'number', default: 15 },
+ transcription_loop_delay_ms: { type: 'number', default: 100 },
+ block_width: { type: 'number', default: 2 },
+ num_blocks: { type: 'number', default: 40 },
+ rows: { type: 'number', default: 10 },
+ cols: { type: 'number', default: 24 },
+ beam_size: { type: 'number', default: 5 },
+ best_of: { type: 'number', default: 5 },
+ volume: { type: 'number', default: 30 },
+
+ // Boolean fields (stored as 1/0)
+ enable_debug_mode: { type: 'boolean', default: 0 },
+ enable_previews: { type: 'boolean', default: 1 },
+ save_audio: { type: 'boolean', default: 0 },
+ enable_segment_logging: { type: 'boolean', default: 0 },
+ use_cpu: { type: 'boolean', default: 0 },
+ enable_lowercase_filter: { type: 'boolean', default: 0 },
+ enable_uppercase_filter: { type: 'boolean', default: 0 },
+ enable_profanity_filter: { type: 'boolean', default: 0 },
+ remove_trailing_period: { type: 'boolean', default: 0 },
+ reset_on_toggle: { type: 'boolean', default: 0 },
+ use_builtin: { type: 'boolean', default: 1 }
+};
+
+// Helper to extract just the default values
+function getDefaultConfig() {
+ const defaults = {};
+ for (const [key, schema] of Object.entries(CONFIG_SCHEMA)) {
+ defaults[key] = schema.default;
+ }
+ return defaults;
+}
+
+// Export for both CommonJS (main process) and ES modules (renderer)
+if (typeof module !== 'undefined' && module.exports) {
+ module.exports = { CONFIG_SCHEMA, getDefaultConfig };
+} else {
+ window.CONFIG_SCHEMA = CONFIG_SCHEMA;
+ window.getDefaultConfig = getDefaultConfig;
+} \ No newline at end of file
diff --git a/ui/index.html b/ui/index.html
new file mode 100644
index 0000000..70eaa68
--- /dev/null
+++ b/ui/index.html
@@ -0,0 +1,349 @@
+<!DOCTYPE html>
+<html>
+<head>
+ <meta charset="UTF-8">
+ <meta http-equiv="Content-Security-Policy" content="default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'">
+ <title>TaSTT</title>
+ <link rel="stylesheet" href="output.css">
+</head>
+<body class="bg-gray-100">
+ <div class="container-fluid px-6 py-6 h-screen flex flex-col">
+ <div class="flex flex-1 gap-6 overflow-hidden">
+ <!-- Left Panel: Configuration Form -->
+ <div class="max-w-96 relative flex flex-col overflow-hidden rounded-lg">
+ <!-- Loading Overlay -->
+ <div id="loading-overlay" class="absolute inset-0 bg-white bg-opacity-75 backdrop-blur-sm z-50 hidden flex items-center justify-center">
+ <div class="text-center p-6">
+ <div class="animate-spin rounded-full h-12 w-12 border-b-2 border-blue-600 mx-auto mb-4"></div>
+ <p class="text-gray-700 font-medium"></p>
+ </div>
+ </div>
+
+ <!-- Scrollable form container -->
+ <div class="overflow-y-auto flex-1">
+ <form id="config-form" class="space-y-6 pr-3">
+ <!-- Basic settings (Always Visible) -->
+ <section class="config-section">
+ <div class="grid grid-cols-2 gap-4">
+ <div>
+ <label for="model" class="form-label">Model</label>
+ <select id="model" class="form-input">
+ <option value="tiny">tiny</option>
+ <option value="base">base</option>
+ <option value="small">small</option>
+ <option value="medium">medium</option>
+ <option value="large">large</option>
+ <option value="turbo">turbo</option>
+ </select>
+ </div>
+ <div>
+ <label for="language" class="form-label">Language</label>
+ <select id="language" class="form-input">
+ <option value="english">English</option>
+ <option value="spanish">Spanish</option>
+ <option value="french">French</option>
+ <option value="german">German</option>
+ <option value="italian">Italian</option>
+ <option value="portuguese">Portuguese</option>
+ <option value="russian">Russian</option>
+ <option value="chinese">Chinese</option>
+ <option value="japanese">Japanese</option>
+ <option value="korean">Korean</option>
+ </select>
+ </div>
+ <div class="col-span-2">
+ <label for="microphone" class="form-label">Microphone</label>
+ <div class="flex gap-2">
+ <select id="microphone" class="form-input flex-1">
+ <option value="">Loading microphones...</option>
+ </select>
+ <button type="button" id="refresh-microphones" class="btn btn-gray px-3 py-2 flex items-center" title="Refresh microphone list">
+ <svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"/>
+ </svg>
+ </button>
+ </div>
+ </div>
+ <div>
+ <label for="button_hand" class="form-label">
+ VR Hand
+ </label>
+ <select id="button_hand" class="form-input">
+ <option value="left">Left</option>
+ <option value="right">Right</option>
+ </select>
+ </div>
+ <div>
+ <label for="button_type" class="form-label">
+ VR Button
+ </label>
+ <select id="button_type" class="form-input">
+ <option value="a">A</option>
+ <option value="b">B</option>
+ <option value="thumbstick">Thumbstick</option>
+ </select>
+ </div>
+ <div class="col-span-2">
+ <label for="keybind" class="form-label">
+ Keyboard Binding
+ </label>
+ <input type="text" id="keybind" value="f24" class="form-input" placeholder="f24">
+ </div>
+ </div>
+ </section>
+
+ <!-- Advanced settings toggle -->
+ <button type="button" id="toggle-advanced" class="flex items-center gap-2 text-gray-600 hover:text-gray-800 font-medium">
+ <svg id="chevron" class="w-5 h-5 transform transition-transform duration-200" fill="none" stroke="currentColor" viewBox="0 0 24 24">
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 5l7 7-7 7"/>
+ </svg>
+ Advanced Settings
+ </button>
+
+ <!-- Advanced settings (initially hidden) -->
+ <div id="advanced-settings" class="hidden space-y-6">
+ <!-- Compute Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Compute Settings</h2>
+ <div class="grid grid-cols-2 gap-4">
+ <div>
+ <label for="compute_type" class="form-label">Compute Type</label>
+ <select id="compute_type" class="form-input">
+ <option value="int8">int8</option>
+ <option value="float16">float16</option>
+ <option value="float32">float32</option>
+ </select>
+ </div>
+ <div>
+ <label for="gpu_idx" class="form-label">GPU Index</label>
+ <input type="number" id="gpu_idx" min="0" value="0" class="form-input">
+ </div>
+ <div class="col-span-2">
+ <label for="use_cpu" class="checkbox-label">
+ <input type="checkbox" id="use_cpu" class="mr-2">
+ <span class="checkbox-text">Use CPU</span>
+ </label>
+ </div>
+ </div>
+ </section>
+
+ <!-- Audio Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Voice Activity Detection</h2>
+ <div class="grid grid-cols-2 gap-4">
+ <div>
+ <label for="max_speech_duration_s" class="form-label">Max Speech Duration (seconds)</label>
+ <input type="number" id="max_speech_duration_s" min="1" value="10" class="form-input">
+ </div>
+ <div>
+ <label for="min_speech_duration_ms" class="form-label">Min Speech Duration (ms)</label>
+ <input type="number" id="min_speech_duration_ms" min="0" value="100" class="form-input">
+ </div>
+ <div>
+ <label for="min_silence_duration_ms" class="form-label">Min Silence Duration (ms)</label>
+ <input type="number" id="min_silence_duration_ms" min="0" value="250" class="form-input">
+ </div>
+ <div>
+ <label for="reset_after_silence_s" class="form-label">Reset After Silence (seconds)</label>
+ <input type="number" id="reset_after_silence_s" min="1" value="15" class="form-input">
+ </div>
+ </div>
+ </section>
+
+ <!-- Transcription Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Transcription Settings</h2>
+ <div>
+ <label for="user_prompt" class="form-label">
+ Prompt
+ <span class="text-gray-500 text-xs block mt-1"
+ title="Whisper is given this prompt before transcribing. It helps guide the transcription style. For example, you could improve the spelling of your friends' names with: 'My friends' names are Saoirse, Azariah, and Caoimhe.'">
+ (Hover for details)
+ </span>
+ </label>
+ <textarea id="user_prompt"
+ class="form-input h-20 resize-none"
+ placeholder="My friends' names are Saoirse, Azariah, and Caoimhe."></textarea>
+ </div>
+ <div class="grid grid-cols-2 gap-4 mt-4">
+ <div>
+ <label for="beam_size" class="form-label">
+ Beam size
+ <span class="text-gray-500 text-xs block mt-1"
+ title="Number of beams for beam search. Higher values may improve accuracy but increase compute time.">
+ (Search width)
+ </span>
+ </label>
+ <input type="number" id="beam_size" min="1" max="20" value="5" class="form-input">
+ </div>
+ <div>
+ <label for="best_of" class="form-label">
+ Best of
+ <span class="text-gray-500 text-xs block mt-1"
+ title="Number of candidates to generate when sampling. The best one will be selected.">
+ (Sampling candidates)
+ </span>
+ </label>
+ <input type="number" id="best_of" min="1" max="20" value="5" class="form-input">
+ </div>
+ </div>
+ </section>
+
+ <!-- Performance Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Performance Settings</h2>
+ <div>
+ <label for="transcription_loop_delay_ms" class="form-label">Transcription Loop Delay (ms)</label>
+ <input type="number" id="transcription_loop_delay_ms" min="0" value="100" class="form-input">
+ </div>
+ </section>
+
+ <!-- Debug/Preview Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Debug/Preview Settings</h2>
+ <div class="space-y-3">
+ <label for="enable_debug_mode" class="checkbox-label">
+ <input type="checkbox" id="enable_debug_mode" class="mr-2">
+ <span class="checkbox-text">Enable Debug Mode</span>
+ </label>
+ <label for="enable_previews" class="checkbox-label">
+ <input type="checkbox" id="enable_previews" checked class="mr-2">
+ <span class="checkbox-text">Enable Previews</span>
+ </label>
+ <label for="save_audio" class="checkbox-label">
+ <input type="checkbox" id="save_audio" class="mr-2">
+ <span class="checkbox-text">Save Audio Segments</span>
+ </label>
+ <label for="enable_segment_logging" class="checkbox-label">
+ <input type="checkbox" id="enable_segment_logging" class="mr-2">
+ <span class="checkbox-text">Log Segment Metadata</span>
+ </label>
+ </div>
+ </section>
+
+ <!-- Text Filters -->
+ <section class="config-section">
+ <h2 class="section-title">Text Filters</h2>
+ <div class="space-y-3">
+ <label for="enable_lowercase_filter" class="checkbox-label">
+ <input type="checkbox" id="enable_lowercase_filter" class="mr-2">
+ <span class="checkbox-text">Convert to lowercase</span>
+ </label>
+ <label for="enable_uppercase_filter" class="checkbox-label">
+ <input type="checkbox" id="enable_uppercase_filter" class="mr-2">
+ <span class="checkbox-text">Convert to uppercase</span>
+ </label>
+ <label for="enable_profanity_filter" class="checkbox-label">
+ <input type="checkbox" id="enable_profanity_filter" class="mr-2">
+ <span class="checkbox-text">Filter profanity</span>
+ </label>
+ <label for="remove_trailing_period" class="checkbox-label">
+ <input type="checkbox" id="remove_trailing_period" class="mr-2">
+ <span class="checkbox-text">Remove trailing period</span>
+ </label>
+ </div>
+ </section>
+
+ <!-- Input Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Input Settings</h2>
+ <div class="space-y-4">
+ <label for="reset_on_toggle" class="checkbox-label">
+ <input type="checkbox" id="reset_on_toggle" class="mr-2">
+ <span class="checkbox-text">Reset transcript on toggle</span>
+ </label>
+ <div>
+ <label for="volume" class="form-label">
+ Local Beep Volume
+ <span id="volume-display" class="text-gray-500 text-sm ml-2">30%</span>
+ </label>
+ <input type="range" id="volume" min="0" max="100" step="10" value="30" class="form-input w-full">
+ </div>
+ </div>
+ </section>
+
+ <!-- Display Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Custom Chatbox Settings</h2>
+ <div class="mb-4">
+ <label for="use_builtin" class="checkbox-label">
+ <input type="checkbox" id="use_builtin" class="mr-2">
+ <span class="checkbox-text">Use built-in VRChat chatbox</span>
+ </label>
+ </div>
+ <div class="grid grid-cols-2 gap-4">
+ <div>
+ <label for="block_width" class="form-label">Block Width</label>
+ <input type="number" id="block_width" min="1" value="2" class="form-input">
+ </div>
+ <div>
+ <label for="num_blocks" class="form-label">Number of Blocks</label>
+ <input type="number" id="num_blocks" min="1" value="40" class="form-input">
+ </div>
+ <div>
+ <label for="rows" class="form-label">Rows</label>
+ <input type="number" id="rows" min="1" value="10" class="form-input">
+ </div>
+ <div>
+ <label for="cols" class="form-label">Columns</label>
+ <input type="number" id="cols" min="1" value="24" class="form-input">
+ </div>
+ </div>
+ </section>
+
+ <!-- Configuration Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Configuration</h2>
+ <div>
+ <button type="button" id="reset-config" class="btn btn-blue w-full">
+ Reset Config to Defaults
+ </button>
+ </div>
+ </section>
+
+ <!-- Virtual Environment Settings -->
+ <section class="config-section">
+ <h2 class="section-title">Virtual Environment</h2>
+ <div class="flex space-x-3">
+ <button type="button" id="setup-venv" class="btn btn-blue flex-1">
+ Setup venv
+ </button>
+ <button type="button" id="reset-venv" class="btn btn-blue flex-1">
+ Reset venv
+ </button>
+ </div>
+ </section>
+ </div>
+
+ <!-- Action Buttons -->
+ <div class="pb-6">
+ <div class="flex space-x-3">
+ <button type="button" id="start-process" class="btn btn-green flex-1">
+ Start
+ </button>
+ <button type="button" id="stop-process" class="btn btn-red flex-1">
+ Stop
+ </button>
+ </div>
+ </div>
+ </form>
+
+ <!-- Status Message -->
+ <div id="status-message" class="mt-6 p-4 rounded-md hidden"></div>
+ </div>
+ </div>
+
+ <!-- Right Panel: Python Console -->
+ <div class="flex-1 flex flex-col bg-gray-900 rounded-lg overflow-hidden">
+ <div id="python-console" class="flex-1 overflow-y-auto p-4 font-mono text-sm">
+ <div id="console-content" class="text-gray-300 whitespace-pre-wrap"></div>
+ </div>
+ </div>
+ </div>
+ </div>
+
+ <script src="config-schema.js"></script>
+ <script src="renderer.js"></script>
+</body>
+</html>
+
diff --git a/ui/index.js b/ui/index.js
new file mode 100644
index 0000000..63c633a
--- /dev/null
+++ b/ui/index.js
@@ -0,0 +1,616 @@
+const { app, BrowserWindow, ipcMain } = require('electron');
+const path = require('node:path');
+const fs = require('node:fs').promises;
+const yaml = require('js-yaml');
+const { spawn } = require('child_process');
+const https = require('https');
+const { CONFIG_SCHEMA, getDefaultConfig } = require('./config-schema.js');
+
+// Detect if we're running in development or production
+const isDev = !app.isPackaged;
+const APP_ROOT = isDev
+ ? path.join(__dirname, '..') // Development: go up from ui/ to project root
+ : process.resourcesPath; // Production: use Electron's resource path
+
+const CONFIG_PATH = path.join(APP_ROOT, 'config.yaml');
+
+let mainWindow;
+let runningProcess = null; // Track the running Python process
+
+// Required DLL files for CUDA/cuDNN support
+const REQUIRED_DLLS = [
+ 'cublas64_12.dll',
+ 'cublasLt64_12.dll',
+ 'cudnn64_9.dll',
+ 'cudnn_adv64_9.dll',
+ 'cudnn_cnn64_9.dll',
+ 'cudnn_engines_precompiled64_9.dll',
+ 'cudnn_engines_runtime_compiled64_9.dll',
+ 'cudnn_graph64_9.dll',
+ 'cudnn_heuristic64_9.dll',
+ 'cudnn_ops64_9.dll'
+];
+
+// Helper function to get the correct Python executable from venv
+function getVenvPython() {
+ const venvPath = path.join(APP_ROOT, 'venv');
+ const pythonPath = path.join(venvPath, 'Scripts', 'python.exe');
+ return pythonPath;
+}
+
+// Helper function to send Python output to renderer
+function sendPythonOutput(message, type = 'stdout') {
+ if (mainWindow && !mainWindow.isDestroyed()) {
+ mainWindow.webContents.send('python-output', { message, type });
+ }
+}
+
+// Helper function to create environment with DLL path
+function createPythonEnvironment() {
+ const dllPath = path.join(APP_ROOT, 'dll');
+ const binPath = path.join(APP_ROOT, 'bin');
+ const env = { ...process.env };
+ env.PATH = `${dllPath};${binPath};${env.PATH}`;
+ env.HF_HUB_DISABLE_SYMLINKS_WARNING = '1';
+ return env;
+}
+
+// Helper function to download a file from URL with progress
+function downloadFile(url, outputPath) {
+ return new Promise((resolve, reject) => {
+ const file = require('fs').createWriteStream(outputPath);
+ const fileName = path.basename(outputPath);
+
+ const request = https.get(url, (response) => {
+ if (response.statusCode === 200) {
+ const totalSize = parseInt(response.headers['content-length'], 10);
+ let downloadedSize = 0;
+ let lastProgressTime = Date.now();
+
+ response.on('data', (chunk) => {
+ downloadedSize += chunk.length;
+
+ // Log progress every 5 seconds
+ const now = Date.now();
+ if (totalSize && (now - lastProgressTime >= 5000)) {
+ const progress = Math.round((downloadedSize / totalSize) * 100);
+ const mb = (downloadedSize / 1024 / 1024).toFixed(1);
+ const totalMb = (totalSize / 1024 / 1024).toFixed(1);
+ sendPythonOutput(`Downloading ${fileName}: ${mb}/${totalMb} MB (${progress}%)`, 'info');
+ lastProgressTime = now;
+ }
+ });
+
+ response.pipe(file);
+
+ file.on('finish', () => {
+ file.close();
+ resolve();
+ });
+
+ file.on('error', (err) => {
+ fs.unlink(outputPath).catch(() => {}); // Clean up on error
+ reject(err);
+ });
+ } else {
+ file.close();
+ fs.unlink(outputPath).catch(() => {}); // Clean up on error
+ reject(new Error(`Failed to download: HTTP ${response.statusCode}`));
+ }
+ });
+
+ request.on('error', (err) => {
+ file.close();
+ fs.unlink(outputPath).catch(() => {}); // Clean up on error
+ reject(err);
+ });
+ });
+}
+
+function shouldFilterMessage(message) {
+ // Filter out pydub ffmpeg/avconv warning. It does not actually matter.
+ if (message.includes("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work")) {
+ return true;
+ }
+ return false;
+}
+
+// Helper function to setup process event handlers
+function setupProcessHandlers(process) {
+ process.stdout.on('data', (data) => {
+ const text = data.toString();
+ sendPythonOutput(text.trimEnd(), 'stdout');
+ });
+
+ process.stderr.on('data', (data) => {
+ const text = data.toString();
+ if (!shouldFilterMessage(text)) {
+ sendPythonOutput(text.trimEnd(), 'stderr');
+ }
+ });
+
+ process.on('error', (error) => {
+ sendPythonOutput(`Process error: ${error.message}`, 'stderr');
+ runningProcess = null;
+ if (mainWindow && !mainWindow.isDestroyed()) {
+ mainWindow.webContents.send('process-stopped');
+ }
+ });
+
+ process.on('close', (code) => {
+ sendPythonOutput(`Process exited with code ${code}`, 'info');
+ runningProcess = null;
+ if (mainWindow && !mainWindow.isDestroyed()) {
+ mainWindow.webContents.send('process-stopped');
+ }
+ });
+}
+
+// Helper function to execute Python commands using venv
+function executePythonCommand(args, options = {}) {
+ return new Promise((resolve, reject) => {
+ const pythonPath = getVenvPython();
+ const commandStr = `${path.basename(pythonPath)} ${args.join(' ')}`;
+ sendPythonOutput(`> ${commandStr}`, 'info');
+
+ const spawnOptions = {
+ ...options,
+ env: createPythonEnvironment()
+ };
+
+ const pythonProcess = spawn(pythonPath, args, spawnOptions);
+
+ let stdout = '';
+ let stderr = '';
+
+ pythonProcess.stdout.on('data', (data) => {
+ const text = data.toString();
+ stdout += text;
+ sendPythonOutput(text.trimEnd(), 'stdout');
+ });
+
+ pythonProcess.stderr.on('data', (data) => {
+ const text = data.toString();
+ stderr += text;
+ // Filter out specific warning messages
+ if (!shouldFilterMessage(text)) {
+ sendPythonOutput(text.trimEnd(), 'stderr');
+ }
+ });
+
+ pythonProcess.on('error', (error) => {
+ sendPythonOutput(`Failed to start Python process: ${error.message}`, 'stderr');
+ reject({ error: error.message, stdout, stderr });
+ });
+
+ pythonProcess.on('close', (code) => {
+ if (code !== 0) {
+ sendPythonOutput(`Process exited with code ${code}`, 'stderr');
+ reject({ code, stdout, stderr });
+ } else {
+ resolve({ stdout, stderr });
+ }
+ });
+ });
+}
+
+function createWindow () {
+ mainWindow = new BrowserWindow({
+ width: 1000,
+ height: 800,
+ icon: path.join(APP_ROOT, 'Images', 'favicon.ico'),
+ webPreferences: {
+ preload: path.join(__dirname, 'preload.js'),
+ contextIsolation: true,
+ nodeIntegration: false
+ }
+ });
+
+ mainWindow.loadFile('index.html');
+}
+
+// Replace the DEFAULT_CONFIG constant with:
+const DEFAULT_CONFIG = getDefaultConfig();
+
+// IPC handlers
+ipcMain.handle('load-config', async () => {
+ try {
+ const fileContent = await fs.readFile(CONFIG_PATH, 'utf8');
+ return yaml.load(fileContent);
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ // Config file doesn't exist, create it with defaults
+ console.error('Config file not found, creating with defaults...');
+ try {
+ const yamlContent = yaml.dump(DEFAULT_CONFIG, { lineWidth: -1 });
+ await fs.writeFile(CONFIG_PATH, yamlContent, 'utf8');
+ console.error('Created config.yaml with default values');
+ return DEFAULT_CONFIG;
+ } catch (writeError) {
+ console.error('Error creating default config:', writeError);
+ // Return defaults even if we can't write the file
+ return DEFAULT_CONFIG;
+ }
+ }
+ console.error('Error loading config:', error);
+ throw error;
+ }
+});
+
+ipcMain.handle('save-config', async (event, config) => {
+ try {
+ const yamlContent = yaml.dump(config, { lineWidth: -1 });
+ await fs.writeFile(CONFIG_PATH, yamlContent, 'utf8');
+ return { success: true };
+ } catch (error) {
+ console.error('Error saving config:', error);
+ throw error;
+ }
+});
+
+ipcMain.handle('reset-config', async () => {
+ try {
+ // Check if the file exists first
+ try {
+ await fs.access(CONFIG_PATH);
+ // File exists, delete it
+ await fs.unlink(CONFIG_PATH);
+ console.error('Config file deleted successfully');
+ return { success: true, message: 'Configuration reset to defaults' };
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ // Config file doesn't exist, that's fine
+ return { success: true, message: 'Configuration already at defaults' };
+ }
+ throw error;
+ }
+ } catch (error) {
+ console.error('Error resetting config:', error);
+ throw new Error(`Failed to reset configuration: ${error.message}`);
+ }
+});
+
+ipcMain.handle('deleteVenvIndicatorFile', async () => {
+ const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up');
+ try {
+ await fs.unlink(venvMarkerPath);
+ return { success: true, message: '.venv_is_set_up deleted successfully.' };
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ return { success: true, message: '.venv_is_set_up not found.' };
+ }
+ console.error('Error deleting .venv_is_set_up file:', error);
+ sendPythonOutput(`Error deleting .venv_is_set_up: ${error.message}`, 'stderr');
+ throw error;
+ }
+});
+
+// Generic function to ensure required files are present
+async function ensureRequiredFiles(config) {
+ const {
+ directoryName,
+ requiredFiles,
+ downloadBaseUrl,
+ resourceType
+ } = config;
+
+ const targetPath = path.join(APP_ROOT, directoryName);
+
+ try {
+ // Check if target directory exists, create it if not
+ try {
+ await fs.access(targetPath);
+ sendPythonOutput(`${resourceType} directory exists`, 'info');
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ sendPythonOutput(`Creating ${resourceType} directory...`, 'info');
+ await fs.mkdir(targetPath, { recursive: true });
+ sendPythonOutput(`${resourceType} directory created`, 'info');
+ } else {
+ throw error;
+ }
+ }
+
+ // Check each required file
+ const missingFiles = [];
+ for (const fileName of requiredFiles) {
+ const filePath = path.join(targetPath, fileName);
+ try {
+ await fs.access(filePath);
+ sendPythonOutput(`✓ ${fileName} exists`, 'info');
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ missingFiles.push(fileName);
+ sendPythonOutput(`✗ ${fileName} missing`, 'info');
+ } else {
+ throw error;
+ }
+ }
+ }
+
+ // Download missing files
+ if (missingFiles.length > 0) {
+ sendPythonOutput(`Downloading ${missingFiles.length} missing ${resourceType} file${missingFiles.length > 1 ? 's' : ''}...`, 'info');
+
+ for (const fileName of missingFiles) {
+ const filePath = path.join(targetPath, fileName);
+ const downloadUrl = `${downloadBaseUrl}/${fileName}`;
+
+ try {
+ sendPythonOutput(`Downloading ${fileName}...`, 'info');
+ await downloadFile(downloadUrl, filePath);
+ sendPythonOutput(`✓ Downloaded ${fileName}`, 'info');
+ } catch (downloadError) {
+ sendPythonOutput(`✗ Failed to download ${fileName}: ${downloadError.message}`, 'stderr');
+ throw new Error(`Failed to download ${fileName}: ${downloadError.message}`);
+ }
+ }
+
+ sendPythonOutput(`All missing ${resourceType} files downloaded successfully`, 'info');
+ } else {
+ sendPythonOutput(`All required ${resourceType} files are present`, 'info');
+ }
+
+ return {
+ success: true,
+ message: `${resourceType} setup complete. ${missingFiles.length} file${missingFiles.length > 1 ? 's' : ''} downloaded.`,
+ downloadedFiles: missingFiles
+ };
+ } catch (error) {
+ console.error(`Error setting up ${resourceType} files:`, error);
+ throw new Error(`${resourceType} setup failed: ${error.message}`);
+ }
+}
+
+// Update the install-requirements handler
+ipcMain.handle('install-requirements', async () => {
+ const requirementsPath = path.join(APP_ROOT, 'app', 'requirements.txt');
+ const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up');
+
+ try {
+ // Check if venv is already set up
+ try {
+ await fs.access(venvMarkerPath);
+ return { success: true, message: 'Virtual environment already set up' };
+ } catch (error) {
+ // Marker doesn't exist, proceed with setup
+ }
+
+ // Check if requirements.txt exists
+ await fs.access(requirementsPath);
+
+ await executePythonCommand(['-m', 'pip', 'install', '-r', requirementsPath]);
+
+ await ensureRequiredFiles({
+ directoryName: 'dll',
+ requiredFiles: REQUIRED_DLLS,
+ downloadBaseUrl: 'https://yummers.dev/tastt/dll',
+ resourceType: 'DLL'
+ });
+
+ await fs.mkdir(path.join(APP_ROOT, 'Models'), { recursive: true });
+
+ await fs.writeFile(venvMarkerPath, new Date().toISOString(), 'utf8');
+ sendPythonOutput('Created .venv_is_set_up marker file', 'info');
+
+ return { success: true, message: 'Requirements and dependencies installed successfully' };
+ } catch (error) {
+ console.error('Error installing requirements:', error);
+ if (error.code === 'ENOENT') {
+ throw new Error('requirements.txt not found');
+ }
+ throw new Error(`Installation failed: ${error.stderr || error.error || 'Unknown error'}`);
+ }
+});
+
+ipcMain.handle('get-microphones', async () => {
+ const scriptPath = path.join(APP_ROOT, 'app', 'list_microphones.py');
+
+ try {
+ const result = await executePythonCommand([scriptPath]);
+ const microphones = JSON.parse(result.stdout.trim());
+ return microphones;
+ } catch (error) {
+ console.error('Failed to get microphones:', error);
+ throw new Error(`Failed to get microphones: ${error.stderr || error.error || 'Unknown error'}`);
+ }
+});
+
+// Helper function to safely delete directory contents
+async function clearDirectory(dirPath, dirName) {
+ try {
+ await fs.access(dirPath);
+ sendPythonOutput(`Clearing ${dirName} directory...`, 'info');
+
+ const files = await fs.readdir(dirPath);
+ let deletedCount = 0;
+
+ for (const file of files) {
+ const filePath = path.join(dirPath, file);
+
+ try {
+ await fs.rm(filePath, { recursive: true, force: true });
+ sendPythonOutput(`✗ Deleted file ${file}`, 'info');
+
+ deletedCount++;
+ } catch (deleteError) {
+ sendPythonOutput(`Warning: Could not delete ${file}: ${deleteError.message}`, 'stderr');
+ // Continue with other files even if one fails
+ }
+ }
+
+ sendPythonOutput(`${dirName} directory cleared`, 'info');
+ return deletedCount;
+ } catch (error) {
+ if (error.code === 'ENOENT') {
+ sendPythonOutput(`${dirName} directory doesn't exist, skipping`, 'info');
+ return 0;
+ } else {
+ sendPythonOutput(`Error clearing ${dirName} directory: ${error.message}`, 'stderr');
+ throw error;
+ }
+ }
+}
+
+ipcMain.handle('reset-venv', async () => {
+ const venvMarkerPath = path.join(APP_ROOT, '.venv_is_set_up');
+
+ try {
+ sendPythonOutput('Starting virtual environment reset...', 'info');
+
+ // Delete the venv marker file first
+ try {
+ await fs.unlink(venvMarkerPath);
+ sendPythonOutput('Deleted .venv_is_set_up marker file', 'info');
+ } catch (error) {
+ if (error.code !== 'ENOENT') {
+ sendPythonOutput(`Warning: Could not delete marker file: ${error.message}`, 'stderr');
+ }
+ }
+
+ // Get list of installed packages
+ sendPythonOutput('Getting list of installed packages...', 'info');
+ const freezeResult = await executePythonCommand(['-m', 'pip', 'freeze']);
+ const installedPackages = freezeResult.stdout.trim();
+
+ let uninstalledPackages = [];
+
+ if (!installedPackages) {
+ sendPythonOutput('No packages found to uninstall', 'info');
+ } else {
+ // Parse package names and filter out core packages
+ const packageLines = installedPackages.split('\n').filter(line => line.trim());
+ const packageNames = packageLines
+ .map(line => line.split('==')[0].trim())
+ .filter(name => name && !name.startsWith('#'));
+
+ const corePackages = ['pip', 'setuptools', 'wheel'];
+ const packagesToUninstall = packageNames.filter(name => !corePackages.includes(name.toLowerCase()));
+
+ if (packagesToUninstall.length === 0) {
+ sendPythonOutput('Only core packages found, nothing to uninstall', 'info');
+ } else {
+ sendPythonOutput(`Uninstalling ${packagesToUninstall.length} packages...`, 'info');
+
+ const uninstallArgs = ['-m', 'pip', 'uninstall', '-y', ...packagesToUninstall];
+ await executePythonCommand(uninstallArgs);
+ uninstalledPackages = packagesToUninstall;
+ }
+ }
+
+ // Clear downloaded files
+ sendPythonOutput('Clearing downloaded files...', 'info');
+
+ const dllPath = path.join(APP_ROOT, 'dll');
+ const modelsPath = path.join(APP_ROOT, 'Models');
+ const binPath = path.join(APP_ROOT, 'bin');
+
+ const deletedDlls = await clearDirectory(dllPath, 'DLL');
+ const deletedModels = await clearDirectory(modelsPath, 'Models');
+ const deletedBins = await clearDirectory(binPath, 'Binary');
+
+ const totalDeletedFiles = deletedDlls + deletedModels + deletedBins;
+
+ sendPythonOutput('Virtual environment reset successfully!', 'info');
+
+ return {
+ success: true,
+ message: `Virtual environment reset complete. Uninstalled ${uninstalledPackages.length} packages and deleted ${totalDeletedFiles} downloaded files.`,
+ uninstalledPackages,
+ deletedFiles: {
+ dlls: deletedDlls,
+ models: deletedModels,
+ binaries: deletedBins,
+ total: totalDeletedFiles
+ }
+ };
+ } catch (error) {
+ console.error('Error resetting virtual environment:', error);
+ throw new Error(`Virtual environment reset failed: ${error.message}`);
+ }
+});
+
+// Add handlers for starting and stopping the process
+ipcMain.handle('start-process', async () => {
+ if (runningProcess) {
+ throw new Error('Process is already running');
+ }
+
+ const scriptPath = path.join(APP_ROOT, 'app', 'hi.py');
+ const args = [scriptPath, '--config', CONFIG_PATH];
+
+ try {
+ const pythonPath = getVenvPython();
+ sendPythonOutput(`Starting process: ${path.basename(pythonPath)} ${args.join(' ')}`, 'info');
+
+ runningProcess = spawn(pythonPath, args, { env: createPythonEnvironment() });
+ setupProcessHandlers(runningProcess);
+
+ return { success: true };
+ } catch (error) {
+ runningProcess = null;
+ throw error;
+ }
+});
+
+ipcMain.handle('stop-process', async () => {
+ if (!runningProcess) {
+ sendPythonOutput('No process to stop', 'info');
+ return { success: true, forcefullyKilled: false };
+ }
+
+ return new Promise((resolve) => {
+ let forcefullyKilled = false;
+
+ // Set up a timeout to force kill after 10 seconds
+ const killTimeout = setTimeout(() => {
+ if (runningProcess) {
+ sendPythonOutput('Process did not stop gracefully, forcing termination...', 'stderr');
+ forcefullyKilled = true;
+ runningProcess.kill('SIGKILL');
+ }
+ }, 10000);
+
+ // Listen for the process to exit
+ runningProcess.once('exit', (code, signal) => {
+ clearTimeout(killTimeout);
+ runningProcess = null;
+
+ if (forcefullyKilled) {
+ sendPythonOutput('Process forcefully terminated', 'info');
+ } else {
+ sendPythonOutput('Process stopped gracefully', 'info');
+ }
+
+ resolve({ success: true, forcefullyKilled });
+ });
+
+ // Send termination signal
+ sendPythonOutput('Stopping process gracefully...', 'info');
+ runningProcess.kill('SIGTERM');
+ });
+});
+
+ipcMain.handle('get-process-state', () => {
+ return { isRunning: runningProcess !== null };
+});
+
+// Clean up on app quit
+app.on('before-quit', () => {
+ if (runningProcess) {
+ runningProcess.kill();
+ }
+});
+
+app.whenReady().then(() => {
+ createWindow();
+
+ app.on('activate', function () {
+ if (BrowserWindow.getAllWindows().length === 0) createWindow();
+ });
+});
+
+app.on('window-all-closed', function () {
+ app.quit();
+});
+
diff --git a/ui/package.json b/ui/package.json
new file mode 100644
index 0000000..d99424c
--- /dev/null
+++ b/ui/package.json
@@ -0,0 +1,118 @@
+{
+ "name": "TaSTT",
+ "version": "1.0.0",
+ "description": "Speech-to-text tool for VRChat",
+ "main": "index.js",
+ "homepage": "./",
+ "scripts": {
+ "start": "npm run build:css && electron .",
+ "build:css": "tailwindcss -i ./src/components.css -o ./output.css",
+ "watch:css": "tailwindcss -i ./src/components.css -o ./output.css --watch",
+ "dev": "concurrently \"npm run watch:css\" \"electron .\"",
+ "test": "echo \"Error: no test specified\" && exit 1",
+ "clean:meta": "node -e \"const fs=require('fs');const path=require('path');function deleteMeta(dir){fs.readdirSync(dir).forEach(f=>{const p=path.join(dir,f);if(f.endsWith('.meta'))fs.unlinkSync(p);else if(fs.statSync(p).isDirectory()&&!f.startsWith('.'))deleteMeta(p);})}deleteMeta('./node_modules')\"",
+ "prebuild": "node build_scripts/setup-empty-venv.js",
+ "dist": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder",
+ "dist:win": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win",
+ "dist:portable": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win portable",
+ "dist:zip": "npm run prebuild && npm run clean:meta && npm run build:css && electron-builder --win zip"
+ },
+ "build": {
+ "appId": "com.yum_food.tastt",
+ "productName": "TaSTT",
+ "directories": {
+ "output": "dist"
+ },
+ "files": [
+ "**/*",
+ "!dist/**/*",
+ "!src/**/*",
+ "!node_modules/**/{CHANGELOG.md,README.md,README,readme.md,readme}",
+ "!node_modules/**/{test,__tests__,tests,powered-test,example,examples}",
+ "!node_modules/**/*.d.ts",
+ "!node_modules/.bin",
+ "!.git/**/*",
+ "!.gitignore"
+ ],
+ "extraResources": [
+ {
+ "from": "../app",
+ "to": "app",
+ "filter": [
+ "**/*.py",
+ "requirements.txt",
+ "!**/__pycache__/**/*"
+ ]
+ },
+ {
+ "from": "../config.yaml",
+ "to": "config.yaml"
+ },
+ {
+ "from": "../Images",
+ "to": "Images",
+ "filter": ["**/*"]
+ },
+ {
+ "from": "../bin",
+ "to": "bin",
+ "filter": ["**/*"]
+ },
+ {
+ "from": "../venv_clean",
+ "to": "venv",
+ "filter": ["**/*"]
+ },
+ {
+ "from": "../dll_empty",
+ "to": "dll",
+ "filter": ["**/*"]
+ },
+ {
+ "from": "../Sounds",
+ "to": "Sounds",
+ "filter": ["*.wav"]
+ }
+ ],
+ "win": {
+ "icon": "../Images/favicon.ico",
+ "target": [
+ {
+ "target": "portable",
+ "arch": ["x64"]
+ },
+ {
+ "target": "zip",
+ "arch": ["x64"]
+ }
+ ]
+ },
+ "portable": {
+ "artifactName": "${productName}-${version}-portable.exe"
+ },
+ "nsis": {
+ "oneClick": false,
+ "allowToChangeInstallationDirectory": true
+ },
+ "compression": "normal",
+ "artifactName": "${productName}-${version}-${arch}.${ext}"
+ },
+ "keywords": [],
+ "author": "yum_food",
+ "license": "MIT",
+ "dependencies": {
+ "js-yaml": "^4.1.0"
+ },
+ "devDependencies": {
+ "@vitejs/plugin-vue": "^5.2.4",
+ "autoprefixer": "^10.4.21",
+ "concurrently": "^9.1.2",
+ "cross-env": "^7.0.3",
+ "electron": "^36.3.2",
+ "electron-builder": "^25.1.8",
+ "postcss": "^8.5.4",
+ "tailwindcss": "^3.4.17",
+ "vite": "^6.3.5",
+ "vue": "^3.5.16"
+ }
+}
diff --git a/ui/postcss.config.js b/ui/postcss.config.js
new file mode 100644
index 0000000..33ad091
--- /dev/null
+++ b/ui/postcss.config.js
@@ -0,0 +1,6 @@
+module.exports = {
+ plugins: {
+ tailwindcss: {},
+ autoprefixer: {},
+ },
+}
diff --git a/ui/preload.js b/ui/preload.js
new file mode 100644
index 0000000..6f6e54f
--- /dev/null
+++ b/ui/preload.js
@@ -0,0 +1,17 @@
+const { contextBridge, ipcRenderer } = require('electron');
+
+contextBridge.exposeInMainWorld('electronAPI', {
+ loadConfig: () => ipcRenderer.invoke('load-config'),
+ saveConfig: (config) => ipcRenderer.invoke('save-config', config),
+ resetConfig: () => ipcRenderer.invoke('reset-config'),
+ getMicrophones: () => ipcRenderer.invoke('get-microphones'),
+ installRequirements: () => ipcRenderer.invoke('install-requirements'),
+ deleteVenvIndicatorFile: () => ipcRenderer.invoke('deleteVenvIndicatorFile'),
+ resetVenv: () => ipcRenderer.invoke('reset-venv'),
+ startProcess: () => ipcRenderer.invoke('start-process'),
+ stopProcess: () => ipcRenderer.invoke('stop-process'),
+ getProcessState: () => ipcRenderer.invoke('get-process-state'),
+ onPythonOutput: (callback) => ipcRenderer.on('python-output', (event, data) => callback(data)),
+ onProcessStopped: (callback) => ipcRenderer.on('process-stopped', () => callback())
+});
+
diff --git a/ui/renderer.js b/ui/renderer.js
new file mode 100644
index 0000000..008e0da
--- /dev/null
+++ b/ui/renderer.js
@@ -0,0 +1,531 @@
+// Import configuration schema
+const CONFIG_FIELDS = window.CONFIG_SCHEMA;
+
+// Process state tracking
+let isProcessRunning = false;
+let buttonManager;
+let loadingOverlay;
+
+// Auto-save functionality with debouncing
+let saveTimeout;
+const SAVE_DELAY = 500;
+let isSettingValues = false;
+
+// Console management
+const consoleContent = document.getElementById('console-content');
+const MAX_CONSOLE_LINES = 512;
+let consoleLineCount = 0;
+
+// Button management system
+class ButtonManager {
+ constructor() {
+ this.buttons = {
+ start: document.getElementById('start-process'),
+ stop: document.getElementById('stop-process'),
+ setupVenv: document.getElementById('setup-venv'),
+ resetVenv: document.getElementById('reset-venv'),
+ refreshMicrophones: document.getElementById('refresh-microphones')
+ };
+
+ // Initialize button states - process is not running at startup
+ this.setProcessStopped();
+ }
+
+ setState(buttonName, disabled) {
+ const button = this.buttons[buttonName];
+ if (!button) return;
+
+ button.disabled = disabled;
+ }
+
+ setProcessRunning() {
+ this.setState('start', true);
+ this.setState('stop', false);
+ isProcessRunning = true;
+ }
+
+ setProcessStopped() {
+ this.setState('start', false);
+ this.setState('stop', true);
+ isProcessRunning = false;
+ }
+
+ async withButtonLoading(buttonName, asyncFn) {
+ this.setState(buttonName, true);
+ try {
+ return await asyncFn();
+ } finally {
+ this.setState(buttonName, false);
+ }
+ }
+}
+
+// Add loading overlay management
+class LoadingOverlay {
+ constructor() {
+ this.overlay = document.getElementById('loading-overlay');
+ this.form = document.getElementById('config-form');
+ this.messageElement = this.overlay.querySelector('p');
+ this.defaultMessage = 'Environment setup underway - please wait.';
+ this.originalStates = new Map(); // Track original disabled states
+ }
+
+ show(message = null) {
+ this.messageElement.textContent = message || this.defaultMessage;
+ this.overlay.classList.remove('hidden');
+ // Disable all form inputs and buttons in the entire left panel
+ const leftPanel = this.overlay.parentElement;
+ const inputs = leftPanel.querySelectorAll('input, select, textarea, button');
+ inputs.forEach(input => {
+ // Store original disabled state before disabling
+ this.originalStates.set(input, input.disabled);
+ input.disabled = true;
+ input.classList.add('opacity-50');
+ });
+ }
+
+ hide() {
+ this.overlay.classList.add('hidden');
+ // Restore original states of form inputs and buttons
+ const leftPanel = this.overlay.parentElement;
+ const inputs = leftPanel.querySelectorAll('input, select, textarea, button');
+ inputs.forEach(input => {
+ // Restore original disabled state
+ input.disabled = this.originalStates.get(input) || false;
+ input.classList.remove('opacity-50');
+ });
+ // Clear the stored states
+ this.originalStates.clear();
+ // Reset to default message
+ this.messageElement.textContent = this.defaultMessage;
+ }
+}
+
+// Handle status messages with better color management
+function showStatus(message, type = 'info') {
+ const statusEl = document.getElementById('status-message');
+ statusEl.textContent = message;
+
+ // Remove all status classes
+ const statusClasses = ['hidden', 'bg-green-100', 'bg-red-100', 'bg-blue-100', 'text-green-800', 'text-red-800', 'text-blue-800'];
+ statusEl.classList.remove(...statusClasses);
+
+ // Add appropriate classes based on type
+ const typeMap = {
+ success: ['bg-green-100', 'text-green-800'],
+ error: ['bg-red-100', 'text-red-800'],
+ info: ['bg-blue-100', 'text-blue-800']
+ };
+
+ statusEl.classList.add(...(typeMap[type] || typeMap.info));
+
+ // Also log to console
+ appendToConsole(message, type === 'error' ? 'stderr' : 'info');
+
+ setTimeout(() => statusEl.classList.add('hidden'), 5000);
+}
+
+// Get form values using field mappings
+function getFormValues() {
+ const config = {};
+
+ for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) {
+ const element = document.getElementById(fieldName);
+ if (!element) continue;
+
+ switch (fieldConfig.type) {
+ case 'boolean':
+ config[fieldName] = element.checked ? 1 : 0;
+ break;
+ case 'number':
+ const numValue = parseInt(element.value);
+ config[fieldName] = isNaN(numValue) ? fieldConfig.default : numValue;
+ break;
+ case 'text':
+ config[fieldName] = element.value || fieldConfig.default;
+ break;
+ default:
+ config[fieldName] = element.value || fieldConfig.default;
+ }
+ }
+
+ return config;
+}
+
+// Set form values using field mappings
+function setFormValues(config) {
+ isSettingValues = true; // Disable auto-save temporarily
+
+ for (const [fieldName, fieldConfig] of Object.entries(CONFIG_FIELDS)) {
+ const element = document.getElementById(fieldName);
+ if (!element) continue;
+
+ const value = config[fieldName] ?? fieldConfig.default;
+
+ switch (fieldConfig.type) {
+ case 'boolean':
+ element.checked = value === 1;
+ break;
+ case 'text':
+ element.value = value || '';
+ break;
+ default:
+ element.value = value;
+ }
+ }
+
+ // Handle use_builtin toggle state
+ const useBuiltin = config.use_builtin === 1;
+ const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols'];
+ customChatboxInputs.forEach(inputId => {
+ const input = document.getElementById(inputId);
+ if (input) {
+ input.disabled = useBuiltin;
+ if (useBuiltin) {
+ input.classList.add('opacity-50', 'cursor-not-allowed');
+ } else {
+ input.classList.remove('opacity-50', 'cursor-not-allowed');
+ }
+ }
+ });
+
+ // Update volume display
+ if (config.volume !== undefined) {
+ const volumePercent = Math.round(config.volume);
+ document.getElementById('volume-display').textContent = `${volumePercent}%`;
+ }
+
+ isSettingValues = false; // Re-enable auto-save
+}
+
+function appendToConsole(message, type = 'stdout') {
+ const timestamp = new Date().toLocaleTimeString();
+ const timestampSpan = document.createElement('span');
+ timestampSpan.className = 'console-timestamp';
+ timestampSpan.textContent = `[${timestamp}] `;
+
+ const messageSpan = document.createElement('span');
+ messageSpan.className = `console-${type}`;
+ messageSpan.textContent = message;
+
+ const lineDiv = document.createElement('div');
+ lineDiv.appendChild(timestampSpan);
+ lineDiv.appendChild(messageSpan);
+
+ consoleContent.appendChild(lineDiv);
+ consoleLineCount++;
+
+ // Remove old lines if we exceed the limit
+ if (consoleLineCount > MAX_CONSOLE_LINES) {
+ // Calculate how many lines to remove (remove 10% to avoid frequent trimming)
+ const linesToRemove = Math.floor(MAX_CONSOLE_LINES * 0.1);
+
+ // Remove the oldest lines
+ for (let i = 0; i < linesToRemove; i++) {
+ if (consoleContent.firstChild) {
+ consoleContent.removeChild(consoleContent.firstChild);
+ }
+ }
+
+ consoleLineCount -= linesToRemove;
+
+ // Add a notice that lines were trimmed
+ const trimNotice = document.createElement('div');
+ trimNotice.className = 'console-info';
+ trimNotice.innerHTML = '<span class="console-timestamp">[System] </span><span class="console-info">... older lines removed to maintain performance ...</span>';
+ consoleContent.insertBefore(trimNotice, consoleContent.firstChild);
+ }
+
+ // Auto-scroll to bottom
+ const pythonConsole = document.getElementById('python-console');
+ pythonConsole.scrollTop = pythonConsole.scrollHeight;
+}
+
+// Async action handler with better error handling
+async function handleAsyncAction(actionName, actionFn) {
+ try {
+ const result = await actionFn();
+ if (result?.message) {
+ showStatus(result.message, 'success');
+ }
+ return result;
+ } catch (error) {
+ showStatus(`${actionName} failed: ${error.message}`, 'error');
+ throw error;
+ }
+}
+
+async function autoSaveConfig() {
+ if (isSettingValues) return;
+
+ clearTimeout(saveTimeout);
+ saveTimeout = setTimeout(async () => {
+ try {
+ const config = getFormValues();
+ await window.electronAPI.saveConfig(config);
+ showStatus('Configuration saved', 'success');
+
+ // Restart process if running
+ if (isProcessRunning) {
+ appendToConsole('Restarting process with new configuration...', 'info');
+
+ try {
+ await window.electronAPI.stopProcess();
+ await new Promise(resolve => setTimeout(resolve, 1000));
+ await window.electronAPI.startProcess();
+ buttonManager.setProcessRunning();
+ appendToConsole('Process restarted with new configuration', 'info');
+ } catch (error) {
+ appendToConsole(`Failed to restart process: ${error.message}`, 'stderr');
+ buttonManager.setProcessStopped();
+ }
+ }
+ } catch (error) {
+ showStatus(`Failed to save configuration: ${error.message}`, 'error');
+ }
+ }, SAVE_DELAY);
+}
+
+// Auto-save setup
+function setupAutoSave() {
+ const form = document.getElementById('config-form');
+ const inputs = form.querySelectorAll('input, select, textarea');
+
+ inputs.forEach(input => {
+ const eventType = input.type === 'checkbox' ? 'change' :
+ (input.type === 'number' || input.type === 'text' || input.tagName === 'TEXTAREA') ? 'input' : 'change';
+ input.addEventListener(eventType, autoSaveConfig);
+ });
+}
+
+// Microphone loading
+async function loadMicrophones() {
+ const microphoneSelect = document.getElementById('microphone');
+
+ try {
+ // Check/install requirements during startup
+ appendToConsole('Checking virtual environment and requirements...', 'info');
+ loadingOverlay.show('Setting up environment - this can take several minutes.');
+ try {
+ await handleAsyncAction('Install requirements', () => window.electronAPI.installRequirements());
+ } finally {
+ loadingOverlay.hide(); // Always hide overlay when done
+ }
+
+ appendToConsole('Loading available microphones...', 'info');
+ const microphones = await window.electronAPI.getMicrophones();
+
+ microphoneSelect.innerHTML = '';
+
+ if (microphones.length === 0) {
+ microphoneSelect.innerHTML = '<option value="" disabled>No microphones found</option>';
+ appendToConsole('No microphones found', 'stderr');
+ return;
+ }
+
+ appendToConsole(`Found ${microphones.length} microphone(s)`, 'info');
+ microphones.forEach(mic => {
+ const option = document.createElement('option');
+ option.value = mic.index.toString();
+ option.textContent = mic.name;
+ microphoneSelect.appendChild(option);
+ appendToConsole(` - ${mic.name} (Device ${mic.index})`, 'stdout');
+ });
+
+ // Restore previously selected microphone
+ try {
+ const config = await window.electronAPI.loadConfig();
+ if (config.microphone) {
+ microphoneSelect.value = config.microphone;
+ }
+ } catch (error) {
+ // Ignore config load errors here
+ }
+
+ } catch (error) {
+ appendToConsole(`Failed to load microphones: ${error.message}`, 'stderr');
+ microphoneSelect.innerHTML = '<option value="" disabled>Error loading microphones</option>';
+ }
+}
+
+// Event handlers setup
+function setupEventHandlers() {
+ // Advanced settings toggle
+ document.getElementById('toggle-advanced').addEventListener('click', () => {
+ const advancedSettings = document.getElementById('advanced-settings');
+ const chevron = document.getElementById('chevron');
+
+ if (advancedSettings.classList.contains('hidden')) {
+ advancedSettings.classList.remove('hidden');
+ chevron.classList.add('rotate-90');
+ } else {
+ advancedSettings.classList.add('hidden');
+ chevron.classList.remove('rotate-90');
+ }
+ });
+
+ // Use builtin chatbox toggle
+ document.getElementById('use_builtin').addEventListener('change', (e) => {
+ const customChatboxInputs = ['block_width', 'num_blocks', 'rows', 'cols'];
+ const isBuiltin = e.target.checked;
+
+ customChatboxInputs.forEach(inputId => {
+ const input = document.getElementById(inputId);
+ if (input) {
+ input.disabled = isBuiltin;
+ if (isBuiltin) {
+ input.classList.add('opacity-50', 'cursor-not-allowed');
+ } else {
+ input.classList.remove('opacity-50', 'cursor-not-allowed');
+ }
+ }
+ });
+ });
+
+ // Volume slider update
+ document.getElementById('volume').addEventListener('input', (e) => {
+ const volumePercent = Math.round(e.target.value);
+ document.getElementById('volume-display').textContent = `${volumePercent}%`;
+ });
+
+ // Setup virtual environment
+ document.getElementById('setup-venv').addEventListener('click', async () => {
+ loadingOverlay.show('Setting up virtual environment - please wait...'); // Show overlay with custom message
+ try {
+ await buttonManager.withButtonLoading('setupVenv', async () => {
+ await window.electronAPI.deleteVenvIndicatorFile();
+ await handleAsyncAction('Install requirements', () => window.electronAPI.installRequirements());
+ });
+ } finally {
+ loadingOverlay.hide(); // Always hide overlay when done
+ }
+ });
+
+ // Reset virtual environment
+ document.getElementById('reset-venv').addEventListener('click', async () => {
+ loadingOverlay.show('Resetting virtual environment - please wait...'); // Show overlay with custom message
+ try {
+ await buttonManager.withButtonLoading('resetVenv', async () => {
+ await handleAsyncAction('Reset virtual environment', () => window.electronAPI.resetVenv());
+ });
+ } finally {
+ loadingOverlay.hide(); // Always hide overlay when done
+ }
+ });
+
+ // Reset configuration
+ document.getElementById('reset-config').addEventListener('click', async () => {
+ const confirmReset = confirm('Are you sure you want to reset all settings to defaults? This cannot be undone.');
+ if (!confirmReset) return;
+
+ try {
+ // Stop process if running
+ const wasRunning = isProcessRunning;
+ if (wasRunning) {
+ appendToConsole('Stopping process before resetting configuration...', 'info');
+ await window.electronAPI.stopProcess();
+ buttonManager.setProcessStopped();
+ await new Promise(resolve => setTimeout(resolve, 500));
+ }
+
+ // Reset configuration
+ appendToConsole('Resetting configuration to defaults...', 'info');
+ const result = await window.electronAPI.resetConfig();
+
+ // Reload configuration with defaults
+ const config = await window.electronAPI.loadConfig();
+ setFormValues(config);
+
+ showStatus(result.message, 'success');
+ appendToConsole('Configuration reset successfully', 'info');
+
+ // Restart process if it was running
+ if (wasRunning) {
+ appendToConsole('Restarting process with default configuration...', 'info');
+ await window.electronAPI.startProcess();
+ buttonManager.setProcessRunning();
+ appendToConsole('Process restarted with default configuration', 'info');
+ }
+ } catch (error) {
+ showStatus(`Failed to reset configuration: ${error.message}`, 'error');
+ appendToConsole(`Failed to reset configuration: ${error.message}`, 'stderr');
+ }
+ });
+
+ // Refresh microphones
+ document.getElementById('refresh-microphones').addEventListener('click', async () => {
+ await buttonManager.withButtonLoading('refreshMicrophones', async () => {
+ await loadMicrophones();
+ });
+ });
+
+ // Start process
+ document.getElementById('start-process').addEventListener('click', async () => {
+ buttonManager.setState('start', true);
+
+ try {
+ // The installRequirements function will now check if venv is set up.
+ loadingOverlay.show('Verifying environment setup - please wait...'); // Show overlay with custom message
+ try {
+ await window.electronAPI.installRequirements();
+ appendToConsole('Virtual environment setup checked/completed', 'info');
+ } finally {
+ loadingOverlay.hide(); // Always hide overlay when done
+ }
+
+ await window.electronAPI.startProcess();
+ buttonManager.setProcessRunning();
+ appendToConsole('Process started successfully', 'info');
+ } catch (error) {
+ appendToConsole(`Failed to start process: ${error.message}`, 'stderr');
+ buttonManager.setState('start', false);
+ }
+ });
+
+ // Stop process
+ document.getElementById('stop-process').addEventListener('click', async () => {
+ buttonManager.setState('stop', true);
+
+ try {
+ await window.electronAPI.stopProcess();
+ appendToConsole('Process stop initiated', 'info');
+ } catch (error) {
+ appendToConsole(`Failed to stop process: ${error.message}`, 'stderr');
+ buttonManager.setState('stop', false);
+ }
+ });
+
+ // Listen for process stopped event
+ window.electronAPI.onProcessStopped(() => {
+ buttonManager.setProcessStopped();
+ });
+}
+
+// Initialize application
+window.addEventListener('load', async () => {
+ appendToConsole('TaSTT Configuration UI initialized', 'info');
+
+ loadingOverlay = new LoadingOverlay();
+ buttonManager = new ButtonManager();
+
+ // Set up Python output listener first so we capture all output
+ window.electronAPI.onPythonOutput((data) => {
+ appendToConsole(data.message, data.type);
+ });
+
+ // Load configuration
+ try {
+ const config = await window.electronAPI.loadConfig();
+ setFormValues(config);
+ appendToConsole('Configuration loaded', 'info');
+ } catch (error) {
+ appendToConsole(`Failed to load configuration: ${error.message}`, 'stderr');
+ }
+
+ // Load microphones
+ await loadMicrophones();
+
+ // Setup event handlers and auto-save
+ setupEventHandlers();
+ setupAutoSave();
+});
diff --git a/ui/src/components.css b/ui/src/components.css
new file mode 100644
index 0000000..2832e12
--- /dev/null
+++ b/ui/src/components.css
@@ -0,0 +1,122 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
+
+@layer components {
+ .config-section {
+ @apply bg-white rounded-lg shadow-md p-6;
+ }
+
+ .section-title {
+ @apply text-xl font-semibold text-gray-700 mb-4;
+ }
+
+ .form-label {
+ @apply block text-sm font-medium text-gray-700 mb-2;
+ }
+
+ .form-input {
+ @apply w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm;
+ }
+
+ .checkbox-label {
+ @apply flex items-center cursor-pointer hover:bg-gray-50 p-2 rounded;
+ }
+
+ .checkbox-text {
+ @apply text-sm text-gray-700;
+ }
+
+ .btn {
+ @apply px-4 py-2 font-medium text-sm rounded-md transition-colors focus:outline-none focus:ring-2 focus:ring-offset-2;
+ }
+
+ .btn-blue {
+ @apply bg-blue-600 text-white hover:bg-blue-700 focus:ring-blue-500;
+ }
+
+ .btn-green {
+ @apply bg-green-600 text-white hover:bg-green-700 focus:ring-green-500;
+ }
+
+ .btn-gray {
+ @apply bg-gray-600 text-white hover:bg-gray-700 focus:ring-gray-500;
+ }
+
+ .btn-red {
+ @apply bg-red-600 text-white hover:bg-red-700 focus:ring-red-500;
+ }
+
+ .btn-purple {
+ @apply bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500;
+ }
+
+ .btn-orange {
+ @apply bg-orange-600 text-white hover:bg-orange-700 focus:ring-orange-500;
+ }
+}
+
+/* Console styling */
+#python-console {
+ background-color: #1a1a1a;
+ font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
+ line-height: 1.4;
+}
+
+#console-content {
+ word-wrap: break-word;
+}
+
+/* Console text colors */
+.console-stdout {
+ color: #a8cc8c;
+}
+
+.console-stderr {
+ color: #e88388;
+}
+
+.console-info {
+ color: #66c2cd;
+}
+
+.console-timestamp {
+ color: #6c7986;
+ font-size: 0.875rem;
+}
+
+/* Ensure full height layout */
+html, body {
+ height: 100%;
+ margin: 0;
+ padding: 0;
+}
+
+.container-fluid {
+ max-width: 100%;
+ height: 100vh;
+}
+
+/* Scrollbar styling for console */
+#python-console::-webkit-scrollbar {
+ width: 8px;
+}
+
+#python-console::-webkit-scrollbar-track {
+ background: #2a2a2a;
+}
+
+#python-console::-webkit-scrollbar-thumb {
+ background: #4a4a4a;
+ border-radius: 4px;
+}
+
+#python-console::-webkit-scrollbar-thumb:hover {
+ background: #5a5a5a;
+}
+
+/* Ensure buttons have proper disabled states */
+button:disabled {
+ cursor: not-allowed;
+ opacity: 0.5;
+}
diff --git a/ui/src/input.css b/ui/src/input.css
new file mode 100644
index 0000000..b5c61c9
--- /dev/null
+++ b/ui/src/input.css
@@ -0,0 +1,3 @@
+@tailwind base;
+@tailwind components;
+@tailwind utilities;
diff --git a/ui/tailwind.config.js b/ui/tailwind.config.js
new file mode 100644
index 0000000..804b7f0
--- /dev/null
+++ b/ui/tailwind.config.js
@@ -0,0 +1,13 @@
+/** @type {import('tailwindcss').Config} */
+module.exports = {
+ content: [
+ "./*.html",
+ "./*.js",
+ "./src/**/*.{html,js}"
+ ],
+ theme: {
+ extend: {},
+ },
+ plugins: [],
+}
+
diff --git a/ui_design.md b/ui_design.md
new file mode 100644
index 0000000..e1ff095
--- /dev/null
+++ b/ui_design.md
@@ -0,0 +1,39 @@
+# TaSTT UI
+
+The TaSTT is built using electron and tailwind.css.
+
+First, install nodejs. Open PowerShell as administrator:
+
+```bash
+# Delete any existing install.
+$ choco uninstall nodejs -y
+$ choco install nodejs-lts -y
+```
+
+To build the app:
+```
+$ npm install
+$ npm run dev
+```
+
+For posterity, this is how I set up the ui directory initially. In a non-admin PowerShell window:
+
+```bash
+# Check your node and npm versions.
+$ node -v
+v22.16.0
+$ npm -v
+10.9.2
+# Set up directory
+$ mkdir ui
+cd ui
+npm init -y
+npm install --save-dev electron
+# Get tailwind and deps
+npm install --save-dev tailwindcss@3 postcss autoprefixer concurrently cross-env
+npx tailwindcss init -p
+# Install vue.js
+npm install --save-dev vue@3 @vitejs/plugin-vue vite yaml
+npm install --save-dev js-yaml
+```
+