summaryrefslogtreecommitdiffstats
path: root/generate_tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'generate_tokenizer.py')
-rw-r--r--generate_tokenizer.py525
1 files changed, 525 insertions, 0 deletions
diff --git a/generate_tokenizer.py b/generate_tokenizer.py
new file mode 100644
index 0000000..88903a8
--- /dev/null
+++ b/generate_tokenizer.py
@@ -0,0 +1,525 @@
+# !!! AI ARTIFACT !!!
+# This file was primarily written with AI.
+
+import os
+import argparse
+from datasets import load_dataset, interleave_datasets
+import sentencepiece as spm
+import itertools
+import random
+from unidecode import unidecode
+
+# --- Dataset 1: Wikipedia ---
+DATASET_NAME_WIKI = "wikipedia"
+DATASET_CONFIG_WIKI = "20220301.en"
+
+DATASET_SPLIT_WIKI = "train"
+TEXT_COLUMN_WIKI = "text"
+
+# --- Dataset 2: DailyDialog ---
+DATASET_NAME_DD = "daily_dialog"
+DATASET_SPLIT_DD = "train"
+UTTERANCE_COLUMN_DD = "dialog"
+
+# --- Dataset 3: BlendedSkillTalk (includes Persona-Chat) ---
+DATASET_NAME_BST = "blended_skill_talk"
+DATASET_SPLIT_BST = "train"
+
+# --- Dataset 4: OpenSubtitles ---
+DATASET_NAME_OS = "Helsinki-NLP/open_subtitles"
+DATASET_LANG_PAIR_OS = ("en", "fr") # Load en-fr and use the 'en' part
+DATASET_SPLIT_OS = "train"
+TEXT_COLUMN_OS = "en" # Access via item['translation']['en']
+
+# Reserve one space for special "empty" token.
+VOCAB_SIZE = 65535
+OUTPUT_DIR = "./custom_unigram_tokenizer_65k"
+MODEL_PREFIX = os.path.join(OUTPUT_DIR, "unigram")
+MODEL_FILE = MODEL_PREFIX + ".model"
+VOCAB_FILE = MODEL_PREFIX + ".vocab"
+
+UNK_TOKEN = "[UNK]"
+PAD_TOKEN = "[PAD]"
+CONTROL_SYMBOLS = ["[CLS]", "[SEP]", "[MASK]"]
+
+BATCH_SIZE = 1000
+
+def wiki_iterator(dataset_wiki, batch_size=BATCH_SIZE):
+ if dataset_wiki:
+ for i in range(0, len(dataset_wiki), batch_size):
+ yield [unidecode(text) for text in dataset_wiki[i : i + batch_size][TEXT_COLUMN_WIKI] if text]
+
+def dd_iterator(dataset_dd, batch_size=BATCH_SIZE):
+ if dataset_dd:
+ current_batch = []
+ for dialogue in dataset_dd:
+ utterances = dialogue[UTTERANCE_COLUMN_DD]
+ for utterance in utterances:
+ if utterance:
+ normalized_utterance = unidecode(utterance)
+ current_batch.append(normalized_utterance)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def bst_iterator(dataset_bst, batch_size=BATCH_SIZE):
+ if dataset_bst:
+ current_batch = []
+ for session in dataset_bst:
+ texts_to_add = []
+ if session.get("previous_utterance"):
+ texts_to_add.append(session["previous_utterance"])
+ if session.get("free_messages"):
+ texts_to_add.extend(session["free_messages"])
+ if session.get("guided_messages"):
+ texts_to_add.extend(session["guided_messages"])
+
+ for text in texts_to_add:
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ current_batch.append(normalized_text)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def os_iterator(dataset_os, batch_size=BATCH_SIZE, lang_code=TEXT_COLUMN_OS):
+ if dataset_os:
+ current_batch = []
+ for item in dataset_os:
+ text = item['translation'][lang_code]
+ if text:
+ normalized_text = unidecode(text)
+ current_batch.append(normalized_text)
+ if len(current_batch) == batch_size:
+ yield current_batch
+ current_batch = []
+ if current_batch:
+ yield current_batch
+
+def count_wiki_items_chars(dataset_wiki):
+ item_count = 0
+ char_count = 0
+ if dataset_wiki:
+ item_count = len(dataset_wiki)
+ try:
+ char_count = sum(len(unidecode(text)) for text in dataset_wiki[TEXT_COLUMN_WIKI] if text and isinstance(text, str))
+ except Exception:
+ print(f"Warning: Direct column access for char count failed for {DATASET_NAME_WIKI}. Iterating row by row (slower).")
+ char_count = sum(len(unidecode(row[TEXT_COLUMN_WIKI])) for row in dataset_wiki if row.get(TEXT_COLUMN_WIKI) and isinstance(row[TEXT_COLUMN_WIKI], str))
+
+ return item_count, char_count
+
+def count_dd_items_chars(dataset_dd):
+ item_count = 0
+ char_count = 0
+ if dataset_dd:
+ for dialogue in dataset_dd:
+ utterances = dialogue[UTTERANCE_COLUMN_DD]
+ for utterance in utterances:
+ if utterance and isinstance(utterance, str):
+ item_count += 1
+ char_count += len(unidecode(utterance))
+ return item_count, char_count
+
+def count_bst_items_chars(dataset_bst):
+ item_count = 0
+ char_count = 0
+ if dataset_bst:
+ for session in dataset_bst:
+ texts_to_process = []
+ # Gather all potential text strings first
+ prev_utt = session.get("previous_utterance")
+ if prev_utt and isinstance(prev_utt, list):
+ if len(prev_utt) > 0 and isinstance(prev_utt[0], str):
+ texts_to_process.append(prev_utt[0])
+ elif prev_utt and isinstance(prev_utt, str):
+ texts_to_process.append(prev_utt)
+
+ free_msgs = session.get("free_messages")
+ if free_msgs and isinstance(free_msgs, list):
+ for item in free_msgs:
+ if isinstance(item, list):
+ texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
+ elif isinstance(item, str):
+ texts_to_process.append(item)
+
+ guided_msgs = session.get("guided_messages")
+ if guided_msgs and isinstance(guided_msgs, list):
+ for item in guided_msgs:
+ if isinstance(item, list):
+ texts_to_process.extend(msg for msg in item if msg and isinstance(msg, str))
+ elif isinstance(item, str):
+ texts_to_process.append(item)
+
+ # Count items and chars from the gathered list
+ for text in texts_to_process:
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ if normalized_text:
+ item_count += 1
+ char_count += len(normalized_text)
+
+ return item_count, char_count
+
+def count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS):
+ item_count = 0
+ char_count = 0
+ if dataset_os:
+ for item in dataset_os:
+ text = item['translation'][lang_code]
+ if text and isinstance(text, str):
+ normalized_text = unidecode(text)
+ if normalized_text: # Ensure not empty after unidecode
+ item_count += 1
+ char_count += len(normalized_text)
+ return item_count, char_count
+
+def load_and_count_datasets(wiki_fraction, subtitles_fraction):
+ """Loads, potentially shrinks, and counts items/chars in datasets."""
+ datasets = {}
+ counts = {}
+ total_items = 0
+ total_chars = 0
+
+ # --- Wikipedia ---
+ print(f"Loading dataset 1: {DATASET_NAME_WIKI} ({DATASET_CONFIG_WIKI}), split: {DATASET_SPLIT_WIKI}")
+ dataset_wiki_full = load_dataset(DATASET_NAME_WIKI, DATASET_CONFIG_WIKI, split=DATASET_SPLIT_WIKI, trust_remote_code=True)
+ print(f"Original Wikipedia dataset size: {len(dataset_wiki_full):,}")
+
+ # Shrink the Wikipedia dataset
+ split_test_size = 1.0 - wiki_fraction
+ shrunk_dataset_split = dataset_wiki_full.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))
+ dataset_wiki = shrunk_dataset_split['train']
+ print(f"Using {wiki_fraction*100:.3f}% of Wikipedia dataset: {len(dataset_wiki):,} items")
+
+ count_wiki, chars_wiki = count_wiki_items_chars(dataset_wiki)
+ print(f"Wikipedia dataset loaded (shrunk). Precise text items: {count_wiki:,}, Characters: {chars_wiki:,}")
+ datasets['wiki'] = dataset_wiki
+ counts['wiki'] = (count_wiki, chars_wiki)
+ total_items += count_wiki
+ total_chars += chars_wiki
+
+ # --- DailyDialog ---
+ print(f"\nLoading dataset 2: {DATASET_NAME_DD}, split: {DATASET_SPLIT_DD}")
+ dataset_dd = load_dataset(DATASET_NAME_DD, split=DATASET_SPLIT_DD, trust_remote_code=True)
+ count_dd, chars_dd = count_dd_items_chars(dataset_dd)
+ print(f"DailyDialog dataset loaded. Precise text items (utterances): {count_dd:,}, Characters: {chars_dd:,}")
+ datasets['dd'] = dataset_dd
+ counts['dd'] = (count_dd, chars_dd)
+ total_items += count_dd
+ total_chars += chars_dd
+
+ # --- BlendedSkillTalk ---
+ print(f"\nLoading dataset 3: {DATASET_NAME_BST}, split: {DATASET_SPLIT_BST}")
+ dataset_bst = load_dataset(DATASET_NAME_BST, split=DATASET_SPLIT_BST, trust_remote_code=True)
+ count_bst, chars_bst = count_bst_items_chars(dataset_bst)
+ print(f"BlendedSkillTalk dataset loaded. Precise text items (extracted): {count_bst:,}, Characters: {chars_bst:,}")
+ datasets['bst'] = dataset_bst
+ counts['bst'] = (count_bst, chars_bst)
+ total_items += count_bst
+ total_chars += chars_bst
+
+ # --- OpenSubtitles ---
+ print(f"\nLoading dataset 4: {DATASET_NAME_OS}, lang_pair: {DATASET_LANG_PAIR_OS}, split: {DATASET_SPLIT_OS}")
+ # Note: OpenSubtitles can be very large. Consider streaming or specific configurations if memory is an issue.
+ # For now, loading a standard configuration.
+ dataset_os = load_dataset(DATASET_NAME_OS, lang1=DATASET_LANG_PAIR_OS[0], lang2=DATASET_LANG_PAIR_OS[1], split=DATASET_SPLIT_OS, trust_remote_code=True)
+ split_test_size = 1.0 - subtitles_fraction
+ dataset_os = dataset_os.train_test_split(test_size=split_test_size, seed=random.randint(0, 1000000))['train']
+ count_os, chars_os = count_os_items_chars(dataset_os, lang_code=TEXT_COLUMN_OS)
+ print(f"OpenSubtitles dataset loaded. Precise text items: {count_os:,}, Characters: {chars_os:,}")
+ datasets['os'] = dataset_os
+ counts['os'] = (count_os, chars_os)
+ total_items += count_os
+ total_chars += chars_os
+
+ print(f"\nTotal precise text items from loaded datasets: {total_items:,}")
+ print(f"Total characters from loaded datasets: {total_chars:,}")
+
+ return datasets, counts
+
+def train_tokenizer(model_prefix, datasets, counts):
+ """Trains a Unigram tokenizer using SentencePiece and saves it."""
+ total_items = sum(c[0] for c in counts.values())
+ total_chars = sum(c[1] for c in counts.values())
+ if total_items == 0:
+ print("Error: No items found in the datasets to train on. Exiting training.")
+ return
+
+ print(f"\nTotal precise text items for training: {total_items:,}")
+ print(f"Total characters for training: {total_chars:,}")
+
+ output_dir = os.path.dirname(model_prefix)
+ os.makedirs(output_dir, exist_ok=True)
+
+ print(f"\nStarting SentencePiece Unigram tokenizer training with vocab size: {VOCAB_SIZE}")
+ print(f"Using combined dataset iterator.")
+ print(f"Output model prefix: {model_prefix}")
+
+ iterators_to_chain = []
+ def flatten_iterator(iterator):
+ for batch in iterator:
+ for item in batch:
+ yield item
+
+ if datasets.get('wiki') and counts['wiki'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(wiki_iterator(datasets['wiki'])))
+ if datasets.get('dd') and counts['dd'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(dd_iterator(datasets['dd'])))
+ if datasets.get('bst') and counts['bst'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(bst_iterator(datasets['bst'])))
+ if datasets.get('os') and counts['os'][0] > 0:
+ iterators_to_chain.append(flatten_iterator(os_iterator(datasets['os'])))
+
+ if not iterators_to_chain:
+ print("Error: No valid dataset iterators available for training. Exiting.")
+ return
+
+ combined_iterator = itertools.chain(*iterators_to_chain)
+
+ # Include whitespace symbols so we can efficiently break lines.
+ # If we include the single space, it prevents the tokenizer from merging
+ # spaces with regular words, and tanks the average chars/token.
+ # This many tokens is kinda overkill, but it gives us a way to efficiently
+ # clear even fairly large boards, so I think it's worth.
+ whitespace_symbols = []
+ for i in range(2, 40):
+ whitespace_symbols.append('▁' * i)
+
+ spm.SentencePieceTrainer.train(
+ sentence_iterator=combined_iterator,
+ model_prefix=model_prefix,
+ vocab_size=VOCAB_SIZE,
+ model_type='unigram',
+ character_coverage=1.0,
+ unk_piece=UNK_TOKEN,
+ pad_piece=PAD_TOKEN,
+ control_symbols=CONTROL_SYMBOLS,
+ user_defined_symbols=whitespace_symbols,
+ # These whitespace options must be false, or else whitespace won't
+ # be respected when encoding.
+ add_dummy_prefix=False,
+ remove_extra_whitespaces=False,
+ split_by_whitespace=False,
+ num_threads=os.cpu_count(),
+ input_sentence_size=total_items,
+ )
+ print("\nTraining finished.")
+ print(f"SentencePiece model saved to: {model_prefix}.model")
+ print(f"SentencePiece vocabulary saved to: {model_prefix}.vocab")
+
+def extract_text_samples(dataset, count, num_samples, text_extractor_func):
+ """Extracts a specified number of text samples from a dataset."""
+ samples = []
+ if dataset is None or count == 0 or num_samples == 0:
+ return samples
+ # Take samples from the beginning, ensure we don't exceed dataset size
+ actual_samples_to_take = min(num_samples, count)
+ # Use the provided function to extract text correctly for this dataset type
+ samples = text_extractor_func(dataset.select(range(actual_samples_to_take)))
+ return [unidecode(s) for s in samples if s and isinstance(s,str)]
+
+def wiki_text_extractor(dataset_slice):
+ """Extracts text from a slice of the Wikipedia dataset."""
+ return [text for text in dataset_slice[TEXT_COLUMN_WIKI] if text]
+
+def dd_text_extractor(dataset_slice):
+ """Extracts text from a slice of the DailyDialog dataset."""
+ texts = []
+ for dialogue in dataset_slice:
+ texts.extend(utterance for utterance in dialogue[UTTERANCE_COLUMN_DD] if utterance)
+ return texts
+
+def bst_text_extractor(dataset_slice):
+ """Extracts text from a slice of the BlendedSkillTalk dataset."""
+ texts = []
+ for session in dataset_slice:
+ if session.get("previous_utterance"):
+ texts.append(session["previous_utterance"])
+ if session.get("free_messages"):
+ texts.extend(session["free_messages"])
+ if session.get("guided_messages"):
+ texts.extend(session["guided_messages"])
+ return [t for t in texts if t and isinstance(t,str)] # Ensure only valid strings
+
+def os_text_extractor(dataset_slice, lang_code=TEXT_COLUMN_OS):
+ """Extracts text from a slice of the OpenSubtitles dataset."""
+ texts = []
+ for item in dataset_slice:
+ text = item['translation'][lang_code]
+ if text and isinstance(text, str):
+ texts.append(text)
+ return texts
+
+def test_tokenizer(model_path, datasets, counts, sample_size=1000):
+ print("\n--- Testing the trained Unigram (SentencePiece) tokenizer on data sample ---")
+ if sample_size <= 0:
+ print("Error: Sample size for testing must be positive.")
+ return
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(model_path)
+ print(f"Successfully loaded SentencePiece model from: {model_path}")
+ print(f"Vocabulary size: {sp.get_piece_size()}")
+
+ # Identify available datasets with content
+ available_datasets = {
+ name: data
+ for name, data in datasets.items()
+ if counts[name][0] > 0
+ }
+ num_available_datasets = len(available_datasets)
+
+ if num_available_datasets == 0:
+ print("Warning: No data available in loaded datasets to sample for testing.")
+ return
+
+ print(f"Found {num_available_datasets} non-empty dataset(s) for testing.")
+ print(f"Attempting to sample up to {sample_size} total items equally from these datasets...")
+
+ # Calculate equal number of samples per available dataset
+ samples_per_dataset = sample_size // num_available_datasets
+ print(f"Targeting approximately {samples_per_dataset} samples per dataset.")
+
+ test_samples = []
+ dataset_extractors = {
+ 'wiki': wiki_text_extractor,
+ 'dd': dd_text_extractor,
+ 'bst': bst_text_extractor,
+ 'os': os_text_extractor,
+ }
+
+ actual_samples_collected = {}
+
+ # Sample equally from each available dataset
+ for name, dataset in available_datasets.items():
+ count = counts[name][0]
+ extractor = dataset_extractors[name]
+
+ # Determine the target number of samples for *this* dataset
+ num_samples_to_target = min(samples_per_dataset, count)
+ if num_samples_to_target <= 0:
+ print(f" Skipping '{name}' (no items requested or available).")
+ actual_samples_collected[name] = 0
+ continue
+
+ print(f" Sampling from '{name}' (target: {num_samples_to_target} items)...")
+
+ items_before = len(test_samples)
+ # For OpenSubtitles, pass the lang_code if the extractor needs it
+ if name == 'os':
+ extracted_items = extract_text_samples(dataset, count, num_samples_to_target, lambda ds_slice: extractor(ds_slice, lang_code=TEXT_COLUMN_OS))
+ else:
+ extracted_items = extract_text_samples(dataset, count, num_samples_to_target, extractor)
+
+ # Limit the extracted items to the target number
+ final_samples_for_dataset = extracted_items[:num_samples_to_target]
+
+ test_samples.extend(final_samples_for_dataset)
+ items_added = len(final_samples_for_dataset)
+ actual_samples_collected[name] = items_added
+ print(f" Added {items_added} items from '{name}'.")
+
+ actual_sample_size = len(test_samples)
+ if actual_sample_size == 0:
+ print("\nCould not gather any samples for testing.")
+ print("--- Test finished ---")
+ return
+
+ print(f"\n--- Starting Test Run ---")
+ print(f"Testing on {actual_sample_size} sampled text items (final count).")
+ print(f"Samples breakdown: {actual_samples_collected}")
+
+ total_chars = 0
+ total_tokens = 0
+ examples_to_show = 5
+ examples_shown = 0
+
+ random.shuffle(test_samples)
+
+ for i, text_sample in enumerate(test_samples):
+ if not text_sample: # Should already be filtered by extractor, but double-check
+ continue
+ try:
+ tokens = sp.encode_as_pieces(text_sample)
+ num_tokens = len(tokens)
+ num_chars = len(text_sample)
+ total_tokens += num_tokens
+ total_chars += num_chars
+
+ if examples_shown < examples_to_show:
+ print(f"\nSample {examples_shown + 1} (Overall index {i}):") # Clarify sample index
+ print(f" Text ({num_chars} chars): {text_sample[:100]}{'...' if len(text_sample)>100 else ''}")
+ print(f" Tokens ({num_tokens}): {tokens}")
+ examples_shown += 1
+ except Exception as e:
+ print(f"\nWarning: Error encoding sample (Overall index {i}) with SentencePiece. Skipping sample. Error: {e}")
+ print(f"Problematic sample text: {text_sample[:100]}...") # Show problematic text
+
+ # --- Test Summary ---
+ if total_tokens > 0 and total_chars > 0 : # Avoid division by zero
+ avg_chars_per_token = total_chars / total_tokens
+ print(f"\n--- Test Summary ---")
+ print(f"Tested on {actual_sample_size} text items.")
+ print(f"Final samples breakdown: {actual_samples_collected}")
+ print(f"Total characters in final sample: {total_chars:,}")
+ print(f"Total tokens in final sample: {total_tokens:,}")
+ print(f"Average characters per token: {avg_chars_per_token:.4f}")
+ elif actual_sample_size > 0:
+ print("\n--- Test Summary ---")
+ print(f"Tested on {actual_sample_size} text items.")
+ print(f"Final samples breakdown: {actual_samples_collected}")
+ print("No valid tokens were generated from the sample data (or samples were empty after unidecode).")
+ else:
+ # This case should be caught earlier, but included for completeness
+ print("\n--- Test Summary ---")
+ print("No samples were collected or processed.")
+
+ print("--- Test finished ---")
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train and test a Unigram tokenizer on combined datasets.")
+ parser.add_argument("--train", action="store_true", help="Train the tokenizer on datasets")
+ parser.add_argument("--test", action="store_true", help="Test the trained tokenizer")
+ parser.add_argument("--sample_size", type=int, default=1000, help="Number of items to sample for testing")
+ parser.add_argument("--wiki_fraction", type=float, default=0.05, help="Fraction of Wikipedia dataset to use (e.g., 0.05 for 5%)")
+ parser.add_argument("--subtitles_fraction", type=float,
+ default=0.05, help="Fraction of opensubtitles dataset to use (e.g., 0.05 for 5%)")
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ # Load datasets and calculate counts once
+ print("--- Loading and Counting Datasets ---")
+ datasets, counts = load_and_count_datasets(wiki_fraction=args.wiki_fraction, subtitles_fraction=args.subtitles_fraction)
+ print("--- Dataset Loading Finished ---")
+
+ run_train = args.train
+ run_test = args.test
+
+ # If no arguments provided, run both by default
+ if not args.train and not args.test:
+ run_train = True
+ run_test = True
+
+ if run_train:
+ print("\n--- Starting Tokenizer Training ---")
+ train_tokenizer(MODEL_PREFIX, datasets, counts)
+ print("--- Training Finished ---")
+
+ if run_test:
+ # Ensure tokenizer file exists if training didn't just run
+ if not os.path.exists(MODEL_FILE):
+ print(f"\nError: SentencePiece model file {MODEL_FILE} not found. Cannot run test.")
+ print("Please run with --train first or ensure the file exists.")
+ else:
+ print("\n--- Starting Tokenizer Testing ---")
+ test_tokenizer(MODEL_FILE, datasets, counts, sample_size=args.sample_size)
+ print("--- Testing Finished ---")
+
+ print("\nScript finished.")