diff options
| -rw-r--r-- | GenerateTextAnimator.cs | 200 | ||||
| -rw-r--r-- | README.md | 282 | ||||
| -rw-r--r-- | bpe_dump.py | 61 | ||||
| -rw-r--r-- | generate_bpe_lut.py | 119 | ||||
| -rw-r--r-- | generate_tokenizer.py | 525 | ||||
| -rw-r--r-- | hi.py | 357 | ||||
| -rw-r--r-- | requirements.txt | 6 | ||||
| -rw-r--r-- | tokenize_me.py | 28 |
8 files changed, 1578 insertions, 0 deletions
diff --git a/GenerateTextAnimator.cs b/GenerateTextAnimator.cs new file mode 100644 index 0000000..fb4a1ee --- /dev/null +++ b/GenerateTextAnimator.cs @@ -0,0 +1,200 @@ +#if UNITY_EDITOR +using AnimatorAsCode.V1; +using AnimatorAsCode.V1.ModularAvatar; +using YumTools; +using nadena.dev.ndmf; +using System.Collections.Generic; +using UnityEditor; +using UnityEditor.Animations; +using UnityEngine; +using VRC.SDK3.Avatars.Components; +using VRC.SDKBase; + +// This example uses NDMF. See https://github.com/bdunderscore/ndmf?tab=readme-ov-file#getting-started +[assembly: ExportsPlugin(typeof(GenerateTextAnimatorPlugin))] +namespace YumTools +{ + public class GenerateTextAnimator : MonoBehaviour, IEditorOnly + { + // The number of blocks addressable. + public int numBlocks = 10; + // The number of datums sent per block. + public int blockWidth = 5; + // The number of bytes per datum. + public int bytesPerDatum = 2; + + public string oscPointerParam = "_Unigram_Letter_Grid_OSC_Pointer"; + + // Data sent through OSC uses this pattern. + public string[] oscParamPerBlockDatumByte = {"_Unigram_Letter_Grid_OSC_Datum{0:00}_Byte{1:00}"}; + // The pattern of the parameters which this script will animate. + public string[] matPropPerBlockDatumByte = {"_Unigram_Letter_Grid_Data_Block{0:00}_Datum{1:00}_Byte{2:00}_Animated"}; + + public string[] oscParamPerBlockDatum = {}; + public string[] matPropPerBlockDatum = {}; + + public string[] oscParamPerBlock = {"_Unigram_Letter_Grid_OSC_Visual_Pointer"}; + public string[] matPropPerBlock = {"_Unigram_Letter_Grid_Block_{0:00}_Visual_Pointer_Animated"}; + } + + public class GenerateTextAnimatorPlugin : Plugin<GenerateTextAnimatorPlugin> + { + public override string QualifiedName => "yum.generate_chat_animator"; + public override string DisplayName => "Chat Animator"; + + private const string SystemName = "GenerateTextAnimator"; + // Direct blendtrees have special semantics with write defaults. We want + // them on. They will not fuck up the rest of the animator, whether it uses + // write defaults or not. + private const bool UseWriteDefaults = true; + + protected override void Configure() + { + InPhase(BuildPhase.Generating).Run($"Generate {DisplayName}", Generate); + } + + private AacFlBlendTreeDirect Add8BitBlendTree( + AacFlBase aac, + AacFlLayer layer, + GenerateTextAnimator cfg, + AacFlBlendTreeDirect tree, + string matProp, string oscParam, string oneParam) { + var offAnim = aac.NewClip().Animating(clip => + { + clip.Animates(cfg.GetComponent<Renderer>(), "material." + matProp).WithFrameCountUnit(keyframes => + keyframes.Constant(/*when=*/0, /*value=*/0)); + }); + var onAnim = aac.NewClip().Animating(clip => + { + clip.Animates(cfg.GetComponent<Renderer>(), "material." + matProp).WithFrameCountUnit(keyframes => + keyframes.Constant(/*when=*/0, /*value=*/255)); + }); + var subtree = aac.NewBlendTree().Simple1D(layer.FloatParameter(oscParam)) + .WithAnimation(offAnim, /*threshold=*/-1) + .WithAnimation(onAnim, /*threshold=*/ 1); + return tree.WithAnimation(subtree, layer.FloatParameter("AlwaysOne")); + } + + private void Generate(BuildContext ctx) + { + var components = ctx.AvatarRootTransform.GetComponentsInChildren<GenerateTextAnimator>(true); + if (components.Length == 0) return; + + var aac = AacV1.Create(new AacConfiguration + { + SystemName = SystemName, + AnimatorRoot = ctx.AvatarRootTransform, + DefaultValueRoot = ctx.AvatarRootTransform, + AssetKey = GUID.Generate().ToString(), + AssetContainer = ctx.AssetContainer, + ContainerMode = AacConfiguration.Container.OnlyWhenPersistenceRequired, + AssetContainerProvider = new NDMFContainerProvider(ctx), + DefaultsProvider = new AacDefaultsProvider(UseWriteDefaults) + }); + + // Create a new object in the scene. We will add Modular Avatar components inside it. + var modularAvatar = MaAc.Create(new GameObject(SystemName) + { + transform = { parent = ctx.AvatarRootTransform } + }); + + var ctrl = aac.NewAnimatorController(); + for (int component_i = 0; component_i < components.Length; component_i++) { + GenerateTextAnimator cfg = components[component_i]; + var layer = ctrl.NewLayer($"Chatbox Plumbing (component #{component_i})"); + + var baseState = layer.NewState("Entry (noop)").WithWriteDefaultsSetTo(false); + + // Create "always one" value to use in DBT. + // See https://vrc.school/docs/Other/DBT-Combining/ for details. + layer.OverrideValue(layer.FloatParameter("AlwaysOne"), 1.0f); + + // Create blendtrees. One for each block of data. + var block_trees = new List<AacFlState>(); + AacFlState last_block_state = null; + // For each block. + for (int i = 0; i < cfg.numBlocks; i++) { + var block_tree = aac.NewBlendTree().Direct(); + + // Create block-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) { + string matPropB = string.Format(cfg.matPropPerBlock[ii], i); + string oscParamB = cfg.oscParamPerBlock[ii]; + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropB, oscParamB, "AlwaysOne"); + } + + // For each datum per block. + for (int j = 0; j < cfg.blockWidth; j++) { + // Create (block, datum)-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) { + string matPropBD = string.Format(cfg.matPropPerBlockDatum[ii], i, j); + string oscParamBD = string.Format(cfg.oscParamPerBlockDatum[ii], j); + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBD, oscParamBD, "AlwaysOne"); + } + + // For each byte per datum. + for (int k = 0; k < cfg.bytesPerDatum; k++) { + // Create (block, datum, byte)-level animations. + for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) { + string matPropBDB = string.Format(cfg.matPropPerBlockDatumByte[ii], i, j, k); + //Debug.Log($"animating property: {matPropBDB}"); + string oscParamBDB = string.Format(cfg.oscParamPerBlockDatumByte[ii], j, k); + block_tree = Add8BitBlendTree(aac, layer, cfg, block_tree, matPropBDB, oscParamBDB, "AlwaysOne"); + } + } + } + + var cur_block_state = layer.NewState($"Block {i}").WithAnimation(block_tree); + if (last_block_state != null) { + cur_block_state = cur_block_state.Under(last_block_state); + } + last_block_state = cur_block_state; + block_trees.Add(cur_block_state); + } + + // Create transitions to each block's blendtree. + for (int i = 0; i < cfg.numBlocks; i++) { + block_trees[i].TransitionsFromAny() + //.WithInterruption(TransitionInterruptionSource.Source) + .When(layer.IntParameter(cfg.oscPointerParam).IsEqualTo(i)); + } + + // Create sync params (VRCSDK) + for (int ii = 0; ii < cfg.oscParamPerBlock.Length; ii++) { + string bParam = cfg.oscParamPerBlock[ii]; + modularAvatar.NewParameter(layer.FloatParameter(bParam)); + } + for (int i = 0; i < cfg.blockWidth; i++) { + for (int ii = 0; ii < cfg.oscParamPerBlockDatum.Length; ii++) { + string bdParam = string.Format(cfg.oscParamPerBlockDatum[ii], i); + modularAvatar.NewParameter(layer.FloatParameter(bdParam)); + } + for (int j = 0; j < cfg.bytesPerDatum; j++) { + for (int ii = 0; ii < cfg.oscParamPerBlockDatumByte.Length; ii++) { + string bdbParam = string.Format(cfg.oscParamPerBlockDatumByte[ii], i, j); + modularAvatar.NewParameter(layer.FloatParameter(bdbParam)).WithDefaultValue(1.0f); + } + } + } + + modularAvatar.NewParameter(layer.IntParameter(cfg.oscPointerParam)); + } + + // By creating a Modular Avatar Merge Animator component, + // our animator controller will be added to the avatar's FX layer. + modularAvatar.NewMergeAnimator(ctrl.AnimatorController, VRCAvatarDescriptor.AnimLayerType.FX); + } + } + + // (For AAC 1.2.0 and above) This is recommended starting from NDMF 1.6.0. You only need to define this class once. + internal class NDMFContainerProvider : IAacAssetContainerProvider + { + private readonly BuildContext _ctx; + public NDMFContainerProvider(BuildContext ctx) => _ctx = ctx; + public void SaveAsPersistenceRequired(Object objectToAdd) => _ctx.AssetSaver.SaveAsset(objectToAdd); + public void SaveAsRegular(Object objectToAdd) { } // Let NDMF crawl our assets when it finishes + public void ClearPreviousAssets() { } // ClearPreviousAssets is never used in non-destructive contexts + } +} +#endif + diff --git a/README.md b/README.md new file mode 100644 index 0000000..eaeceea --- /dev/null +++ b/README.md @@ -0,0 +1,282 @@ +# Optimized text paging for VRChat + +It is sometimes useful to send text data into VRChat, for example for +speech-to-text (STT). This is typically done naively, with a "block" of +n 8-bit characters\* sent in along with an 8-bit pointer. Since avatars can only +send 256 bits at 3 Hz\*\* with OSC, this means you can only send (256 - 8) / 8 = +31 characters per sync. The average English word is 4.79 characters long, so +if we naively send in 1 character per byte, then we get a speed limit of 6.47 +words per sync. The works out to ~20 words per second or ~1200 words per minute +(wpm). Adults typically read at about 238-260 wpm. + +\* Typically ASCII encoding is used. + +\*\* Experimentally, 3 Hz is the fastest you can reliably page data with OSC in +busy instances. + +Sending in one character per byte gives you (1200/256) ~= 4.7 wpm per OSC bit +used. Thus to reach a typical reading speed, you need to use (260/4.7) = 55.5 +OSC bits. The goal of this module is to get more out of these bits by +compressing text over the wire. + +## Unigram tokenizer + +Byte pair encoding (BPE) is an encoding scheme frequently used in natural +language processing (NLP) contexts. For any language with a fixed character set +(e.g. an alphabet), you can count up how often each character is used in a +large corpus of text. Then you can repeatedly join together characters and +assign a unique token to joined characters in the order of the most frequently +occurring sequences. You wind up with a lookup table like this: + +``` +0: [UNK] +1: <s> +2: </s> +3: [CLS] +4: [SEP] +... +100: from +101: but +102: he +103: e +104: now +... +10000: eo +10001: is currently +10002: dish +10003: Mi +10004: 6) +``` + +The above example is from a unigram\* sentencepiece organizer that I trained. + +\* A unigram tokenizer is a variant of the byte-pair tokenizer. Where a byte +pair tokenizer views inputs as arbitrary sequences of bytes, a unigram +tokenizer views it as a sequence of letters. + +The tokenizer has a vocabulary size of 65,536 tokens. It was trained on +opensubtitles and 5% of wikipedia, with `unidecode` normalization applied to +limit training data to ASCII. Subword lengths are distributed as follows: + +Subword length histogram: +1: 95 +2: 1032 +3: 3350 +4: 5439 +5: 5445 +6: 5082 +7: 5334 +8: 5866 +9: 6172 +10: 5934 +11: 5329 +12: 4698 +13: 4016 +14: 3191 +15: 2434 +16: 2095 + +In the test set - 25% wikipedia 75% conversational English - +this tokenizer yields 4.896 characters per token. (Recall that +the average English word is 4.8 characters long. Not bad!) + +If we naively send these tokens into the game with a 16-bit number, we can send +floor((256-8)/16) = 15 tokens per sync. This gives us an average of 73.44 +characters per sync - more than 2x higher than the naive approach's 31. + +Here is how the unigram-tokenized scheme fairs against the naive scheme in +every possible configuration of bits used: + +``` +bits naive rate bpe rate speedup factor +8 n/a n/a n/a +16 1 n/a 0.000 +24 2 4.896 2.448 +32 3 4.896 1.632 +40 4 9.792 2.448 +48 5 9.792 1.958 +56 6 14.688 2.448 +64 7 14.688 2.098 +72 8 19.584 2.448 +80 9 19.584 2.176 +88 10 24.480 2.448 +96 11 24.480 2.225 +104 12 29.376 2.448 +112 13 29.376 2.260 +120 14 34.272 2.448 +128 15 34.272 2.285 +136 16 39.168 2.448 +144 17 39.168 2.304 +152 18 44.064 2.448 +160 19 44.064 2.319 +168 20 48.960 2.448 +176 21 48.960 2.331 +184 22 53.856 2.448 +192 23 53.856 2.342 +200 24 58.752 2.448 +208 25 58.752 2.350 +216 26 63.648 2.448 +224 27 63.648 2.357 +232 28 68.544 2.448 +240 29 68.544 2.364 +248 30 73.440 2.448 +256 31 73.440 2.369 +``` + +([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing)) + +I reserve 39 token slots for sequences of whitespace characters of length 2-40. This helps simplify formatting. To end a line or position text, you can just send in the exact right number of spaces, and a fixed-width font renderer will position things as intended. + +## Paging data into shader + +Sending this data to a shader is pretty simple: + +- An OSC app encodes a string into tokens and pages it into the game with OSC. + - The app sends a pointer of *where* the tokens should be rendered along with them. Since the tokens can encode a variable length string, the pointer must be able to point to any spot in the rendering window. Thus we are limited to a 256-bit display with an 8-bit pointer, or a 64K display with a 16-bit pointer. We call this the visual pointer. + - A second pointer tells the animator which shader parameters the tokens and visual pointer should be written to. This can be 8-bit. We call this the logical pointer. +- Animator uses the logical pointer to decide which shader parameters to send the visual pointer and tokens to. + +Here is the expected speedup in every possible configuration with a 1-byte +overhead (1BO) or a 2-byte overhead (2BO): + +``` +bits naive rate bpe rate (1BO) speedup bpe rate (2BO) speedup +8 n/a n/a n/a n/a n/a +16 1 0.000 0.000 0.000 0.000 +24 2 0.000 0.000 0.000 0.000 +32 3 4.896 1.632 0.000 0.000 +40 4 4.896 1.224 4.896 1.224 +48 5 9.792 1.958 4.896 0.979 +56 6 9.792 1.632 9.792 1.632 +64 7 14.688 2.098 9.792 1.399 +72 8 14.688 1.836 14.688 1.836 +80 9 19.584 2.176 14.688 1.632 +88 10 19.584 1.958 19.584 1.958 +96 11 24.480 2.225 19.584 1.780 +104 12 24.480 2.040 24.480 2.040 +112 13 29.376 2.260 24.480 1.883 +120 14 29.376 2.098 29.376 2.098 +128 15 34.272 2.285 29.376 1.958 +136 16 34.272 2.142 34.272 2.142 +144 17 39.168 2.304 34.272 2.016 +152 18 39.168 2.176 39.168 2.176 +160 19 44.064 2.319 39.168 2.061 +168 20 44.064 2.203 44.064 2.203 +176 21 48.960 2.331 44.064 2.098 +184 22 48.960 2.225 48.960 2.225 +192 23 53.856 2.342 48.960 2.129 +200 24 53.856 2.244 53.856 2.244 +208 25 58.752 2.350 53.856 2.154 +216 26 58.752 2.260 58.752 2.260 +224 27 63.648 2.357 58.752 2.176 +232 28 63.648 2.273 63.648 2.273 +240 29 68.544 2.364 63.648 2.195 +248 30 68.544 2.285 68.544 2.285 +256 31 73.440 2.369 68.544 2.211 +``` + +([Spreadsheet](https://docs.google.com/spreadsheets/d/1d9SEZvo3Q-6U_Wf9nuGRKXxndUhKn2V3Q0Ox0nOB4T4/edit?usp=sharing)) + +As you can see, a 2-byte visual pointer is very damaging to the speedup at low bit budgets. So in bit-constrained setups we should definitely use a smaller display. + +Notably, *there is only one crossover point*. If all configurations except the 2-byte overhead 48-bit configuration, BPE-based paging is always\* faster. + +\* Always, going off of the *expected* rate. If you get unlucky and your tokens all decode to 1 character, then BPE-based paging is about 50% *slower* than naive encoding. + +Because the Unity animator sucks shit, we're going to decode tokens on the GPU. As a refresher, the GPU sees data like this: + +``` +_Text_Block00_Visual_Ptr: 0 +_Text_Block00_Token00: 13,766 +_Text_Block00_Token01: 84 +_Text_Block01_Visual_Ptr: 13 +_Text_Block01_Token00: 599 +_Text_Block01_Token01: 8,301 +... +``` + +I.e. it sees "blocks" of data with tokens and visual pointers. The visual pointer just says where on a grid it should draw the subwords represented by the tokens. + +The pixel shader can trivially work out what grid location the current pixel belongs to. By scanning through the visual pointers, it can work out which block it has to draw. + +We can generate a function like this: + +```c +#define BLOCK_WIDTH 2 +void GetBlock(uint which_block, out float data[BLOCK_WIDTH]) { + [loop] + for (uint i = 0; i < BLOCK_WIDTH; i++) { + data[i] = _Text_Blocks[which_block][i]; + } +} + +// Get the tokens that cover `screen_ptr`. Also returns `block_ptr`, the +// location where this block of tokens begins. +void GetTokens(uint screen_ptr, out uint block_ptr, out uint tokens[BLOCK_WIDTH]) { + uint which_block; + [loop] + for (uint i = 0; i < BLOCK_WIDTH; i++) { + if (screen_ptr >= _Text_Block_Visual_Ptrs[i]) { + which_block = i; + } + } + GetBlock(which_block, tokens); +} +``` + +## GPU decoding + +Now we have to translate the tokens into text. I do this with a texture laid out as follows: + +1. A fixed-length array of (offset, length) pairs. Offset is 24 bits, giving us an address space of about 16 million slots. Length is 8 bits, but as established above, the longest token is only 16 characters long. So we're wasting about 4 bits. This tells us we should use an RGBA texture. + +2. A variable length array of ASCII-encoded strings. Each slot is RGBA, so it can hold 4 characters. + +My tokenizer's vocabulary is 65,536 tokens. If we add up the lengths of every token, rounding them up to the nearest multiple of 4, we get the length 667,532 This means that we need 166,883 slots to fit the actual content of the vocabulary. + +So, the entire vocabulary - length+offset head and content - requires a 32-bit RGBA texture with 232,419 slots. We'll just jam this into a 512x512 texture, at an occupancy ratio of 88.66% (11.34% waste). The total VRAM usage of that lookup table (LUT) is 1 MiB. + +We want to implement this API: + +```c +uint GetChar(uint screen_ptr); +``` + +Internally, it must: + +1. Get tokens. (GetTokens - already done) +2. Figure out which token covers the screen\_ptr. +3. Figure out which character in the token covers the screen\_ptr. + +Let's break down [2]. We can get the length of a token with a single texture tap. So, naively, we can just scan through the tokens in the current block, add up their lengths, and stop once we find a token covering the current slot. The scan incurs a worst-case cost of `BLOCK_WIDTH` texture taps. The character lookup is then a single tap. + +```c +// Gets the length of the subword encoded by the token. Performs one texture +// tap. +void TokenLengthOffset(uint token, out uint length, out uint offset); +// Gets the nth character of the token stored at `token_offset`. +uint GetTokenChar(uint token_offset, uint nth); + +uint GetChar(uint screen_ptr) { + uint block_ptr; + uint tokens[BLOCK_WIDTH]; + GetTokens(screen_ptr, tokens, block_ptr); + uint start = block_ptr; + uint covering_token = 0; + uint token_ptr = block_ptr; + uint token_offset; + for (uint ii = 0; ii < BLOCK_WIDTH; ii++) { + uint token_length; + TokenLengthOffset(tokens[ii], token_length, token_offset); + if (token_ptr + token_length >= screen_ptr) { + covering_token = tokens[ii]; + } + token_ptr += token_length; + } + // covering_token covers screen_ptr. It starts at token_ptr. + return GetTokenChar(token_offset, screen_ptr - token_ptr); +} +``` + +That's actually it for the GPU decoding. Once you have the character, you can use standard fixed-width font rendering techniques to display it (e.g. [disinfo](https://github.com/yum-food/disinfo) and [msdf](https://github.com/Chlumsky/msdfgen)). + diff --git a/bpe_dump.py b/bpe_dump.py new file mode 100644 index 0000000..74ee08f --- /dev/null +++ b/bpe_dump.py @@ -0,0 +1,61 @@ +import math +import sentencepiece as spm + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + sp = spm.SentencePieceProcessor() + sp.load(model_path) + return sp + +tokenizer = get_tokenizer() + +print(f"vocabulary size: {tokenizer.get_piece_size()}") +# Sentencepiece uses U+2581 (lower one eighth block) to indicate a space before +# a subword. +sp_space = chr(9601) +tokens_with_non_ascii = set() +subword_len_histo = dict() +# The sum of the lengths of each subword in the vocabulary. These are rounded +# up to 4 characters. +vocab_len_4c_quantized = 0 + +for i in range(tokenizer.get_piece_size()): + k = tokenizer.id_to_piece(i) + v = i + print(f" Original token ({v}): {repr(k)} ({' '.join(str(ord(k_c)) for k_c in k)})") + for k_c in k: + if ord(k_c) > 127 and ord(k_c) != 9601: + tokens_with_non_ascii.add(k) + break + k_processed = k.replace(sp_space, ' ') + if not k.startswith(sp_space) and k not in ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]: + k_processed = k + else: + k_processed = k_processed + + current_len = len(k_processed) + if current_len in subword_len_histo: + subword_len_histo[current_len] += 1 + else: + subword_len_histo[current_len] = 1 + + vocab_len_4c_quantized += math.ceil(current_len / 4.0) * 4.0 + print(f" {v}: {k_processed}") + +print(f"Num tokens with non-ascii: {len(tokens_with_non_ascii)} ({100 * len(tokens_with_non_ascii) / tokenizer.get_piece_size():.2f})%") + +print(f"Subword length histogram:") +avg_subword_len = 0 +total_pieces_for_avg = 0 +for k_len, v_count in sorted(subword_len_histo.items(), key=lambda x: x[0]): + avg_subword_len += k_len * v_count + total_pieces_for_avg += v_count + print(f" {k_len}: {v_count}") + +if total_pieces_for_avg > 0: + avg_subword_len /= total_pieces_for_avg + print(f"Average subword length: {avg_subword_len:.4f}") +else: + print("Average subword length: N/A (no pieces analyzed)") + +print(f"Sum of all subword lengths, quantized to 4 character chunks: {vocab_len_4c_quantized}") diff --git a/generate_bpe_lut.py b/generate_bpe_lut.py new file mode 100644 index 0000000..72b1201 --- /dev/null +++ b/generate_bpe_lut.py @@ -0,0 +1,119 @@ +from math import ceil, floor +from PIL import Image +from unidecode import unidecode +import sentencepiece as spm + +IMG_RES = 512 # square image + +def get_tokenizer(): + use_sentencepiece = True + + if not use_sentencepiece: + from tokenizers import Tokenizer + tokenizer_json = "./custom_wordpiece_tokenizer_65k/tokenizer.json" + print(f"Loading Tokenizers library tokenizer from: {tokenizer_json}") + return Tokenizer.from_file(tokenizer_json) + else: + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def get_words(): + tokenizer = get_tokenizer() + + print(f"vocabulary size: {tokenizer.get_piece_size()}") + # sp_space = sentencepiece space. + # A special character sentencepiece uses to represent spaces before words. + sp_space = chr(9601) + words = [] + + # Accumulate words into a list, indexed by the token number. Sanitize them as + # you go. + for i in range(tokenizer.get_piece_size()): + word = tokenizer.id_to_piece(i) + tok = i + #print(f" Original token ({tok}): {repr(word)} ({' '.join(str(ord(c)) for c in word)})") + word_sanitized = "" + # Dirty hack: convert non-ASCII characters to nearest ASCII equivalent + for c in word: + if ord(c) > 127 and c != sp_space: + c_plain = unidecode(c) + print(f" Resolved {c} to {c_plain}") + word_sanitized += c_plain + else: + word_sanitized += c + # Replace sp_space with ' ' + word_sanitized = word_sanitized.replace(sp_space, ' ') + #print(f" {tok}: {word_sanitized}") + words.append(word_sanitized) + + # Special word: empty string. SentencePiece doesn't support this natively. + words.append('') + + return words + +# Fold a flat index into a IMG_RESxIMG_RES box. Return the (x,y) coordinate of +# the folded index. +def fold_idx(flat_idx): + return (flat_idx % IMG_RES, int(floor(flat_idx / IMG_RES))) + +def unfold_idx(coord): + return coord[0] + coord[1] * IMG_RES + +assert unfold_idx(fold_idx(1533125)) == 1533125 +assert unfold_idx(fold_idx(8538235)) == 8538235 +assert fold_idx(unfold_idx((192,235))) == (192,235) +assert fold_idx(unfold_idx((83,388))) == (83,388) + +def generate_lut(words, filename): + # Write the texture header. + black = (0, 0, 0, 255) + img = Image.new('RGBA', (IMG_RES, IMG_RES), black) + + # The header is `len(words)` slots long. Thus the actual LUT content starts at + # the index `len(words)`. + pixel_data = img.load() + lut_ptr = len(words) + for i in range(0, len(words)): + # Get pointer to the actual word data. + tok_ptr = lut_ptr + tok_len = len(words[i]) + rgba = ((tok_ptr >> 0) & 0xFF, + (tok_ptr >> 8) & 0xFF, + (tok_ptr >> 16) & 0xFF, + tok_len) + print(f"Writing {rgba} to {i} / {fold_idx(i)}") + idx_x, idx_y = fold_idx(i) + pixel_data[idx_x, idx_y] = rgba + + for j in range(0, ceil(tok_len/4.0)): + quad_ptr = tok_ptr + j + tok_0 = ord(words[i][j*4]) + tok_1 = ord(words[i][j*4+1] if tok_len > j*4+1 else ' ') + tok_2 = ord(words[i][j*4+2] if tok_len > j*4+2 else ' ') + tok_3 = ord(words[i][j*4+3] if tok_len > j*4+3 else ' ') + rgba = (tok_0, tok_1, tok_2, tok_3) + idx_x, idx_y = fold_idx(quad_ptr) + print(f" Writing {rgba} to {quad_ptr} / {fold_idx(quad_ptr)}") + pixel_data[idx_x, idx_y] = rgba + + # Advance the LUT ptr. Since we store 4 chars per pixel (RGBA), we advance + # it by ceil(tok_len/4). + lut_ptr += int(ceil(tok_len/4.0)) + + pretty = False + if pretty: + for y in range(0, IMG_RES): + for x in range(0, IMG_RES): + rgba = pixel_data[x, y] + pixel_data[x, y] = (rgba[0], rgba[1], rgba[2], 255) + + print(f"Saving to {filename}") + img.save(filename) + +if __name__ == "__main__": + words = get_words() + generate_lut(words, "bpe_lut.png") diff --git a/generate_tokenizer.py b/generate_tokenizer.py new file mode 100644 index 0000000..88903a8 --- /dev/null +++ b/generate_tokenizer.py @@ -0,0 +1,525 @@ +# !!! AI ARTIFACT !!! +# This file was primarily written with AI. + +import os +import argparse +from datasets import load_dataset, interleave_datasets +import sentencepiece as spm +import itertools +import random +from unidecode import unidecode + +# --- Dataset 1: Wikipedia --- +DATASET_NAME_WIKI = "wikipedia" +DATASET_CONFIG_WIKI = "20220301.en" + +DATASET_SPLIT_WIKI = "train" +TEXT_COLUMN_WIKI = "text" + +# --- Dataset 2: DailyDialog --- +DATASET_NAME_DD = "daily_dialog" +DATASET_SPLIT_DD = "train" +UTTERANCE_COLUMN_DD = "dialog" + +# --- Dataset 3: BlendedSkillTalk (includes Persona-Chat) --- +DATASET_NAME_BST = "blended_skill_talk" +DATASET_SPLIT_BST = "train" + +# --- Dataset 4: OpenSubtitles --- +DATASET_NAME_OS = "Helsinki-NLP/open_subtitles" +DATASET_LANG_PAIR_OS = ("en", "fr") # Load en-fr and use the 'en' part +DATASET_SPLIT_OS = "train" +TEXT_COLUMN_OS = "en" # Access via item['translation']['en'] + +# Reserve one space for special "empty" token. +VOCAB_SIZE = 65535 +OUTPUT_DIR = "./custom_unigram_tokenizer_65k" +MODEL_PREFIX = os.path.join(OUTPUT_DIR, "unigram") +MODEL_FILE = MODEL_PREFIX + ".model" +VOCAB_FILE = MODEL_PREFIX + ".vocab" + +UNK_TOKEN = "[UNK]" +PAD_TOKEN = "[PAD]" +CONTROL_SYMBOLS = ["[CLS]", "[SEP]", "[MASK]"] + +BATCH_SIZE = 1000 + +def wiki_iterator(dataset_wiki, batch_size=BATCH_SIZE): + if dataset_wiki: + for i in range(0, len(dataset_wiki), batch_size): + yield [unidecode(text) for text in dataset_wiki[i : i + batch_size][TEXT_COLUMN_WIKI] if text] + +def dd_iterator(dataset_dd, batch_size=BATCH_SIZE): + if dataset_dd: + current_batch = [] + for dialogue in dataset_dd: + utterances = dialogue[UTTERANCE_COLUMN_DD] + for utterance in utterances: + if utterance: + normalized_utterance = unidecode(utterance) + current_batch.append(normalized_utterance) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def bst_iterator(dataset_bst, batch_size=BATCH_SIZE): + if dataset_bst: + current_batch = [] + for session in dataset_bst: + texts_to_add = [] + if session.get("previous_utterance"): + texts_to_add.append(session["previous_utterance"]) + if session.get("free_messages"): + texts_to_add.extend(session["free_messages"]) + if session.get("guided_messages"): + texts_to_add.extend(session["guided_messages"]) + + for text in texts_to_add: + if text and isinstance(text, str): + normalized_text = unidecode(text) + current_batch.append(normalized_text) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def os_iterator(dataset_os, batch_size=BATCH_SIZE, lang_code=TEXT_COLUMN_OS): + if dataset_os: + current_batch = [] + for item in dataset_os: + text = item['translation'][lang_code] + if text: + normalized_text = unidecode(text) + current_batch.append(normalized_text) + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + if current_batch: + yield current_batch + +def count_wiki_items_chars(dataset_wiki): + item_count = 0 + char_count = 0 + if dataset_wiki: + item_count = len(dataset_wiki) + try: + char_count = sum(len(unidecode(text)) for text in dataset_wiki[TEXT_COLUMN_WIKI] if text and isinstance(text, str)) + except Exception: + print(f"Warning: Direct column access for char count failed for {DATASET_NAME_WIKI}. Iterating row by row (slower).") + char_count = sum(len(unidecode(row[TEXT_COLUMN_WIKI])) for row in dataset_wiki if row.get(TEXT_COLUMN_WIKI) and isinstance(row[TEXT_COLUMN_WIKI], str)) + + return item_count, char_count + +def count_dd_items_chars(dataset_dd): + item_count = 0 + char_count = 0 + if dataset_dd: + for dialogue in dataset_dd: + utterances = dialogue[UTTERANCE_COLUMN_DD] + for utterance in utterances: + if utterance and isinstance(utterance, str): + item_count += 1 + char_count += len(unidecode(utterance)) + return item_count, char_count + +def count_bst_items_chars(dataset_bst): + item_count = 0 + char_count = 0 + if dataset_bst: + for session in dataset_bst: + texts_to_process = [] + # Gather all potential text strings first + prev_utt = session.get("previous_utterance") + if prev_utt and isinstance(prev_utt, list): + if len(prev_utt) > 0 and isinstance(prev_utt[0], str): + texts_to_process.append(prev_utt[0]) + elif prev_utt and isinstance(prev_utt, str): + texts_to_process.append(prev_utt) + + free_msgs = session.get("free_messages") + if free_msgs and isinstance(free_msgs, list): + for item in free_msgs: + if isinstance(item, list): + texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str)) + elif isinstance(item, str): + texts_to_process.append(item) + + guided_msgs = session.get("guided_messages") + if guided_msgs and isinstance(guided_msgs, list): + for item in guided_msgs: + if isinstance(item, list): + texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str)) + elif isinstance(item, str): + texts_to_process.append(item) + + # Count items and chars from the gathered list + for text in texts_to_process: + if text and isinstance(text, str): + normalized_text = unidecode(text) + if normalized_text: + item_count += 1 + char_count += len(normalized_text) + + return item_count, char_count + +def count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS): + item_count = 0 + char_count = 0 + if dataset_os: + for item in dataset_os: + text = item['translation'][lang_code] + if text and isinstance(text, str): + normalized_text = unidecode(text) + if normalized_text: # Ensure not empty after unidecode + item_count += 1 + char_count += len(normalized_text) + return item_count, char_count + +def load_and_count_datasets(wiki_fraction, subtitles_fraction): + """Loads, potentially shrinks, and counts items/chars in datasets.""" + datasets = {} + counts = {} + total_items = 0 + total_chars = 0 + + # --- Wikipedia --- + print(f"Loading dataset 1: {DATASET_NAME_WIKI} ({DATASET_CONFIG_WIKI}), split: {DATASET_SPLIT_WIKI}") + dataset_wiki_full = load_dataset(DATASET_NAME_WIKI, DATASET_CONFIG_WIKI, split=DATASET_SPLIT_WIKI, trust_remote_code=True) + print(f"Original Wikipedia dataset size: {len(dataset_wiki_full):,}") + + # Shrink the Wikipedia dataset + split_test_size = 1.0 - wiki_fraction + shrunk_dataset_split = dataset_wiki_full.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000)) + dataset_wiki = shrunk_dataset_split['train'] + print(f"Using {wiki_fraction*100:.3f}% of Wikipedia dataset: {len(dataset_wiki):,} items") + + count_wiki, chars_wiki = count_wiki_items_chars(dataset_wiki) + print(f"Wikipedia dataset loaded (shrunk). Precise text items: {count_wiki:,}, Characters: {chars_wiki:,}") + datasets['wiki'] = dataset_wiki + counts['wiki'] = (count_wiki, chars_wiki) + total_items += count_wiki + total_chars += chars_wiki + + # --- DailyDialog --- + print(f"\nLoading dataset 2: {DATASET_NAME_DD}, split: {DATASET_SPLIT_DD}") + dataset_dd = load_dataset(DATASET_NAME_DD, split=DATASET_SPLIT_DD, trust_remote_code=True) + count_dd, chars_dd = count_dd_items_chars(dataset_dd) + print(f"DailyDialog dataset loaded. Precise text items (utterances): {count_dd:,}, Characters: {chars_dd:,}") + datasets['dd'] = dataset_dd + counts['dd'] = (count_dd, chars_dd) + total_items += count_dd + total_chars += chars_dd + + # --- BlendedSkillTalk --- + print(f"\nLoading dataset 3: {DATASET_NAME_BST}, split: {DATASET_SPLIT_BST}") + dataset_bst = load_dataset(DATASET_NAME_BST, split=DATASET_SPLIT_BST, trust_remote_code=True) + count_bst, chars_bst = count_bst_items_chars(dataset_bst) + print(f"BlendedSkillTalk dataset loaded. Precise text items (extracted): {count_bst:,}, Characters: {chars_bst:,}") + datasets['bst'] = dataset_bst + counts['bst'] = (count_bst, chars_bst) + total_items += count_bst + total_chars += chars_bst + + # --- OpenSubtitles --- + print(f"\nLoading dataset 4: {DATASET_NAME_OS}, lang_pair: {DATASET_LANG_PAIR_OS}, split: {DATASET_SPLIT_OS}") + # Note: OpenSubtitles can be very large. Consider streaming or specific configurations if memory is an issue. + # For now, loading a standard configuration. + dataset_os = load_dataset(DATASET_NAME_OS, lang1=DATASET_LANG_PAIR_OS[0], lang2=DATASET_LANG_PAIR_OS[1], split=DATASET_SPLIT_OS, trust_remote_code=True) + split_test_size = 1.0 - subtitles_fraction + dataset_os = dataset_os.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))['train'] + count_os, chars_os = count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS) + print(f"OpenSubtitles dataset loaded. Precise text items: {count_os:,}, Characters: {chars_os:,}") + datasets['os'] = dataset_os + counts['os'] = (count_os, chars_os) + total_items += count_os + total_chars += chars_os + + print(f"\nTotal precise text items from loaded datasets: {total_items:,}") + print(f"Total characters from loaded datasets: {total_chars:,}") + + return datasets, counts + +def train_tokenizer(model_prefix, datasets, counts): + """Trains a Unigram tokenizer using SentencePiece and saves it.""" + total_items = sum(c[0] for c in counts.values()) + total_chars = sum(c[1] for c in counts.values()) + if total_items == 0: + print("Error: No items found in the datasets to train on. Exiting training.") + return + + print(f"\nTotal precise text items for training: {total_items:,}") + print(f"Total characters for training: {total_chars:,}") + + output_dir = os.path.dirname(model_prefix) + os.makedirs(output_dir, exist_ok=True) + + print(f"\nStarting SentencePiece Unigram tokenizer training with vocab size: {VOCAB_SIZE}") + print(f"Using combined dataset iterator.") + print(f"Output model prefix: {model_prefix}") + + iterators_to_chain = [] + def flatten_iterator(iterator): + for batch in iterator: + for item in batch: + yield item + + if datasets.get('wiki') and counts['wiki'][0] > 0: + iterators_to_chain.append(flatten_iterator(wiki_iterator(datasets['wiki']))) + if datasets.get('dd') and counts['dd'][0] > 0: + iterators_to_chain.append(flatten_iterator(dd_iterator(datasets['dd']))) + if datasets.get('bst') and counts['bst'][0] > 0: + iterators_to_chain.append(flatten_iterator(bst_iterator(datasets['bst']))) + if datasets.get('os') and counts['os'][0] > 0: + iterators_to_chain.append(flatten_iterator(os_iterator(datasets['os']))) + + if not iterators_to_chain: + print("Error: No valid dataset iterators available for training. Exiting.") + return + + combined_iterator = itertools.chain(*iterators_to_chain) + + # Include whitespace symbols so we can efficiently break lines. + # If we include the single space, it prevents the tokenizer from merging + # spaces with regular words, and tanks the average chars/token. + # This many tokens is kinda overkill, but it gives us a way to efficiently + # clear even fairly large boards, so I think it's worth. + whitespace_symbols = [] + for i in range(2, 40): + whitespace_symbols.append('▁' * i) + + spm.SentencePieceTrainer.train( + sentence_iterator=combined_iterator, + model_prefix=model_prefix, + vocab_size=VOCAB_SIZE, + model_type='unigram', + character_coverage=1.0, + unk_piece=UNK_TOKEN, + pad_piece=PAD_TOKEN, + control_symbols=CONTROL_SYMBOLS, + user_defined_symbols=whitespace_symbols, + # These whitespace options must be false, or else whitespace won't + # be respected when encoding. + add_dummy_prefix=False, + remove_extra_whitespaces=False, + split_by_whitespace=False, + num_threads=os.cpu_count(), + input_sentence_size=total_items, + ) + print("\nTraining finished.") + print(f"SentencePiece model saved to: {model_prefix}.model") + print(f"SentencePiece vocabulary saved to: {model_prefix}.vocab") + +def extract_text_samples(dataset, count, num_samples, text_extractor_func): + """Extracts a specified number of text samples from a dataset.""" + samples = [] + if dataset is None or count == 0 or num_samples == 0: + return samples + # Take samples from the beginning, ensure we don't exceed dataset size + actual_samples_to_take = min(num_samples, count) + # Use the provided function to extract text correctly for this dataset type + samples = text_extractor_func(dataset.select(range(actual_samples_to_take))) + return [unidecode(s) for s in samples if s and isinstance(s,str)] + +def wiki_text_extractor(dataset_slice): + """Extracts text from a slice of the Wikipedia dataset.""" + return [text for text in dataset_slice[TEXT_COLUMN_WIKI] if text] + +def dd_text_extractor(dataset_slice): + """Extracts text from a slice of the DailyDialog dataset.""" + texts = [] + for dialogue in dataset_slice: + texts.extend(utterance for utterance in dialogue[UTTERANCE_COLUMN_DD] if utterance) + return texts + +def bst_text_extractor(dataset_slice): + """Extracts text from a slice of the BlendedSkillTalk dataset.""" + texts = [] + for session in dataset_slice: + if session.get("previous_utterance"): + texts.append(session["previous_utterance"]) + if session.get("free_messages"): + texts.extend(session["free_messages"]) + if session.get("guided_messages"): + texts.extend(session["guided_messages"]) + return [t for t in texts if t and isinstance(t,str)] # Ensure only valid strings + +def os_text_extractor(dataset_slice, lang_code=TEXT_COLUMN_OS): + """Extracts text from a slice of the OpenSubtitles dataset.""" + texts = [] + for item in dataset_slice: + text = item['translation'][lang_code] + if text and isinstance(text, str): + texts.append(text) + return texts + +def test_tokenizer(model_path, datasets, counts, sample_size=1000): + print("\n--- Testing the trained Unigram (SentencePiece) tokenizer on data sample ---") + if sample_size <= 0: + print("Error: Sample size for testing must be positive.") + return + + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model from: {model_path}") + print(f"Vocabulary size: {sp.get_piece_size()}") + + # Identify available datasets with content + available_datasets = { + name: data + for name, data in datasets.items() + if counts[name][0] > 0 + } + num_available_datasets = len(available_datasets) + + if num_available_datasets == 0: + print("Warning: No data available in loaded datasets to sample for testing.") + return + + print(f"Found {num_available_datasets} non-empty dataset(s) for testing.") + print(f"Attempting to sample up to {sample_size} total items equally from these datasets...") + + # Calculate equal number of samples per available dataset + samples_per_dataset = sample_size // num_available_datasets + print(f"Targeting approximately {samples_per_dataset} samples per dataset.") + + test_samples = [] + dataset_extractors = { + 'wiki': wiki_text_extractor, + 'dd': dd_text_extractor, + 'bst': bst_text_extractor, + 'os': os_text_extractor, + } + + actual_samples_collected = {} + + # Sample equally from each available dataset + for name, dataset in available_datasets.items(): + count = counts[name][0] + extractor = dataset_extractors[name] + + # Determine the target number of samples for *this* dataset + num_samples_to_target = min(samples_per_dataset, count) + if num_samples_to_target <= 0: + print(f" Skipping '{name}' (no items requested or available).") + actual_samples_collected[name] = 0 + continue + + print(f" Sampling from '{name}' (target: {num_samples_to_target} items)...") + + items_before = len(test_samples) + # For OpenSubtitles, pass the lang_code if the extractor needs it + if name == 'os': + extracted_items = extract_text_samples(dataset, count, num_samples_to_target, lambda ds_slice: extractor(ds_slice, lang_code=TEXT_COLUMN_OS)) + else: + extracted_items = extract_text_samples(dataset, count, num_samples_to_target, extractor) + + # Limit the extracted items to the target number + final_samples_for_dataset = extracted_items[:num_samples_to_target] + + test_samples.extend(final_samples_for_dataset) + items_added = len(final_samples_for_dataset) + actual_samples_collected[name] = items_added + print(f" Added {items_added} items from '{name}'.") + + actual_sample_size = len(test_samples) + if actual_sample_size == 0: + print("\nCould not gather any samples for testing.") + print("--- Test finished ---") + return + + print(f"\n--- Starting Test Run ---") + print(f"Testing on {actual_sample_size} sampled text items (final count).") + print(f"Samples breakdown: {actual_samples_collected}") + + total_chars = 0 + total_tokens = 0 + examples_to_show = 5 + examples_shown = 0 + + random.shuffle(test_samples) + + for i, text_sample in enumerate(test_samples): + if not text_sample: # Should already be filtered by extractor, but double-check + continue + try: + tokens = sp.encode_as_pieces(text_sample) + num_tokens = len(tokens) + num_chars = len(text_sample) + total_tokens += num_tokens + total_chars += num_chars + + if examples_shown < examples_to_show: + print(f"\nSample {examples_shown + 1} (Overall index {i}):") # Clarify sample index + print(f" Text ({num_chars} chars): {text_sample[:100]}{'...' if len(text_sample)>100 else ''}") + print(f" Tokens ({num_tokens}): {tokens}") + examples_shown += 1 + except Exception as e: + print(f"\nWarning: Error encoding sample (Overall index {i}) with SentencePiece. Skipping sample. Error: {e}") + print(f"Problematic sample text: {text_sample[:100]}...") # Show problematic text + + # --- Test Summary --- + if total_tokens > 0 and total_chars > 0 : # Avoid division by zero + avg_chars_per_token = total_chars / total_tokens + print(f"\n--- Test Summary ---") + print(f"Tested on {actual_sample_size} text items.") + print(f"Final samples breakdown: {actual_samples_collected}") + print(f"Total characters in final sample: {total_chars:,}") + print(f"Total tokens in final sample: {total_tokens:,}") + print(f"Average characters per token: {avg_chars_per_token:.4f}") + elif actual_sample_size > 0: + print("\n--- Test Summary ---") + print(f"Tested on {actual_sample_size} text items.") + print(f"Final samples breakdown: {actual_samples_collected}") + print("No valid tokens were generated from the sample data (or samples were empty after unidecode).") + else: + # This case should be caught earlier, but included for completeness + print("\n--- Test Summary ---") + print("No samples were collected or processed.") + + print("--- Test finished ---") + +def parse_args(): + parser = argparse.ArgumentParser(description="Train and test a Unigram tokenizer on combined datasets.") + parser.add_argument("--train", action="store_true", help="Train the tokenizer on datasets") + parser.add_argument("--test", action="store_true", help="Test the trained tokenizer") + parser.add_argument("--sample_size", type=int, default=1000, help="Number of items to sample for testing") + parser.add_argument("--wiki_fraction", type=float, default=0.05, help="Fraction of Wikipedia dataset to use (e.g., 0.05 for 5%)") + parser.add_argument("--subtitles_fraction", type=float, + default=0.05, help="Fraction of opensubtitles dataset to use (e.g., 0.05 for 5%)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Load datasets and calculate counts once + print("--- Loading and Counting Datasets ---") + datasets, counts = load_and_count_datasets(wiki_fraction=args.wiki_fraction, subtitles_fraction=args.subtitles_fraction) + print("--- Dataset Loading Finished ---") + + run_train = args.train + run_test = args.test + + # If no arguments provided, run both by default + if not args.train and not args.test: + run_train = True + run_test = True + + if run_train: + print("\n--- Starting Tokenizer Training ---") + train_tokenizer(MODEL_PREFIX, datasets, counts) + print("--- Training Finished ---") + + if run_test: + # Ensure tokenizer file exists if training didn't just run + if not os.path.exists(MODEL_FILE): + print(f"\nError: SentencePiece model file {MODEL_FILE} not found. Cannot run test.") + print("Please run with --train first or ensure the file exists.") + else: + print("\n--- Starting Tokenizer Testing ---") + test_tokenizer(MODEL_FILE, datasets, counts, sample_size=args.sample_size) + print("--- Testing Finished ---") + + print("\nScript finished.") @@ -0,0 +1,357 @@ +import argparse +from math import floor, ceil +import msvcrt +from pythonosc import udp_client +import sentencepiece as spm +import sys +import threading +import time + +TESTS_ENABLED = True + +# 0 = quiet, 1 = verbose, 2 = very verbose +LOG_LEVEL = 0 + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--block-width", type=int, default=2, help="Number of elements sent in one block") + parser.add_argument("--num-blocks", type=int, default=40, help="Number of blocks in animator") + parser.add_argument("--rows", type=int, default=8, help="Number of rows in the chatbox") + parser.add_argument("--cols", type=int, default=20, help="Number of columns in the chatbox") + return parser.parse_args() + +def assert_equal(a, b): + err_msg = f"{a} != {b}" + assert a == b, err_msg + +# Turn a whitespace-delimited string into a list of strings no longer than +# `cols`. +# Preferentially breaks strings at whitespace boundaries. Preserves whitespace +# between words, except if that whitespace comes between lines. Breaks words +# longer than `cols` with a hyphen. +def wrap_line(line: str, cols): + # First, split line into alternating chunks of words and whitespace. + def get_sequences(line): + is_space = False + sequences = [] + seq_start = 0 + seq_end = -1 + for i in range(0, len(line)): + if line[i].isspace(): + if is_space: + seq_end = i + continue + # We were looking at text, now we see whitespace. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = True + else: + if not is_space: + seq_end = i + continue + # We were looking at whitespace, now we see text. + seq = line[seq_start:seq_end+1] + if len(seq) > 0: + sequences.append(seq) + seq_start = i + seq_end = i + is_space = False + sequences.append(line[seq_start:seq_end+1]) + return sequences + if TESTS_ENABLED: + assert_equal(get_sequences("foo"), ["foo"]) + assert_equal(get_sequences("foo bar"), ["foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + assert_equal(get_sequences(" foo bar"), [" ", "foo", " ", "bar"]) + + # Next, greedily construct lines out of those sequences. + # Whitespace gets treated specially. If it would push us over the limit, we + # end the line and drop the whitespace. + sequences = get_sequences(line) + def coalesce_sequences(sequences, cols): + cur_line = "" + lines = [] + for seq in sequences: + if len(cur_line) + len(seq) <= cols: + cur_line += seq + continue + if seq.isspace(): + lines.append(cur_line) + cur_line = "" + continue + if len(cur_line) > 0: + lines.append(cur_line) + # Edge case: text sequence is longer than a line. + while len(seq) > cols: + seq_prefix = seq[0:cols-1] + "-" + seq = seq[cols-1:] + lines.append(seq_prefix) + cur_line = seq + if len(cur_line) > 0: + lines.append(cur_line) + return lines + if TESTS_ENABLED: + assert_equal(coalesce_sequences(get_sequences("foo bar"), 3), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo ", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foo bar"), 4), ["foo", "bar"]) + assert_equal(coalesce_sequences(get_sequences("foobar"), 3), ["fo-", "ob-", "ar"]) + assert_equal(coalesce_sequences(get_sequences("f obar"), 3), ["f ", "ob-", "ar"]) + + lines = coalesce_sequences(sequences, cols) + + # Next, pad each line with whitespace. + def pad_lines(lines, cols): + for i in range(0, len(lines)): + lines[i] += ' ' * (cols - len(lines[i])) + return lines + if TESTS_ENABLED: + assert_equal(pad_lines(["foo", "ba"], 4), ["foo ", "ba "]) + assert_equal(pad_lines(["foo"], 2), ["foo"]) + + return pad_lines(lines, cols) + +def get_blocks(lines, tokenizer, block_width, num_blocks): + if LOG_LEVEL == 2: + print(f"Lines sent to tokenizer: {''.join(lines)}") + tokens = tokenizer.encode_as_ids(''.join(lines)) + if LOG_LEVEL == 2: + print(f"Tokens: {tokens}") + pieces = [] + for tok in tokens: + piece = tokenizer.id_to_piece(tok) + pieces.append(piece) + if LOG_LEVEL == 2: + print(f"Pieces: {pieces}") + + # Group tokens into blocks and pad with empty characters. + # Also get visual pointers - the location where each block will be rendered. + def get_blocks(): + blocks = [] + visual_pointer = 0 + visual_pointers = [] + for i in range(0, ceil(len(tokens) / block_width)): + visual_pointers.append(visual_pointer) + block = [] + for j in range(0, block_width): + if i*block_width + j >= len(tokens): + # Pad block with empty characters. 65535 is a special token. + block += [65535] * (block_width - len(block)) + break + block.append(tokens[i*block_width+j]) + visual_pointer += len(pieces[i*block_width+j]) + blocks.append(block) + return (blocks, visual_pointers) + blocks, visual_pointers = get_blocks() + if LOG_LEVEL == 2: + print(f"Blocks: {blocks}") + print(f"Visual pointers: {visual_pointers}") + + # Set all blocks up to the next `num_blocks` boundary to blank tokens. + # This handles the edge case where a prior message wrote data there which + # is covering up our new data. + def pad_blocks(blocks, visual_pointers): + cur_num_blocks = len(blocks) + num_pad_blocks = num_blocks - cur_num_blocks + for i in range(0, num_pad_blocks): + blocks.append([65535] * block_width) + visual_pointers.append(255) + return blocks, visual_pointers + blocks, visual_pointers = pad_blocks(blocks, visual_pointers) + if LOG_LEVEL == 2: + print(f"Blocks (padded): {blocks}") + print(f"Visual pointers (padded): {visual_pointers}") + + return blocks, visual_pointers + +def calc_diff(prev_blocks, prev_visual_pointers, cur_blocks, + cur_visual_pointers): + diff_indices = [] + diff_blocks = [] + diff_visual_pointers = [] + + for i in range(0, len(cur_blocks)): + if i >= len(prev_blocks): + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + continue + if prev_blocks[i] != cur_blocks[i] or prev_visual_pointers[i] != cur_visual_pointers[i]: + diff_blocks.append(cur_blocks[i]) + diff_visual_pointers.append(cur_visual_pointers[i]) + diff_indices.append(i) + + return diff_indices, diff_blocks, diff_visual_pointers + +def send_data(osc_client, indices, blocks, visual_pointers): + def split_blocks_by_byte(blocks): + blocks_byte00 = [] + blocks_byte01 = [] + for block in blocks: + block_byte00 = [] + block_byte01 = [] + for datum in block: + block_byte00.append((datum >> 0) & 0xFF) + block_byte01.append((datum >> 8) & 0xFF) + blocks_byte00.append(block_byte00) + blocks_byte01.append(block_byte01) + return blocks_byte00, blocks_byte01 + + blocks_byte00, blocks_byte01 = split_blocks_by_byte(blocks) + if LOG_LEVEL == 2: + print(f"Blocks (byte 00): {blocks_byte00}") + print(f"Blocks (byte 01): {blocks_byte01}") + + def send_osc(osc_client, addr, data): + #print(f"Sending {data} to {addr}") + osc_client.send_message(addr, data) + + for i in range(0, len(blocks)): + lp_int = indices[i] + lp_param = "_Unigram_Letter_Grid_OSC_Pointer" + addr = "/avatar/parameters/" + lp_param + send_osc(osc_client, addr, lp_int) + + vp_float = (-127.5 + visual_pointers[i]) / 127.5 + vp_param = f"_Unigram_Letter_Grid_OSC_Visual_Pointer" + addr = "/avatar/parameters/" + vp_param + send_osc(osc_client, addr, vp_float) + if LOG_LEVEL == 2: + print(f"Sending block {blocks[i]} at {visual_pointers[i]} index {indices[i]}") + for j in range(0, len(blocks[i])): + byte00_float = (-127.5 + blocks_byte00[i][j]) / 127.5 + byte01_float = (-127.5 + blocks_byte01[i][j]) / 127.5 + byte00_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte00" + byte01_param = f"_Unigram_Letter_Grid_OSC_Datum{j:02}_Byte01" + addr = "/avatar/parameters/" + byte00_param + send_osc(osc_client, addr, byte00_float) + addr = "/avatar/parameters/" + byte01_param + send_osc(osc_client, addr, byte01_float) + time.sleep(0.34) + +def getOscClient(ip = "127.0.0.1", port = 9000): + return udp_client.SimpleUDPClient(ip, port) + +class InputState: + def __init__(self): + # Initialize the known state of the board to empty array. This will cause + # our paging logic to re-send everything the first time around. + self.blocks = [] + self.visual_pointers = [] + pass + +def handle_input(state: InputState, line: str, tokenizer, osc_client, args): + line_wrapped = wrap_line(line, args.cols) + if TESTS_ENABLED: + for line in line_wrapped: + assert_equal(len(line), args.cols) + if LOG_LEVEL == 2: + print(f"Wrapped lines: {line_wrapped}") + blocks, visual_pointers = get_blocks(line_wrapped, tokenizer, + args.block_width, args.num_blocks) + # Wrap visual pointers. This is mostly done just to stay within our + # limited 8-bit precision budget. The shader also has to wrap in order + # to properly display things. + visual_pointers = [ ptr % (args.rows * args.cols) for ptr in visual_pointers ] + indices, diff_blocks, diff_visual_pointers = calc_diff(state.blocks, state.visual_pointers, blocks, visual_pointers) + indices = [idx % args.num_blocks for idx in indices] + # Send only one block at a time to make things snappier in interactive use + # case. + # TODO use a continuation (yield) instead of returning. Then we can be a + # little lighter on the cpu. Measurements show that this script is + # already very light but we're clearly wasting a lot of work by + # re-tokenizing the entire input every time we send a block. + if len(indices) == 0: + return + if indices[0] == len(state.blocks): + state.blocks.append(diff_blocks[0]) + state.visual_pointers.append(diff_visual_pointers[0]) + elif indices[0] > len(state.blocks): + print(f"This should never happen!") + sys.exit(1) + else: + state.blocks[indices[0]] = diff_blocks[0] + state.visual_pointers[indices[0]] = diff_visual_pointers[0] + + send_data(osc_client, [indices[0]], [diff_blocks[0]], [diff_visual_pointers[0]]) + +class SharedThreadData: + def __init__(self): + self.word = "" + self.word_lock = threading.Lock() + self.exit_event = threading.Event() + +def osc_thread(shared_data: SharedThreadData, cli_args): + tokenizer = get_tokenizer() + osc_client = getOscClient() + + # Prime the board + print("Priming the board") + input_state = InputState() + handle_input(input_state, "", tokenizer, osc_client, cli_args) + + while not shared_data.exit_event.is_set(): + word_copy = "" + with shared_data.word_lock: + word_copy = shared_data.word + handle_input(input_state, word_copy, tokenizer, osc_client, cli_args) + time.sleep(0.01) + +if __name__ == "__main__": + cli_args = parse_args() + + shared_data = SharedThreadData() + osc_thread = threading.Thread( + target=osc_thread, + args=(shared_data, cli_args)) + osc_thread.start() + + word_is_over = False + local_word = "" + while True: + char_bytes = msvcrt.getch() + + try: + char = char_bytes.decode('utf-8') + if char == '\r' or char == '\n': + word_is_over = True + continue + except UnicodeDecodeError: + print(f"Unsupported character: {char_bytes}") + if char_bytes == b'\x00' or char_bytes == b'\xe0': + special_char = msvcrt.getch() + continue + + if char_bytes == b'\x03': # ctrl+C + break + elif char_bytes == b'\x08': # backspace + with shared_data.word_lock: + shared_data.word = shared_data.word[:-1] + local_word = shared_data.word + elif char_bytes == b'\x0c': # ctrl+L + with shared_data.word_lock: + shared_data.word = "" + local_word = shared_data.word + elif word_is_over: + with shared_data.word_lock: + shared_data.word = char + local_word = shared_data.word + word_is_over = False + else: + with shared_data.word_lock: + shared_data.word += char + local_word = shared_data.word + print(local_word + "_") + shared_data.exit_event.set() + osc_thread.join() + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..104e7b8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +datasets +pillow +python-osc +unidecode +sentencepiece + diff --git a/tokenize_me.py b/tokenize_me.py new file mode 100644 index 0000000..83be290 --- /dev/null +++ b/tokenize_me.py @@ -0,0 +1,28 @@ +import argparse +import sentencepiece as spm + +def get_tokenizer(): + model_path = "./custom_unigram_tokenizer_65k/unigram.model" + print(f"Loading SentencePiece tokenizer from: {model_path}") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + print(f"Successfully loaded SentencePiece model. Vocab size: {sp.get_piece_size()}") + return sp + +def parse_args(): + parser = argparse.ArgumentParser(description="Tokenize a given string using a SentencePiece model.") + parser.add_argument("text", type=str, help="The string to tokenize.") + args = parser.parse_args() + return args + +args = parse_args() +tok = get_tokenizer() +tokens = tok.encode_as_pieces(args.text) +print("Tokens:", tokens) + +token_ids = tok.encode_as_ids(args.text) +print("Token IDs:", token_ids) + +# Split each token ID into two 8-bit chunks (high byte, low byte) +byte_pairs = [(tid >> 8, tid & 0xFF) for tid in token_ids] +print("Token ID Byte Pairs:", byte_pairs) |
