import random
import re
from collections import Counter
import lorem
from tqdm import tqdmByte-pair encoding from scratch
random.seed(42)In byte-pair encoding we build a vocabulary by iteratively merging the most frequent pair of adjacent tokens.
- Start with individual characters as tokens.
- Find the most frequent pair of adjacent tokens.
- Merge that pair into a new token.
- Repeat until we reach the desired size of our vocabulary.
We start off by creating a small corpus of text.
text = lorem.text()
print(text)Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.
Consectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.
Quiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.
Training
tokens = text.encode("utf-8")
print(tokens)b'Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.\n\nConsectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.\n\nQuiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.'
tokens = list(map(int, tokens))
print(tokens)
print(f"Length: {len(tokens)}")[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 111, 114, 32, 116, 101, 109, 112, 111, 114, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 116, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 97, 108, 105, 113, 117, 97, 109, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 101, 105, 117, 115, 46, 32, 81, 117, 105, 113, 117, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 111, 114, 114, 111, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 115, 105, 116, 32, 112, 111, 114, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 113, 117, 101, 32, 113, 117, 97, 101, 114, 97, 116, 32, 101, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 116, 32, 109, 111, 100, 105, 32, 105, 112, 115, 117, 109, 32, 101, 115, 116, 32, 100, 111, 108, 111, 114, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 116, 32, 109, 97, 103, 110, 97, 109, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 46, 32, 69, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 116, 101, 109, 112, 111, 114, 97, 32, 110, 101, 113, 117, 101, 32, 112, 111, 114, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 111, 114, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 113, 117, 105, 115, 113, 117, 97, 109, 46, 32, 76, 97, 98, 111, 114, 101, 32, 113, 117, 97, 101, 114, 97, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 97, 109, 101, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 116, 32, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 110, 101, 113, 117, 101, 46, 32, 68, 111, 108, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 113, 117, 105, 97, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 112, 111, 114, 114, 111, 46, 32, 68, 111, 108, 111, 114, 101, 32, 110, 101, 113, 117, 101, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 113, 117, 105, 115, 113, 117, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 109, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 117, 116, 32, 117, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 116, 32, 105, 112, 115, 117, 109, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 46, 32, 68, 111, 108, 111, 114, 101, 109, 32, 113, 117, 97, 101, 114, 97, 116, 32, 115, 105, 116, 32, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 113, 117, 97, 109, 32, 110, 111, 110, 32, 109, 97, 103, 110, 97, 109, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 110, 101, 113, 117, 101, 46, 32, 69, 115, 116, 32, 100, 111, 108, 111, 114, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 111, 114, 97, 32, 112, 111, 114, 114, 111, 46, 32, 69, 115, 116, 32, 116, 101, 109, 112, 111, 114, 97, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 97, 103, 110, 97, 109, 32, 108, 97, 98, 111, 114, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 113, 117, 97, 109, 32, 110, 111, 110, 32, 97, 109, 101, 116, 32, 117, 116, 32, 97, 108, 105, 113, 117, 97, 109, 46, 32, 68, 111, 108, 111, 114, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 833
pair_counts: dict[tuple[int, int], int] = {}
for pair in zip(tokens, tokens[1:]):
pair_counts[pair] = pair_counts.get(pair, 0) + 1
print(pair_counts){(69, 105): 3, (105, 117): 10, (117, 115): 10, (115, 32): 6, (32, 101): 14, (101, 105): 7, (32, 100): 9, (100, 111): 9, (111, 108): 15, (108, 111): 13, (111, 114): 29, (114, 32): 8, (32, 116): 4, (116, 101): 11, (101, 109): 11, (109, 112): 4, (112, 111): 9, (114, 97): 10, (97, 32): 6, (32, 99): 4, (99, 111): 4, (111, 110): 7, (110, 115): 5, (115, 101): 9, (101, 99): 5, (99, 116): 5, (101, 116): 12, (116, 117): 5, (117, 114): 5, (32, 115): 8, (101, 100): 4, (100, 46): 1, (46, 32): 18, (32, 65): 2, (65, 109): 2, (109, 101): 6, (116, 32): 31, (32, 113): 12, (113, 117): 32, (117, 97): 19, (97, 101): 6, (101, 114): 6, (97, 116): 8, (32, 109): 10, (109, 111): 3, (111, 100): 3, (100, 105): 6, (105, 32): 6, (32, 97): 12, (97, 108): 5, (108, 105): 9, (105, 113): 7, (97, 109): 24, (109, 32): 26, (97, 100): 3, (105, 112): 7, (112, 105): 3, (105, 115): 10, (115, 99): 3, (99, 105): 4, (114, 101): 15, (115, 46): 4, (32, 81): 2, (81, 117): 3, (117, 105): 11, (105, 97): 2, (32, 112): 5, (114, 114): 6, (114, 111): 6, (111, 32): 3, (115, 105): 4, (105, 116): 9, (115, 113): 7, (111, 46): 3, (32, 69): 5, (32, 110): 7, (110, 101): 5, (101, 113): 5, (117, 101): 5, (101, 32): 12, (101, 115): 6, (115, 116): 9, (32, 118): 5, (118, 101): 3, (101, 108): 4, (32, 117): 4, (117, 116): 4, (116, 46): 4, (32, 83): 1, (83, 105): 1, (32, 105): 4, (112, 115): 4, (115, 117): 4, (117, 109): 5, (114, 46): 1, (46, 10): 2, (10, 10): 2, (10, 67): 1, (67, 111): 1, (109, 97): 7, (97, 103): 7, (103, 110): 7, (110, 97): 7, (32, 108): 5, (108, 97): 5, (97, 98): 6, (98, 111): 6, (101, 46): 3, (69, 115): 3, (109, 46): 5, (32, 80): 1, (80, 111): 1, (116, 105): 1, (105, 110): 1, (110, 99): 1, (105, 100): 1, (100, 117): 1, (117, 110): 1, (110, 116): 1, (118, 111): 2, (108, 117): 2, (117, 112): 2, (112, 116): 2, (116, 97): 2, (32, 76): 1, (76, 97): 1, (100, 32): 3, (32, 68): 4, (68, 111): 4, (10, 81): 1, (32, 86): 1, (86, 101): 1, (110, 111): 2, (110, 32): 2, (32, 78): 1, (78, 117): 1, (109, 113): 1}
You can print the sorted pair counts like this:
print(sorted(pair_counts.items(), key=lambda kv: kv[1], reverse=True))[((113, 117), 32), ((116, 32), 31), ((111, 114), 29), ((109, 32), 26), ((97, 109), 24), ((117, 97), 19), ((46, 32), 18), ((111, 108), 15), ((114, 101), 15), ((32, 101), 14), ((108, 111), 13), ((101, 116), 12), ((32, 113), 12), ((32, 97), 12), ((101, 32), 12), ((116, 101), 11), ((101, 109), 11), ((117, 105), 11), ((105, 117), 10), ((117, 115), 10), ((114, 97), 10), ((32, 109), 10), ((105, 115), 10), ((32, 100), 9), ((100, 111), 9), ((112, 111), 9), ((115, 101), 9), ((108, 105), 9), ((105, 116), 9), ((115, 116), 9), ((114, 32), 8), ((32, 115), 8), ((97, 116), 8), ((101, 105), 7), ((111, 110), 7), ((105, 113), 7), ((105, 112), 7), ((115, 113), 7), ((32, 110), 7), ((109, 97), 7), ((97, 103), 7), ((103, 110), 7), ((110, 97), 7), ((115, 32), 6), ((97, 32), 6), ((109, 101), 6), ((97, 101), 6), ((101, 114), 6), ((100, 105), 6), ((105, 32), 6), ((114, 114), 6), ((114, 111), 6), ((101, 115), 6), ((97, 98), 6), ((98, 111), 6), ((110, 115), 5), ((101, 99), 5), ((99, 116), 5), ((116, 117), 5), ((117, 114), 5), ((97, 108), 5), ((32, 112), 5), ((32, 69), 5), ((110, 101), 5), ((101, 113), 5), ((117, 101), 5), ((32, 118), 5), ((117, 109), 5), ((32, 108), 5), ((108, 97), 5), ((109, 46), 5), ((32, 116), 4), ((109, 112), 4), ((32, 99), 4), ((99, 111), 4), ((101, 100), 4), ((99, 105), 4), ((115, 46), 4), ((115, 105), 4), ((101, 108), 4), ((32, 117), 4), ((117, 116), 4), ((116, 46), 4), ((32, 105), 4), ((112, 115), 4), ((115, 117), 4), ((32, 68), 4), ((68, 111), 4), ((69, 105), 3), ((109, 111), 3), ((111, 100), 3), ((97, 100), 3), ((112, 105), 3), ((115, 99), 3), ((81, 117), 3), ((111, 32), 3), ((111, 46), 3), ((118, 101), 3), ((101, 46), 3), ((69, 115), 3), ((100, 32), 3), ((32, 65), 2), ((65, 109), 2), ((32, 81), 2), ((105, 97), 2), ((46, 10), 2), ((10, 10), 2), ((118, 111), 2), ((108, 117), 2), ((117, 112), 2), ((112, 116), 2), ((116, 97), 2), ((110, 111), 2), ((110, 32), 2), ((100, 46), 1), ((32, 83), 1), ((83, 105), 1), ((114, 46), 1), ((10, 67), 1), ((67, 111), 1), ((32, 80), 1), ((80, 111), 1), ((116, 105), 1), ((105, 110), 1), ((110, 99), 1), ((105, 100), 1), ((100, 117), 1), ((117, 110), 1), ((110, 116), 1), ((32, 76), 1), ((76, 97), 1), ((10, 81), 1), ((32, 86), 1), ((86, 101), 1), ((32, 78), 1), ((78, 117), 1), ((109, 113), 1)]
Next we need to pick the pair with the highest count as our first merge.
most_common_pair = max(pair_counts, key=lambda p: pair_counts[p])
print(most_common_pair)
print(f"Count: {pair_counts[most_common_pair]}")
print(f"{chr(most_common_pair[0])}{chr(most_common_pair[1])}")(113, 117)
Count: 32
qu
And merge it into a new token.
new_token = most_common_pair[0] + most_common_pair[1]
print(new_token)230
We now update our words to use the new token.
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == most_common_pair:
new_tokens.append(new_token)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
print(new_tokens)
print(f"Length: {len(new_tokens)}")[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 111, 114, 32, 116, 101, 109, 112, 111, 114, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 116, 32, 230, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 97, 108, 105, 230, 97, 109, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 101, 105, 117, 115, 46, 32, 81, 117, 105, 230, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 111, 114, 114, 111, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 230, 105, 115, 230, 97, 109, 32, 115, 105, 116, 32, 112, 111, 114, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 230, 101, 32, 230, 97, 101, 114, 97, 116, 32, 101, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 116, 32, 109, 111, 100, 105, 32, 105, 112, 115, 117, 109, 32, 101, 115, 116, 32, 100, 111, 108, 111, 114, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 116, 32, 109, 97, 103, 110, 97, 109, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 46, 32, 69, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 108, 105, 230, 97, 109, 32, 116, 101, 109, 112, 111, 114, 97, 32, 110, 101, 230, 101, 32, 112, 111, 114, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 111, 114, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 230, 105, 115, 230, 97, 109, 46, 32, 76, 97, 98, 111, 114, 101, 32, 230, 97, 101, 114, 97, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 97, 109, 101, 116, 32, 97, 108, 105, 230, 97, 109, 32, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 116, 32, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 110, 101, 230, 101, 46, 32, 68, 111, 108, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 230, 105, 97, 32, 230, 105, 115, 230, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 112, 111, 114, 114, 111, 46, 32, 68, 111, 108, 111, 114, 101, 32, 110, 101, 230, 101, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 230, 105, 115, 230, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 109, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 117, 116, 32, 117, 116, 32, 97, 108, 105, 230, 97, 109, 32, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 116, 32, 105, 112, 115, 117, 109, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 46, 32, 68, 111, 108, 111, 114, 101, 109, 32, 230, 97, 101, 114, 97, 116, 32, 115, 105, 116, 32, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 230, 97, 109, 32, 110, 111, 110, 32, 109, 97, 103, 110, 97, 109, 32, 230, 105, 115, 230, 97, 109, 32, 110, 101, 230, 101, 46, 32, 69, 115, 116, 32, 100, 111, 108, 111, 114, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 111, 114, 97, 32, 112, 111, 114, 114, 111, 46, 32, 69, 115, 116, 32, 116, 101, 109, 112, 111, 114, 97, 32, 230, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 230, 97, 101, 114, 97, 116, 32, 109, 97, 103, 110, 97, 109, 32, 108, 97, 98, 111, 114, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 230, 97, 109, 32, 110, 111, 110, 32, 97, 109, 101, 116, 32, 117, 116, 32, 97, 108, 105, 230, 97, 109, 46, 32, 68, 111, 108, 111, 114, 32, 230, 105, 115, 230, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 801
assert not any(
new_tokens[i] == most_common_pair[0] and new_tokens[i + 1] == most_common_pair[1]
for i in range(len(new_tokens) - 1)
)
new_tokens.count(new_token)32
We can see that the length of the token list has decreased by the number of times we merged the pair (833 -> 801). Meanwhile, our vocabulary has obviously increased by 1.
Let’s wrap what we’ve done so far into some functions:
def get_tokens(text: str) -> list[int]:
tokens: bytes = text.encode(encoding="utf-8")
return list(map(int, tokens))
def get_pair_counts(tokens: list[int]) -> dict[tuple[int, int], int]:
counts: dict[tuple[int, int], int] = {}
for pair in zip(tokens, tokens[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge_pair(tokens: list[int], pair: tuple[int, int], idx: int) -> list[int]:
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
new_tokens.append(idx)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
return new_tokensAnd we’ll use them in a loop
VOCAB_SIZE = 260
merges: dict[tuple[int, int], int] = {}
tokens = get_tokens(text)
for i in range(256, VOCAB_SIZE):
pair_counts = get_pair_counts(tokens)
most_common_pair = max(pair_counts, key=lambda p: pair_counts[p])
tokens = merge_pair(tokens, most_common_pair, i)
merges[most_common_pair] = i
print(f"Merged {most_common_pair} into {i}. Sequence length: {len(tokens)}")Merged (113, 117) into 256. Sequence length: 801
Merged (116, 32) into 257. Sequence length: 770
Merged (111, 114) into 258. Sequence length: 741
Merged (109, 32) into 259. Sequence length: 715
print(merges){(113, 117): 256, (116, 32): 257, (111, 114): 258, (109, 32): 259}
And that is how we build up our vocabulary. The vocab put together looks like this:
vocab = {i: bytes([i]) for i in range(256)}
for pair, idx in merges.items():
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
print(vocab){0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[', 92: b'\\', 93: b']', 94: b'^', 95: b'_', 96: b'`', 97: b'a', 98: b'b', 99: b'c', 100: b'd', 101: b'e', 102: b'f', 103: b'g', 104: b'h', 105: b'i', 106: b'j', 107: b'k', 108: b'l', 109: b'm', 110: b'n', 111: b'o', 112: b'p', 113: b'q', 114: b'r', 115: b's', 116: b't', 117: b'u', 118: b'v', 119: b'w', 120: b'x', 121: b'y', 122: b'z', 123: b'{', 124: b'|', 125: b'}', 126: b'~', 127: b'\x7f', 128: b'\x80', 129: b'\x81', 130: b'\x82', 131: b'\x83', 132: b'\x84', 133: b'\x85', 134: b'\x86', 135: b'\x87', 136: b'\x88', 137: b'\x89', 138: b'\x8a', 139: b'\x8b', 140: b'\x8c', 141: b'\x8d', 142: b'\x8e', 143: b'\x8f', 144: b'\x90', 145: b'\x91', 146: b'\x92', 147: b'\x93', 148: b'\x94', 149: b'\x95', 150: b'\x96', 151: b'\x97', 152: b'\x98', 153: b'\x99', 154: b'\x9a', 155: b'\x9b', 156: b'\x9c', 157: b'\x9d', 158: b'\x9e', 159: b'\x9f', 160: b'\xa0', 161: b'\xa1', 162: b'\xa2', 163: b'\xa3', 164: b'\xa4', 165: b'\xa5', 166: b'\xa6', 167: b'\xa7', 168: b'\xa8', 169: b'\xa9', 170: b'\xaa', 171: b'\xab', 172: b'\xac', 173: b'\xad', 174: b'\xae', 175: b'\xaf', 176: b'\xb0', 177: b'\xb1', 178: b'\xb2', 179: b'\xb3', 180: b'\xb4', 181: b'\xb5', 182: b'\xb6', 183: b'\xb7', 184: b'\xb8', 185: b'\xb9', 186: b'\xba', 187: b'\xbb', 188: b'\xbc', 189: b'\xbd', 190: b'\xbe', 191: b'\xbf', 192: b'\xc0', 193: b'\xc1', 194: b'\xc2', 195: b'\xc3', 196: b'\xc4', 197: b'\xc5', 198: b'\xc6', 199: b'\xc7', 200: b'\xc8', 201: b'\xc9', 202: b'\xca', 203: b'\xcb', 204: b'\xcc', 205: b'\xcd', 206: b'\xce', 207: b'\xcf', 208: b'\xd0', 209: b'\xd1', 210: b'\xd2', 211: b'\xd3', 212: b'\xd4', 213: b'\xd5', 214: b'\xd6', 215: b'\xd7', 216: b'\xd8', 217: b'\xd9', 218: b'\xda', 219: b'\xdb', 220: b'\xdc', 221: b'\xdd', 222: b'\xde', 223: b'\xdf', 224: b'\xe0', 225: b'\xe1', 226: b'\xe2', 227: b'\xe3', 228: b'\xe4', 229: b'\xe5', 230: b'\xe6', 231: b'\xe7', 232: b'\xe8', 233: b'\xe9', 234: b'\xea', 235: b'\xeb', 236: b'\xec', 237: b'\xed', 238: b'\xee', 239: b'\xef', 240: b'\xf0', 241: b'\xf1', 242: b'\xf2', 243: b'\xf3', 244: b'\xf4', 245: b'\xf5', 246: b'\xf6', 247: b'\xf7', 248: b'\xf8', 249: b'\xf9', 250: b'\xfa', 251: b'\xfb', 252: b'\xfc', 253: b'\xfd', 254: b'\xfe', 255: b'\xff', 256: b'qu', 257: b't ', 258: b'or', 259: b'm '}
Decoding
We need a way to decode our tokens back into text.
def decode(tokens: list[int]) -> str:
token_bytes: bytes = b"".join(vocab[token] for token in tokens)
return token_bytes.decode("utf-8", errors="replace")
print(decode(tokens))Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.
Consectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.
Quiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.
Encoding
We also need a way to encode text into tokens, following the rules we built up in the training phase.
def encode(text: str) -> list[int]:
token_bytes: bytes = text.encode(encoding="utf-8")
tokens: list[int] = list(map(int, token_bytes))
for pair, idx in merges.items():
tokens = merge_pair(tokens, pair, idx)
return tokens
print(encode(text))
print(f"Length: {len(encode(text))}")[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 258, 32, 116, 101, 109, 112, 258, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 257, 256, 97, 101, 114, 97, 257, 109, 111, 100, 105, 32, 97, 108, 105, 256, 97, 259, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 257, 100, 111, 108, 258, 101, 259, 101, 105, 117, 115, 46, 32, 81, 117, 105, 256, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 258, 114, 111, 32, 100, 111, 108, 258, 101, 259, 115, 105, 257, 256, 105, 115, 256, 97, 259, 115, 105, 257, 112, 258, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 256, 101, 32, 256, 97, 101, 114, 97, 257, 101, 115, 257, 118, 101, 108, 105, 257, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 257, 109, 111, 100, 105, 32, 105, 112, 115, 117, 259, 101, 115, 257, 100, 111, 108, 258, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 257, 109, 97, 103, 110, 97, 259, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 258, 101, 32, 108, 97, 98, 258, 101, 46, 32, 69, 115, 257, 118, 101, 108, 105, 257, 97, 108, 105, 256, 97, 259, 116, 101, 109, 112, 258, 97, 32, 110, 101, 256, 101, 32, 112, 258, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 258, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 257, 118, 111, 108, 117, 112, 116, 97, 116, 101, 259, 256, 105, 115, 256, 97, 109, 46, 32, 76, 97, 98, 258, 101, 32, 256, 97, 101, 114, 97, 257, 100, 111, 108, 258, 101, 259, 115, 105, 257, 97, 109, 101, 257, 97, 108, 105, 256, 97, 259, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 257, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 259, 101, 115, 257, 110, 101, 256, 101, 46, 32, 68, 111, 108, 258, 101, 32, 108, 97, 98, 258, 101, 32, 108, 97, 98, 258, 101, 32, 100, 111, 108, 258, 101, 259, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 256, 105, 97, 32, 256, 105, 115, 256, 97, 259, 100, 111, 108, 258, 101, 32, 112, 258, 114, 111, 46, 32, 68, 111, 108, 258, 101, 32, 110, 101, 256, 101, 32, 109, 97, 103, 110, 97, 259, 101, 115, 257, 256, 105, 115, 256, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 259, 118, 111, 108, 117, 112, 116, 97, 116, 101, 259, 117, 257, 117, 257, 97, 108, 105, 256, 97, 259, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 257, 105, 112, 115, 117, 259, 109, 97, 103, 110, 97, 259, 101, 115, 116, 46, 32, 68, 111, 108, 258, 101, 259, 256, 97, 101, 114, 97, 257, 115, 105, 257, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 256, 97, 259, 110, 111, 110, 32, 109, 97, 103, 110, 97, 259, 256, 105, 115, 256, 97, 259, 110, 101, 256, 101, 46, 32, 69, 115, 257, 100, 111, 108, 258, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 258, 97, 32, 112, 258, 114, 111, 46, 32, 69, 115, 257, 116, 101, 109, 112, 258, 97, 32, 256, 97, 101, 114, 97, 257, 109, 111, 100, 105, 32, 256, 97, 101, 114, 97, 257, 109, 97, 103, 110, 97, 259, 108, 97, 98, 258, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 256, 97, 259, 110, 111, 110, 32, 97, 109, 101, 257, 117, 257, 97, 108, 105, 256, 97, 109, 46, 32, 68, 111, 108, 258, 32, 256, 105, 115, 256, 97, 259, 100, 111, 108, 258, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 715
assert decode(encode(text)) == textAs a class
BPE
def BPE(
vocab_size:int
):
Initialize self. See help(type(self)) for accurate signature.
bpe = BPE(vocab_size=260)bpe.train(text)Tests
# Vocab has the right size
assert len(bpe.vocab) == 260, f"Expected 260, got {len(bpe.vocab)}"
# Base vocab entries are single bytes
assert bpe.vocab[65] == b"A"
assert bpe.vocab[32] == b" "
# Merged vocab entries are bytes concatenations of their component tokens
for pair, idx in bpe.merges.items():
assert bpe.vocab[idx] == bpe.vocab[pair[0]] + bpe.vocab[pair[1]], f"vocab[{idx}] mismatch for pair {pair}"
# Encoding reduces sequence length vs raw UTF-8
raw_len = len(list(map(int, text.encode("utf-8"))))
encoded = bpe.encode(text)
assert len(encoded) < raw_len, "Encoding should shorten the sequence"
# Round-trip: decode(encode(text)) == text
assert bpe.decode(bpe.encode(text)) == text, "Round-trip failed"
# Encoding unseen text still produces a valid token sequence
unseen = "Hello world"
encoded_unseen = bpe.encode(unseen)
assert isinstance(encoded_unseen, list)
assert all(isinstance(t, int) for t in encoded_unseen)
assert bpe.decode(encoded_unseen) == unseenAll tests passed.