Byte-pair encoding from scratch

random.seed(42)

In byte-pair encoding we build a vocabulary by iteratively merging the most frequent pair of adjacent tokens.

  1. Start with individual characters as tokens.
  2. Find the most frequent pair of adjacent tokens.
  3. Merge that pair into a new token.
  4. Repeat until we reach the desired size of our vocabulary.

We start off by creating a small corpus of text.

text = lorem.text()

print(text)
Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.

Consectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.

Quiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.

Training

tokens = text.encode("utf-8")

print(tokens)
b'Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.\n\nConsectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.\n\nQuiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.'
tokens = list(map(int, tokens))

print(tokens)
print(f"Length: {len(tokens)}")
[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 111, 114, 32, 116, 101, 109, 112, 111, 114, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 116, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 97, 108, 105, 113, 117, 97, 109, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 101, 105, 117, 115, 46, 32, 81, 117, 105, 113, 117, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 111, 114, 114, 111, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 115, 105, 116, 32, 112, 111, 114, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 113, 117, 101, 32, 113, 117, 97, 101, 114, 97, 116, 32, 101, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 116, 32, 109, 111, 100, 105, 32, 105, 112, 115, 117, 109, 32, 101, 115, 116, 32, 100, 111, 108, 111, 114, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 116, 32, 109, 97, 103, 110, 97, 109, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 46, 32, 69, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 116, 101, 109, 112, 111, 114, 97, 32, 110, 101, 113, 117, 101, 32, 112, 111, 114, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 111, 114, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 113, 117, 105, 115, 113, 117, 97, 109, 46, 32, 76, 97, 98, 111, 114, 101, 32, 113, 117, 97, 101, 114, 97, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 97, 109, 101, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 116, 32, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 110, 101, 113, 117, 101, 46, 32, 68, 111, 108, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 113, 117, 105, 97, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 112, 111, 114, 114, 111, 46, 32, 68, 111, 108, 111, 114, 101, 32, 110, 101, 113, 117, 101, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 113, 117, 105, 115, 113, 117, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 109, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 117, 116, 32, 117, 116, 32, 97, 108, 105, 113, 117, 97, 109, 32, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 116, 32, 105, 112, 115, 117, 109, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 46, 32, 68, 111, 108, 111, 114, 101, 109, 32, 113, 117, 97, 101, 114, 97, 116, 32, 115, 105, 116, 32, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 113, 117, 97, 109, 32, 110, 111, 110, 32, 109, 97, 103, 110, 97, 109, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 110, 101, 113, 117, 101, 46, 32, 69, 115, 116, 32, 100, 111, 108, 111, 114, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 111, 114, 97, 32, 112, 111, 114, 114, 111, 46, 32, 69, 115, 116, 32, 116, 101, 109, 112, 111, 114, 97, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 113, 117, 97, 101, 114, 97, 116, 32, 109, 97, 103, 110, 97, 109, 32, 108, 97, 98, 111, 114, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 113, 117, 97, 109, 32, 110, 111, 110, 32, 97, 109, 101, 116, 32, 117, 116, 32, 97, 108, 105, 113, 117, 97, 109, 46, 32, 68, 111, 108, 111, 114, 32, 113, 117, 105, 115, 113, 117, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 833
pair_counts: dict[tuple[int, int], int] = {}
for pair in zip(tokens, tokens[1:]):
    pair_counts[pair] = pair_counts.get(pair, 0) + 1

print(pair_counts)
{(69, 105): 3, (105, 117): 10, (117, 115): 10, (115, 32): 6, (32, 101): 14, (101, 105): 7, (32, 100): 9, (100, 111): 9, (111, 108): 15, (108, 111): 13, (111, 114): 29, (114, 32): 8, (32, 116): 4, (116, 101): 11, (101, 109): 11, (109, 112): 4, (112, 111): 9, (114, 97): 10, (97, 32): 6, (32, 99): 4, (99, 111): 4, (111, 110): 7, (110, 115): 5, (115, 101): 9, (101, 99): 5, (99, 116): 5, (101, 116): 12, (116, 117): 5, (117, 114): 5, (32, 115): 8, (101, 100): 4, (100, 46): 1, (46, 32): 18, (32, 65): 2, (65, 109): 2, (109, 101): 6, (116, 32): 31, (32, 113): 12, (113, 117): 32, (117, 97): 19, (97, 101): 6, (101, 114): 6, (97, 116): 8, (32, 109): 10, (109, 111): 3, (111, 100): 3, (100, 105): 6, (105, 32): 6, (32, 97): 12, (97, 108): 5, (108, 105): 9, (105, 113): 7, (97, 109): 24, (109, 32): 26, (97, 100): 3, (105, 112): 7, (112, 105): 3, (105, 115): 10, (115, 99): 3, (99, 105): 4, (114, 101): 15, (115, 46): 4, (32, 81): 2, (81, 117): 3, (117, 105): 11, (105, 97): 2, (32, 112): 5, (114, 114): 6, (114, 111): 6, (111, 32): 3, (115, 105): 4, (105, 116): 9, (115, 113): 7, (111, 46): 3, (32, 69): 5, (32, 110): 7, (110, 101): 5, (101, 113): 5, (117, 101): 5, (101, 32): 12, (101, 115): 6, (115, 116): 9, (32, 118): 5, (118, 101): 3, (101, 108): 4, (32, 117): 4, (117, 116): 4, (116, 46): 4, (32, 83): 1, (83, 105): 1, (32, 105): 4, (112, 115): 4, (115, 117): 4, (117, 109): 5, (114, 46): 1, (46, 10): 2, (10, 10): 2, (10, 67): 1, (67, 111): 1, (109, 97): 7, (97, 103): 7, (103, 110): 7, (110, 97): 7, (32, 108): 5, (108, 97): 5, (97, 98): 6, (98, 111): 6, (101, 46): 3, (69, 115): 3, (109, 46): 5, (32, 80): 1, (80, 111): 1, (116, 105): 1, (105, 110): 1, (110, 99): 1, (105, 100): 1, (100, 117): 1, (117, 110): 1, (110, 116): 1, (118, 111): 2, (108, 117): 2, (117, 112): 2, (112, 116): 2, (116, 97): 2, (32, 76): 1, (76, 97): 1, (100, 32): 3, (32, 68): 4, (68, 111): 4, (10, 81): 1, (32, 86): 1, (86, 101): 1, (110, 111): 2, (110, 32): 2, (32, 78): 1, (78, 117): 1, (109, 113): 1}

You can print the sorted pair counts like this:

print(sorted(pair_counts.items(), key=lambda kv: kv[1], reverse=True))
[((113, 117), 32), ((116, 32), 31), ((111, 114), 29), ((109, 32), 26), ((97, 109), 24), ((117, 97), 19), ((46, 32), 18), ((111, 108), 15), ((114, 101), 15), ((32, 101), 14), ((108, 111), 13), ((101, 116), 12), ((32, 113), 12), ((32, 97), 12), ((101, 32), 12), ((116, 101), 11), ((101, 109), 11), ((117, 105), 11), ((105, 117), 10), ((117, 115), 10), ((114, 97), 10), ((32, 109), 10), ((105, 115), 10), ((32, 100), 9), ((100, 111), 9), ((112, 111), 9), ((115, 101), 9), ((108, 105), 9), ((105, 116), 9), ((115, 116), 9), ((114, 32), 8), ((32, 115), 8), ((97, 116), 8), ((101, 105), 7), ((111, 110), 7), ((105, 113), 7), ((105, 112), 7), ((115, 113), 7), ((32, 110), 7), ((109, 97), 7), ((97, 103), 7), ((103, 110), 7), ((110, 97), 7), ((115, 32), 6), ((97, 32), 6), ((109, 101), 6), ((97, 101), 6), ((101, 114), 6), ((100, 105), 6), ((105, 32), 6), ((114, 114), 6), ((114, 111), 6), ((101, 115), 6), ((97, 98), 6), ((98, 111), 6), ((110, 115), 5), ((101, 99), 5), ((99, 116), 5), ((116, 117), 5), ((117, 114), 5), ((97, 108), 5), ((32, 112), 5), ((32, 69), 5), ((110, 101), 5), ((101, 113), 5), ((117, 101), 5), ((32, 118), 5), ((117, 109), 5), ((32, 108), 5), ((108, 97), 5), ((109, 46), 5), ((32, 116), 4), ((109, 112), 4), ((32, 99), 4), ((99, 111), 4), ((101, 100), 4), ((99, 105), 4), ((115, 46), 4), ((115, 105), 4), ((101, 108), 4), ((32, 117), 4), ((117, 116), 4), ((116, 46), 4), ((32, 105), 4), ((112, 115), 4), ((115, 117), 4), ((32, 68), 4), ((68, 111), 4), ((69, 105), 3), ((109, 111), 3), ((111, 100), 3), ((97, 100), 3), ((112, 105), 3), ((115, 99), 3), ((81, 117), 3), ((111, 32), 3), ((111, 46), 3), ((118, 101), 3), ((101, 46), 3), ((69, 115), 3), ((100, 32), 3), ((32, 65), 2), ((65, 109), 2), ((32, 81), 2), ((105, 97), 2), ((46, 10), 2), ((10, 10), 2), ((118, 111), 2), ((108, 117), 2), ((117, 112), 2), ((112, 116), 2), ((116, 97), 2), ((110, 111), 2), ((110, 32), 2), ((100, 46), 1), ((32, 83), 1), ((83, 105), 1), ((114, 46), 1), ((10, 67), 1), ((67, 111), 1), ((32, 80), 1), ((80, 111), 1), ((116, 105), 1), ((105, 110), 1), ((110, 99), 1), ((105, 100), 1), ((100, 117), 1), ((117, 110), 1), ((110, 116), 1), ((32, 76), 1), ((76, 97), 1), ((10, 81), 1), ((32, 86), 1), ((86, 101), 1), ((32, 78), 1), ((78, 117), 1), ((109, 113), 1)]

Next we need to pick the pair with the highest count as our first merge.

most_common_pair = max(pair_counts, key=lambda p: pair_counts[p])
print(most_common_pair)
print(f"Count: {pair_counts[most_common_pair]}")
print(f"{chr(most_common_pair[0])}{chr(most_common_pair[1])}")
(113, 117)
Count: 32
qu

And merge it into a new token.

new_token = most_common_pair[0] + most_common_pair[1]
print(new_token)
230

We now update our words to use the new token.

new_tokens = []
i = 0
while i < len(tokens):
    if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == most_common_pair:
        new_tokens.append(new_token)
        i += 2
    else:
        new_tokens.append(tokens[i])
        i += 1

print(new_tokens)
print(f"Length: {len(new_tokens)}")
[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 111, 114, 32, 116, 101, 109, 112, 111, 114, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 116, 32, 230, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 97, 108, 105, 230, 97, 109, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 101, 105, 117, 115, 46, 32, 81, 117, 105, 230, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 111, 114, 114, 111, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 230, 105, 115, 230, 97, 109, 32, 115, 105, 116, 32, 112, 111, 114, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 230, 101, 32, 230, 97, 101, 114, 97, 116, 32, 101, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 116, 32, 109, 111, 100, 105, 32, 105, 112, 115, 117, 109, 32, 101, 115, 116, 32, 100, 111, 108, 111, 114, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 116, 32, 109, 97, 103, 110, 97, 109, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 46, 32, 69, 115, 116, 32, 118, 101, 108, 105, 116, 32, 97, 108, 105, 230, 97, 109, 32, 116, 101, 109, 112, 111, 114, 97, 32, 110, 101, 230, 101, 32, 112, 111, 114, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 111, 114, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 230, 105, 115, 230, 97, 109, 46, 32, 76, 97, 98, 111, 114, 101, 32, 230, 97, 101, 114, 97, 116, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 105, 116, 32, 97, 109, 101, 116, 32, 97, 108, 105, 230, 97, 109, 32, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 116, 32, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 110, 101, 230, 101, 46, 32, 68, 111, 108, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 108, 97, 98, 111, 114, 101, 32, 100, 111, 108, 111, 114, 101, 109, 32, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 230, 105, 97, 32, 230, 105, 115, 230, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 112, 111, 114, 114, 111, 46, 32, 68, 111, 108, 111, 114, 101, 32, 110, 101, 230, 101, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 32, 230, 105, 115, 230, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 109, 32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109, 32, 117, 116, 32, 117, 116, 32, 97, 108, 105, 230, 97, 109, 32, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 116, 32, 105, 112, 115, 117, 109, 32, 109, 97, 103, 110, 97, 109, 32, 101, 115, 116, 46, 32, 68, 111, 108, 111, 114, 101, 109, 32, 230, 97, 101, 114, 97, 116, 32, 115, 105, 116, 32, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 230, 97, 109, 32, 110, 111, 110, 32, 109, 97, 103, 110, 97, 109, 32, 230, 105, 115, 230, 97, 109, 32, 110, 101, 230, 101, 46, 32, 69, 115, 116, 32, 100, 111, 108, 111, 114, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 111, 114, 97, 32, 112, 111, 114, 114, 111, 46, 32, 69, 115, 116, 32, 116, 101, 109, 112, 111, 114, 97, 32, 230, 97, 101, 114, 97, 116, 32, 109, 111, 100, 105, 32, 230, 97, 101, 114, 97, 116, 32, 109, 97, 103, 110, 97, 109, 32, 108, 97, 98, 111, 114, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 230, 97, 109, 32, 110, 111, 110, 32, 97, 109, 101, 116, 32, 117, 116, 32, 97, 108, 105, 230, 97, 109, 46, 32, 68, 111, 108, 111, 114, 32, 230, 105, 115, 230, 97, 109, 32, 100, 111, 108, 111, 114, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 801
assert not any(
    new_tokens[i] == most_common_pair[0] and new_tokens[i + 1] == most_common_pair[1]
    for i in range(len(new_tokens) - 1)
)

new_tokens.count(new_token)
32

We can see that the length of the token list has decreased by the number of times we merged the pair (833 -> 801). Meanwhile, our vocabulary has obviously increased by 1.

Let’s wrap what we’ve done so far into some functions:

def get_tokens(text: str) -> list[int]:
    tokens: bytes = text.encode(encoding="utf-8")
    return list(map(int, tokens))


def get_pair_counts(tokens: list[int]) -> dict[tuple[int, int], int]:
    counts: dict[tuple[int, int], int] = {}
    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge_pair(tokens: list[int], pair: tuple[int, int], idx: int) -> list[int]:
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
            new_tokens.append(idx)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

And we’ll use them in a loop

VOCAB_SIZE = 260
merges: dict[tuple[int, int], int] = {}
tokens = get_tokens(text)

for i in range(256, VOCAB_SIZE):
    pair_counts = get_pair_counts(tokens)
    most_common_pair = max(pair_counts, key=lambda p: pair_counts[p])
    tokens = merge_pair(tokens, most_common_pair, i)
    merges[most_common_pair] = i
    print(f"Merged {most_common_pair} into {i}. Sequence length: {len(tokens)}")
Merged (113, 117) into 256. Sequence length: 801
Merged (116, 32) into 257. Sequence length: 770
Merged (111, 114) into 258. Sequence length: 741
Merged (109, 32) into 259. Sequence length: 715
print(merges)
{(113, 117): 256, (116, 32): 257, (111, 114): 258, (109, 32): 259}

And that is how we build up our vocabulary. The vocab put together looks like this:

vocab = {i: bytes([i]) for i in range(256)}
for pair, idx in merges.items():
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

print(vocab)
{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[', 92: b'\\', 93: b']', 94: b'^', 95: b'_', 96: b'`', 97: b'a', 98: b'b', 99: b'c', 100: b'd', 101: b'e', 102: b'f', 103: b'g', 104: b'h', 105: b'i', 106: b'j', 107: b'k', 108: b'l', 109: b'm', 110: b'n', 111: b'o', 112: b'p', 113: b'q', 114: b'r', 115: b's', 116: b't', 117: b'u', 118: b'v', 119: b'w', 120: b'x', 121: b'y', 122: b'z', 123: b'{', 124: b'|', 125: b'}', 126: b'~', 127: b'\x7f', 128: b'\x80', 129: b'\x81', 130: b'\x82', 131: b'\x83', 132: b'\x84', 133: b'\x85', 134: b'\x86', 135: b'\x87', 136: b'\x88', 137: b'\x89', 138: b'\x8a', 139: b'\x8b', 140: b'\x8c', 141: b'\x8d', 142: b'\x8e', 143: b'\x8f', 144: b'\x90', 145: b'\x91', 146: b'\x92', 147: b'\x93', 148: b'\x94', 149: b'\x95', 150: b'\x96', 151: b'\x97', 152: b'\x98', 153: b'\x99', 154: b'\x9a', 155: b'\x9b', 156: b'\x9c', 157: b'\x9d', 158: b'\x9e', 159: b'\x9f', 160: b'\xa0', 161: b'\xa1', 162: b'\xa2', 163: b'\xa3', 164: b'\xa4', 165: b'\xa5', 166: b'\xa6', 167: b'\xa7', 168: b'\xa8', 169: b'\xa9', 170: b'\xaa', 171: b'\xab', 172: b'\xac', 173: b'\xad', 174: b'\xae', 175: b'\xaf', 176: b'\xb0', 177: b'\xb1', 178: b'\xb2', 179: b'\xb3', 180: b'\xb4', 181: b'\xb5', 182: b'\xb6', 183: b'\xb7', 184: b'\xb8', 185: b'\xb9', 186: b'\xba', 187: b'\xbb', 188: b'\xbc', 189: b'\xbd', 190: b'\xbe', 191: b'\xbf', 192: b'\xc0', 193: b'\xc1', 194: b'\xc2', 195: b'\xc3', 196: b'\xc4', 197: b'\xc5', 198: b'\xc6', 199: b'\xc7', 200: b'\xc8', 201: b'\xc9', 202: b'\xca', 203: b'\xcb', 204: b'\xcc', 205: b'\xcd', 206: b'\xce', 207: b'\xcf', 208: b'\xd0', 209: b'\xd1', 210: b'\xd2', 211: b'\xd3', 212: b'\xd4', 213: b'\xd5', 214: b'\xd6', 215: b'\xd7', 216: b'\xd8', 217: b'\xd9', 218: b'\xda', 219: b'\xdb', 220: b'\xdc', 221: b'\xdd', 222: b'\xde', 223: b'\xdf', 224: b'\xe0', 225: b'\xe1', 226: b'\xe2', 227: b'\xe3', 228: b'\xe4', 229: b'\xe5', 230: b'\xe6', 231: b'\xe7', 232: b'\xe8', 233: b'\xe9', 234: b'\xea', 235: b'\xeb', 236: b'\xec', 237: b'\xed', 238: b'\xee', 239: b'\xef', 240: b'\xf0', 241: b'\xf1', 242: b'\xf2', 243: b'\xf3', 244: b'\xf4', 245: b'\xf5', 246: b'\xf6', 247: b'\xf7', 248: b'\xf8', 249: b'\xf9', 250: b'\xfa', 251: b'\xfb', 252: b'\xfc', 253: b'\xfd', 254: b'\xfe', 255: b'\xff', 256: b'qu', 257: b't ', 258: b'or', 259: b'm '}

Decoding

We need a way to decode our tokens back into text.

def decode(tokens: list[int]) -> str:
    token_bytes: bytes = b"".join(vocab[token] for token in tokens)
    return token_bytes.decode("utf-8", errors="replace")


print(decode(tokens))
Eius eius dolor tempora consectetur sed. Amet quaerat modi aliquam adipisci amet dolorem eius. Quiquia adipisci porro dolorem sit quisquam sit porro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dolor.

Consectetur amet magnam consectetur labore labore. Est velit aliquam tempora neque porro consectetur magnam. Porro etincidunt voluptatem quisquam. Labore quaerat dolorem sit amet aliquam sed eius. Amet eius consectetur magnam est neque. Dolore labore labore dolorem sed est.

Quiquia quisquam dolore porro. Dolore neque magnam est quisquam. Eius sed ipsum voluptatem ut ut aliquam eius. Velit ipsum magnam est. Dolorem quaerat sit ipsum. Quisquam non magnam quisquam neque. Est dolor eius tempora porro. Est tempora quaerat modi quaerat magnam labore eius. Numquam non amet ut aliquam. Dolor quisquam dolore velit.

Encoding

We also need a way to encode text into tokens, following the rules we built up in the training phase.

def encode(text: str) -> list[int]:
    token_bytes: bytes = text.encode(encoding="utf-8")
    tokens: list[int] = list(map(int, token_bytes))
    for pair, idx in merges.items():
        tokens = merge_pair(tokens, pair, idx)
    return tokens


print(encode(text))
print(f"Length: {len(encode(text))}")
[69, 105, 117, 115, 32, 101, 105, 117, 115, 32, 100, 111, 108, 258, 32, 116, 101, 109, 112, 258, 97, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 115, 101, 100, 46, 32, 65, 109, 101, 257, 256, 97, 101, 114, 97, 257, 109, 111, 100, 105, 32, 97, 108, 105, 256, 97, 259, 97, 100, 105, 112, 105, 115, 99, 105, 32, 97, 109, 101, 257, 100, 111, 108, 258, 101, 259, 101, 105, 117, 115, 46, 32, 81, 117, 105, 256, 105, 97, 32, 97, 100, 105, 112, 105, 115, 99, 105, 32, 112, 258, 114, 111, 32, 100, 111, 108, 258, 101, 259, 115, 105, 257, 256, 105, 115, 256, 97, 259, 115, 105, 257, 112, 258, 114, 111, 46, 32, 69, 105, 117, 115, 32, 110, 101, 256, 101, 32, 256, 97, 101, 114, 97, 257, 101, 115, 257, 118, 101, 108, 105, 257, 97, 100, 105, 112, 105, 115, 99, 105, 32, 117, 116, 46, 32, 83, 105, 257, 109, 111, 100, 105, 32, 105, 112, 115, 117, 259, 101, 115, 257, 100, 111, 108, 258, 46, 10, 10, 67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 97, 109, 101, 257, 109, 97, 103, 110, 97, 259, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 108, 97, 98, 258, 101, 32, 108, 97, 98, 258, 101, 46, 32, 69, 115, 257, 118, 101, 108, 105, 257, 97, 108, 105, 256, 97, 259, 116, 101, 109, 112, 258, 97, 32, 110, 101, 256, 101, 32, 112, 258, 114, 111, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 109, 46, 32, 80, 258, 114, 111, 32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 257, 118, 111, 108, 117, 112, 116, 97, 116, 101, 259, 256, 105, 115, 256, 97, 109, 46, 32, 76, 97, 98, 258, 101, 32, 256, 97, 101, 114, 97, 257, 100, 111, 108, 258, 101, 259, 115, 105, 257, 97, 109, 101, 257, 97, 108, 105, 256, 97, 259, 115, 101, 100, 32, 101, 105, 117, 115, 46, 32, 65, 109, 101, 257, 101, 105, 117, 115, 32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114, 32, 109, 97, 103, 110, 97, 259, 101, 115, 257, 110, 101, 256, 101, 46, 32, 68, 111, 108, 258, 101, 32, 108, 97, 98, 258, 101, 32, 108, 97, 98, 258, 101, 32, 100, 111, 108, 258, 101, 259, 115, 101, 100, 32, 101, 115, 116, 46, 10, 10, 81, 117, 105, 256, 105, 97, 32, 256, 105, 115, 256, 97, 259, 100, 111, 108, 258, 101, 32, 112, 258, 114, 111, 46, 32, 68, 111, 108, 258, 101, 32, 110, 101, 256, 101, 32, 109, 97, 103, 110, 97, 259, 101, 115, 257, 256, 105, 115, 256, 97, 109, 46, 32, 69, 105, 117, 115, 32, 115, 101, 100, 32, 105, 112, 115, 117, 259, 118, 111, 108, 117, 112, 116, 97, 116, 101, 259, 117, 257, 117, 257, 97, 108, 105, 256, 97, 259, 101, 105, 117, 115, 46, 32, 86, 101, 108, 105, 257, 105, 112, 115, 117, 259, 109, 97, 103, 110, 97, 259, 101, 115, 116, 46, 32, 68, 111, 108, 258, 101, 259, 256, 97, 101, 114, 97, 257, 115, 105, 257, 105, 112, 115, 117, 109, 46, 32, 81, 117, 105, 115, 256, 97, 259, 110, 111, 110, 32, 109, 97, 103, 110, 97, 259, 256, 105, 115, 256, 97, 259, 110, 101, 256, 101, 46, 32, 69, 115, 257, 100, 111, 108, 258, 32, 101, 105, 117, 115, 32, 116, 101, 109, 112, 258, 97, 32, 112, 258, 114, 111, 46, 32, 69, 115, 257, 116, 101, 109, 112, 258, 97, 32, 256, 97, 101, 114, 97, 257, 109, 111, 100, 105, 32, 256, 97, 101, 114, 97, 257, 109, 97, 103, 110, 97, 259, 108, 97, 98, 258, 101, 32, 101, 105, 117, 115, 46, 32, 78, 117, 109, 256, 97, 259, 110, 111, 110, 32, 97, 109, 101, 257, 117, 257, 97, 108, 105, 256, 97, 109, 46, 32, 68, 111, 108, 258, 32, 256, 105, 115, 256, 97, 259, 100, 111, 108, 258, 101, 32, 118, 101, 108, 105, 116, 46]
Length: 715
assert decode(encode(text)) == text

As a class


NaiveBPE


def NaiveBPE(
    vocab_size:int
):

Initialize self. See help(type(self)) for accurate signature.

naive_bpe = NaiveBPE(vocab_size=260)
naive_bpe.train(text)

Tests

# Vocab has the right size
assert len(naive_bpe.vocab) == 260, f"Expected 260, got {len(naive_bpe.vocab)}"

# Base vocab entries are single bytes
assert naive_bpe.vocab[65] == b"A"
assert naive_bpe.vocab[32] == b" "

# Merged vocab entries are bytes concatenations of their component tokens
for pair, idx in naive_bpe.merges.items():
    assert naive_bpe.vocab[idx] == naive_bpe.vocab[pair[0]] + naive_bpe.vocab[pair[1]], (
        f"vocab[{idx}] mismatch for pair {pair}"
    )

# Encoding reduces sequence length vs raw UTF-8
raw_len = len(list(map(int, text.encode("utf-8"))))
encoded = naive_bpe.encode(text)
assert len(encoded) < raw_len, "Encoding should shorten the sequence"

# Round-trip: decode(encode(text)) == text
assert naive_bpe.decode(naive_bpe.encode(text)) == text, "Round-trip failed"

# Encoding unseen text still produces a valid token sequence
unseen = "Hello world"
encoded_unseen = naive_bpe.encode(unseen)
assert isinstance(encoded_unseen, list)
assert all(isinstance(t, int) for t in encoded_unseen)
assert naive_bpe.decode(encoded_unseen) == unseen

Avoiding merging across word boundaries

We’ll use the regex from GPT-4 to split the text into words.

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
text_chunks = re.findall(GPT4_SPLIT_PATTERN, text)

print(text_chunks)
['Eius', ' eius', ' dolor', ' tempora', ' consectetur', ' sed', '.', ' Amet', ' quaerat', ' modi', ' aliquam', ' adipisci', ' amet', ' dolorem', ' eius', '.', ' Quiquia', ' adipisci', ' porro', ' dolorem', ' sit', ' quisquam', ' sit', ' porro', '.', ' Eius', ' neque', ' quaerat', ' est', ' velit', ' adipisci', ' ut', '.', ' Sit', ' modi', ' ipsum', ' est', ' dolor', '.\n\n', 'Consectetur', ' amet', ' magnam', ' consectetur', ' labore', ' labore', '.', ' Est', ' velit', ' aliquam', ' tempora', ' neque', ' porro', ' consectetur', ' magnam', '.', ' Porro', ' etincidunt', ' voluptatem', ' quisquam', '.', ' Labore', ' quaerat', ' dolorem', ' sit', ' amet', ' aliquam', ' sed', ' eius', '.', ' Amet', ' eius', ' consectetur', ' magnam', ' est', ' neque', '.', ' Dolore', ' labore', ' labore', ' dolorem', ' sed', ' est', '.\n\n', 'Quiquia', ' quisquam', ' dolore', ' porro', '.', ' Dolore', ' neque', ' magnam', ' est', ' quisquam', '.', ' Eius', ' sed', ' ipsum', ' voluptatem', ' ut', ' ut', ' aliquam', ' eius', '.', ' Velit', ' ipsum', ' magnam', ' est', '.', ' Dolorem', ' quaerat', ' sit', ' ipsum', '.', ' Quisquam', ' non', ' magnam', ' quisquam', ' neque', '.', ' Est', ' dolor', ' eius', ' tempora', ' porro', '.', ' Est', ' tempora', ' quaerat', ' modi', ' quaerat', ' magnam', ' labore', ' eius', '.', ' Numquam', ' non', ' amet', ' ut', ' aliquam', '.', ' Dolor', ' quisquam', ' dolore', ' velit', '.']

When we happy to tokenise across any part of the text we encoded it all at once. but now we have a list of “text chunks”.

tokenised_chunks = [list(ch.encode("utf-8")) for ch in text_chunks]

print(tokenised_chunks)
[[69, 105, 117, 115], [32, 101, 105, 117, 115], [32, 100, 111, 108, 111, 114], [32, 116, 101, 109, 112, 111, 114, 97], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 115, 101, 100], [46], [32, 65, 109, 101, 116], [32, 113, 117, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 97, 108, 105, 113, 117, 97, 109], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 97, 109, 101, 116], [32, 100, 111, 108, 111, 114, 101, 109], [32, 101, 105, 117, 115], [46], [32, 81, 117, 105, 113, 117, 105, 97], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 112, 111, 114, 114, 111], [32, 100, 111, 108, 111, 114, 101, 109], [32, 115, 105, 116], [32, 113, 117, 105, 115, 113, 117, 97, 109], [32, 115, 105, 116], [32, 112, 111, 114, 114, 111], [46], [32, 69, 105, 117, 115], [32, 110, 101, 113, 117, 101], [32, 113, 117, 97, 101, 114, 97, 116], [32, 101, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 117, 116], [46], [32, 83, 105, 116], [32, 109, 111, 100, 105], [32, 105, 112, 115, 117, 109], [32, 101, 115, 116], [32, 100, 111, 108, 111, 114], [46, 10, 10], [67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 97, 109, 101, 116], [32, 109, 97, 103, 110, 97, 109], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 108, 97, 98, 111, 114, 101], [32, 108, 97, 98, 111, 114, 101], [46], [32, 69, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 108, 105, 113, 117, 97, 109], [32, 116, 101, 109, 112, 111, 114, 97], [32, 110, 101, 113, 117, 101], [32, 112, 111, 114, 114, 111], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 97, 109], [46], [32, 80, 111, 114, 114, 111], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109], [32, 113, 117, 105, 115, 113, 117, 97, 109], [46], [32, 76, 97, 98, 111, 114, 101], [32, 113, 117, 97, 101, 114, 97, 116], [32, 100, 111, 108, 111, 114, 101, 109], [32, 115, 105, 116], [32, 97, 109, 101, 116], [32, 97, 108, 105, 113, 117, 97, 109], [32, 115, 101, 100], [32, 101, 105, 117, 115], [46], [32, 65, 109, 101, 116], [32, 101, 105, 117, 115], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 97, 109], [32, 101, 115, 116], [32, 110, 101, 113, 117, 101], [46], [32, 68, 111, 108, 111, 114, 101], [32, 108, 97, 98, 111, 114, 101], [32, 108, 97, 98, 111, 114, 101], [32, 100, 111, 108, 111, 114, 101, 109], [32, 115, 101, 100], [32, 101, 115, 116], [46, 10, 10], [81, 117, 105, 113, 117, 105, 97], [32, 113, 117, 105, 115, 113, 117, 97, 109], [32, 100, 111, 108, 111, 114, 101], [32, 112, 111, 114, 114, 111], [46], [32, 68, 111, 108, 111, 114, 101], [32, 110, 101, 113, 117, 101], [32, 109, 97, 103, 110, 97, 109], [32, 101, 115, 116], [32, 113, 117, 105, 115, 113, 117, 97, 109], [46], [32, 69, 105, 117, 115], [32, 115, 101, 100], [32, 105, 112, 115, 117, 109], [32, 118, 111, 108, 117, 112, 116, 97, 116, 101, 109], [32, 117, 116], [32, 117, 116], [32, 97, 108, 105, 113, 117, 97, 109], [32, 101, 105, 117, 115], [46], [32, 86, 101, 108, 105, 116], [32, 105, 112, 115, 117, 109], [32, 109, 97, 103, 110, 97, 109], [32, 101, 115, 116], [46], [32, 68, 111, 108, 111, 114, 101, 109], [32, 113, 117, 97, 101, 114, 97, 116], [32, 115, 105, 116], [32, 105, 112, 115, 117, 109], [46], [32, 81, 117, 105, 115, 113, 117, 97, 109], [32, 110, 111, 110], [32, 109, 97, 103, 110, 97, 109], [32, 113, 117, 105, 115, 113, 117, 97, 109], [32, 110, 101, 113, 117, 101], [46], [32, 69, 115, 116], [32, 100, 111, 108, 111, 114], [32, 101, 105, 117, 115], [32, 116, 101, 109, 112, 111, 114, 97], [32, 112, 111, 114, 114, 111], [46], [32, 69, 115, 116], [32, 116, 101, 109, 112, 111, 114, 97], [32, 113, 117, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 113, 117, 97, 101, 114, 97, 116], [32, 109, 97, 103, 110, 97, 109], [32, 108, 97, 98, 111, 114, 101], [32, 101, 105, 117, 115], [46], [32, 78, 117, 109, 113, 117, 97, 109], [32, 110, 111, 110], [32, 97, 109, 101, 116], [32, 117, 116], [32, 97, 108, 105, 113, 117, 97, 109], [46], [32, 68, 111, 108, 111, 114], [32, 113, 117, 105, 115, 113, 117, 97, 109], [32, 100, 111, 108, 111, 114, 101], [32, 118, 101, 108, 105, 116], [46]]
print(f"Number of chunks: {len(tokenised_chunks)}")
print(f"Number of tokens: {sum([len(tc) for tc in tokenised_chunks])}")
Number of chunks: 145
Number of tokens: 833

We can still reuse the functions we created before. But we’ll add some tweaks to the training loop to handle the fact that we’re now operating on a list of chunks.

VOCAB_SIZE = 260
merges: dict[tuple[int, int], int] = {}
tokens = get_tokens(text)

for i in range(256, VOCAB_SIZE):
    pair_counts = {}
    for tc in tokenised_chunks:
        for pair, count in get_pair_counts(tc).items():
            pair_counts[pair] = pair_counts.get(pair, 0) + count
    most_common_pair = max(pair_counts, key=lambda p: pair_counts[p])
    tokenised_chunks = [merge_pair(tc, most_common_pair, i) for tc in tokenised_chunks]
    merges[most_common_pair] = i
print(f"Merges: {merges}")
print(f"Number of chunks: {len(tokenised_chunks)}")
print(f"Number of tokens: {sum([len(tc) for tc in tokenised_chunks])}")
Merges: {(113, 117): 256, (111, 114): 257, (97, 109): 258, (111, 108): 259}
Number of chunks: 145
Number of tokens: 733

We still get a good amount of compression, even with the restriction of only merging within chunks.

print(tokenised_chunks)
[[69, 105, 117, 115], [32, 101, 105, 117, 115], [32, 100, 259, 257], [32, 116, 101, 109, 112, 257, 97], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 115, 101, 100], [46], [32, 65, 109, 101, 116], [32, 256, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 97, 108, 105, 256, 258], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 258, 101, 116], [32, 100, 259, 257, 101, 109], [32, 101, 105, 117, 115], [46], [32, 81, 117, 105, 256, 105, 97], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 112, 257, 114, 111], [32, 100, 259, 257, 101, 109], [32, 115, 105, 116], [32, 256, 105, 115, 256, 258], [32, 115, 105, 116], [32, 112, 257, 114, 111], [46], [32, 69, 105, 117, 115], [32, 110, 101, 256, 101], [32, 256, 97, 101, 114, 97, 116], [32, 101, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 117, 116], [46], [32, 83, 105, 116], [32, 109, 111, 100, 105], [32, 105, 112, 115, 117, 109], [32, 101, 115, 116], [32, 100, 259, 257], [46, 10, 10], [67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 258, 101, 116], [32, 109, 97, 103, 110, 258], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 108, 97, 98, 257, 101], [32, 108, 97, 98, 257, 101], [46], [32, 69, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 108, 105, 256, 258], [32, 116, 101, 109, 112, 257, 97], [32, 110, 101, 256, 101], [32, 112, 257, 114, 111], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 258], [46], [32, 80, 257, 114, 111], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 256, 105, 115, 256, 258], [46], [32, 76, 97, 98, 257, 101], [32, 256, 97, 101, 114, 97, 116], [32, 100, 259, 257, 101, 109], [32, 115, 105, 116], [32, 258, 101, 116], [32, 97, 108, 105, 256, 258], [32, 115, 101, 100], [32, 101, 105, 117, 115], [46], [32, 65, 109, 101, 116], [32, 101, 105, 117, 115], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [32, 110, 101, 256, 101], [46], [32, 68, 259, 257, 101], [32, 108, 97, 98, 257, 101], [32, 108, 97, 98, 257, 101], [32, 100, 259, 257, 101, 109], [32, 115, 101, 100], [32, 101, 115, 116], [46, 10, 10], [81, 117, 105, 256, 105, 97], [32, 256, 105, 115, 256, 258], [32, 100, 259, 257, 101], [32, 112, 257, 114, 111], [46], [32, 68, 259, 257, 101], [32, 110, 101, 256, 101], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [32, 256, 105, 115, 256, 258], [46], [32, 69, 105, 117, 115], [32, 115, 101, 100], [32, 105, 112, 115, 117, 109], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 117, 116], [32, 117, 116], [32, 97, 108, 105, 256, 258], [32, 101, 105, 117, 115], [46], [32, 86, 101, 108, 105, 116], [32, 105, 112, 115, 117, 109], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [46], [32, 68, 259, 257, 101, 109], [32, 256, 97, 101, 114, 97, 116], [32, 115, 105, 116], [32, 105, 112, 115, 117, 109], [46], [32, 81, 117, 105, 115, 256, 258], [32, 110, 111, 110], [32, 109, 97, 103, 110, 258], [32, 256, 105, 115, 256, 258], [32, 110, 101, 256, 101], [46], [32, 69, 115, 116], [32, 100, 259, 257], [32, 101, 105, 117, 115], [32, 116, 101, 109, 112, 257, 97], [32, 112, 257, 114, 111], [46], [32, 69, 115, 116], [32, 116, 101, 109, 112, 257, 97], [32, 256, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 256, 97, 101, 114, 97, 116], [32, 109, 97, 103, 110, 258], [32, 108, 97, 98, 257, 101], [32, 101, 105, 117, 115], [46], [32, 78, 117, 109, 256, 258], [32, 110, 111, 110], [32, 258, 101, 116], [32, 117, 116], [32, 97, 108, 105, 256, 258], [46], [32, 68, 259, 257], [32, 256, 105, 115, 256, 258], [32, 100, 259, 257, 101], [32, 118, 101, 108, 105, 116], [46]]

Let’s build our vocab from the merges again.

vocab = {i: bytes([i]) for i in range(256)}
for pair, idx in merges.items():
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

Encoding with chunks

Having not merged across chunks, we also need to encode each chunk independently.

encoded_tokens = []
for tc in text_chunks:
    encoded_tokens.append(encode(tc))
print(encoded_tokens)
[[69, 105, 117, 115], [32, 101, 105, 117, 115], [32, 100, 259, 257], [32, 116, 101, 109, 112, 257, 97], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 115, 101, 100], [46], [32, 65, 109, 101, 116], [32, 256, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 97, 108, 105, 256, 258], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 258, 101, 116], [32, 100, 259, 257, 101, 109], [32, 101, 105, 117, 115], [46], [32, 81, 117, 105, 256, 105, 97], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 112, 257, 114, 111], [32, 100, 259, 257, 101, 109], [32, 115, 105, 116], [32, 256, 105, 115, 256, 258], [32, 115, 105, 116], [32, 112, 257, 114, 111], [46], [32, 69, 105, 117, 115], [32, 110, 101, 256, 101], [32, 256, 97, 101, 114, 97, 116], [32, 101, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 117, 116], [46], [32, 83, 105, 116], [32, 109, 111, 100, 105], [32, 105, 112, 115, 117, 109], [32, 101, 115, 116], [32, 100, 259, 257], [46, 10, 10], [67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 258, 101, 116], [32, 109, 97, 103, 110, 258], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 108, 97, 98, 257, 101], [32, 108, 97, 98, 257, 101], [46], [32, 69, 115, 116], [32, 118, 101, 108, 105, 116], [32, 97, 108, 105, 256, 258], [32, 116, 101, 109, 112, 257, 97], [32, 110, 101, 256, 101], [32, 112, 257, 114, 111], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 258], [46], [32, 80, 257, 114, 111], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 256, 105, 115, 256, 258], [46], [32, 76, 97, 98, 257, 101], [32, 256, 97, 101, 114, 97, 116], [32, 100, 259, 257, 101, 109], [32, 115, 105, 116], [32, 258, 101, 116], [32, 97, 108, 105, 256, 258], [32, 115, 101, 100], [32, 101, 105, 117, 115], [46], [32, 65, 109, 101, 116], [32, 101, 105, 117, 115], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [32, 110, 101, 256, 101], [46], [32, 68, 259, 257, 101], [32, 108, 97, 98, 257, 101], [32, 108, 97, 98, 257, 101], [32, 100, 259, 257, 101, 109], [32, 115, 101, 100], [32, 101, 115, 116], [46, 10, 10], [81, 117, 105, 256, 105, 97], [32, 256, 105, 115, 256, 258], [32, 100, 259, 257, 101], [32, 112, 257, 114, 111], [46], [32, 68, 259, 257, 101], [32, 110, 101, 256, 101], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [32, 256, 105, 115, 256, 258], [46], [32, 69, 105, 117, 115], [32, 115, 101, 100], [32, 105, 112, 115, 117, 109], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 117, 116], [32, 117, 116], [32, 97, 108, 105, 256, 258], [32, 101, 105, 117, 115], [46], [32, 86, 101, 108, 105, 116], [32, 105, 112, 115, 117, 109], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [46], [32, 68, 259, 257, 101, 109], [32, 256, 97, 101, 114, 97, 116], [32, 115, 105, 116], [32, 105, 112, 115, 117, 109], [46], [32, 81, 117, 105, 115, 256, 258], [32, 110, 111, 110], [32, 109, 97, 103, 110, 258], [32, 256, 105, 115, 256, 258], [32, 110, 101, 256, 101], [46], [32, 69, 115, 116], [32, 100, 259, 257], [32, 101, 105, 117, 115], [32, 116, 101, 109, 112, 257, 97], [32, 112, 257, 114, 111], [46], [32, 69, 115, 116], [32, 116, 101, 109, 112, 257, 97], [32, 256, 97, 101, 114, 97, 116], [32, 109, 111, 100, 105], [32, 256, 97, 101, 114, 97, 116], [32, 109, 97, 103, 110, 258], [32, 108, 97, 98, 257, 101], [32, 101, 105, 117, 115], [46], [32, 78, 117, 109, 256, 258], [32, 110, 111, 110], [32, 258, 101, 116], [32, 117, 116], [32, 97, 108, 105, 256, 258], [46], [32, 68, 259, 257], [32, 256, 105, 115, 256, 258], [32, 100, 259, 257, 101], [32, 118, 101, 108, 105, 116], [46]]

Decoding with chunks

decoded_tokens = [decode(tc) for tc in encoded_tokens]
decoded_text = "".join(decoded_tokens)

print(decoded_text)
Eius eius dm t  tempt a consectetur sed. Amet quaerat modi aliquor adipisci oret dm t em eius. Quiquia adipisci pt ro dm t em sit quisquor sit pt ro. Eius neque quaerat est velit adipisci ut. Sit modi ipsum est dm t .

Consectetur oret magnor consectetur labt e labt e. Est velit aliquor tempt a neque pt ro consectetur magnor. Pt ro etincidunt vm uptatem quisquor. Labt e quaerat dm t em sit oret aliquor sed eius. Amet eius consectetur magnor est neque. Dm t e labt e labt e dm t em sed est.

Quiquia quisquor dm t e pt ro. Dm t e neque magnor est quisquor. Eius sed ipsum vm uptatem ut ut aliquor eius. Velit ipsum magnor est. Dm t em quaerat sit ipsum. Quisquor non magnor quisquor neque. Est dm t  eius tempt a pt ro. Est tempt a quaerat modi quaerat magnor labt e eius. Numquor non oret ut aliquor. Dm t  quisquor dm t e velit.

As a class


ChunkedBPE


def ChunkedBPE(
    vocab_size:int
):

Initialize self. See help(type(self)) for accurate signature.

text = lorem.text()

print(text)
Sit est neque etincidunt dolorem magnam non. Eius magnam quaerat labore. Etincidunt sit etincidunt adipisci voluptatem sed magnam est. Quaerat sed ut tempora. Quiquia tempora non voluptatem. Ut velit eius quiquia velit labore. Quisquam dolorem quiquia est sed. Quisquam consectetur quisquam quisquam aliquam. Velit neque aliquam quaerat labore tempora. Amet etincidunt ipsum tempora modi.

Dolor velit porro labore numquam. Est voluptatem dolore est voluptatem non velit etincidunt. Velit consectetur neque amet dolor ut. Sed tempora sed magnam velit. Labore amet velit magnam adipisci est porro consectetur. Labore sed tempora sed est quaerat magnam.

Consectetur sed eius non adipisci quiquia. Ipsum quiquia eius quisquam amet quisquam voluptatem neque. Quisquam modi consectetur dolor aliquam aliquam. Non consectetur consectetur eius porro dolor. Neque labore sed tempora sit porro modi. Tempora tempora dolor modi quisquam consectetur voluptatem non. Modi est aliquam sit labore dolorem neque neque. Labore consectetur sed labore porro. Aliquam magnam est dolorem consectetur voluptatem. Amet sed dolorem quisquam quisquam quiquia adipisci.

Eius dolor velit quaerat dolorem amet. Dolorem quaerat dolorem voluptatem eius ipsum ut dolor. Adipisci est dolor dolor porro est velit dolore. Sed adipisci dolor adipisci. Velit velit eius quaerat ipsum adipisci.

Aliquam dolor tempora modi numquam consectetur. Non neque ut labore. Quaerat consectetur neque numquam eius quiquia aliquam tempora. Etincidunt neque quisquam adipisci aliquam non magnam modi. Non sit neque amet. Ipsum quiquia dolor amet.
chunked_bpe = ChunkedBPE(vocab_size=260)

chunked_bpe.train(text)
encoded_text = chunked_bpe.encode(text)
print(encoded_text)
[[83, 105, 116], [32, 101, 115, 116], [32, 110, 101, 256, 101], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 100, 259, 257, 101, 109], [32, 109, 97, 103, 110, 258], [32, 110, 111, 110], [46], [32, 69, 105, 117, 115], [32, 109, 97, 103, 110, 258], [32, 256, 97, 101, 114, 97, 116], [32, 108, 97, 98, 257, 101], [46], [32, 69, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 115, 105, 116], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 115, 101, 100], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [46], [32, 81, 117, 97, 101, 114, 97, 116], [32, 115, 101, 100], [32, 117, 116], [32, 116, 101, 109, 112, 257, 97], [46], [32, 81, 117, 105, 256, 105, 97], [32, 116, 101, 109, 112, 257, 97], [32, 110, 111, 110], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [46], [32, 85, 116], [32, 118, 101, 108, 105, 116], [32, 101, 105, 117, 115], [32, 256, 105, 256, 105, 97], [32, 118, 101, 108, 105, 116], [32, 108, 97, 98, 257, 101], [46], [32, 81, 117, 105, 115, 256, 258], [32, 100, 259, 257, 101, 109], [32, 256, 105, 256, 105, 97], [32, 101, 115, 116], [32, 115, 101, 100], [46], [32, 81, 117, 105, 115, 256, 258], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 256, 105, 115, 256, 258], [32, 256, 105, 115, 256, 258], [32, 97, 108, 105, 256, 258], [46], [32, 86, 101, 108, 105, 116], [32, 110, 101, 256, 101], [32, 97, 108, 105, 256, 258], [32, 256, 97, 101, 114, 97, 116], [32, 108, 97, 98, 257, 101], [32, 116, 101, 109, 112, 257, 97], [46], [32, 65, 109, 101, 116], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 105, 112, 115, 117, 109], [32, 116, 101, 109, 112, 257, 97], [32, 109, 111, 100, 105], [46, 10, 10], [68, 259, 257], [32, 118, 101, 108, 105, 116], [32, 112, 257, 114, 111], [32, 108, 97, 98, 257, 101], [32, 110, 117, 109, 256, 258], [46], [32, 69, 115, 116], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 100, 259, 257, 101], [32, 101, 115, 116], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 110, 111, 110], [32, 118, 101, 108, 105, 116], [32, 101, 116, 105, 110, 99, 105, 100, 117, 110, 116], [46], [32, 86, 101, 108, 105, 116], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 110, 101, 256, 101], [32, 258, 101, 116], [32, 100, 259, 257], [32, 117, 116], [46], [32, 83, 101, 100], [32, 116, 101, 109, 112, 257, 97], [32, 115, 101, 100], [32, 109, 97, 103, 110, 258], [32, 118, 101, 108, 105, 116], [46], [32, 76, 97, 98, 257, 101], [32, 258, 101, 116], [32, 118, 101, 108, 105, 116], [32, 109, 97, 103, 110, 258], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 101, 115, 116], [32, 112, 257, 114, 111], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [46], [32, 76, 97, 98, 257, 101], [32, 115, 101, 100], [32, 116, 101, 109, 112, 257, 97], [32, 115, 101, 100], [32, 101, 115, 116], [32, 256, 97, 101, 114, 97, 116], [32, 109, 97, 103, 110, 258], [46, 10, 10], [67, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 115, 101, 100], [32, 101, 105, 117, 115], [32, 110, 111, 110], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 256, 105, 256, 105, 97], [46], [32, 73, 112, 115, 117, 109], [32, 256, 105, 256, 105, 97], [32, 101, 105, 117, 115], [32, 256, 105, 115, 256, 258], [32, 258, 101, 116], [32, 256, 105, 115, 256, 258], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 110, 101, 256, 101], [46], [32, 81, 117, 105, 115, 256, 258], [32, 109, 111, 100, 105], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 100, 259, 257], [32, 97, 108, 105, 256, 258], [32, 97, 108, 105, 256, 258], [46], [32, 78, 111, 110], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 101, 105, 117, 115], [32, 112, 257, 114, 111], [32, 100, 259, 257], [46], [32, 78, 101, 256, 101], [32, 108, 97, 98, 257, 101], [32, 115, 101, 100], [32, 116, 101, 109, 112, 257, 97], [32, 115, 105, 116], [32, 112, 257, 114, 111], [32, 109, 111, 100, 105], [46], [32, 84, 101, 109, 112, 257, 97], [32, 116, 101, 109, 112, 257, 97], [32, 100, 259, 257], [32, 109, 111, 100, 105], [32, 256, 105, 115, 256, 258], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 110, 111, 110], [46], [32, 77, 111, 100, 105], [32, 101, 115, 116], [32, 97, 108, 105, 256, 258], [32, 115, 105, 116], [32, 108, 97, 98, 257, 101], [32, 100, 259, 257, 101, 109], [32, 110, 101, 256, 101], [32, 110, 101, 256, 101], [46], [32, 76, 97, 98, 257, 101], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 115, 101, 100], [32, 108, 97, 98, 257, 101], [32, 112, 257, 114, 111], [46], [32, 65, 108, 105, 256, 258], [32, 109, 97, 103, 110, 258], [32, 101, 115, 116], [32, 100, 259, 257, 101, 109], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [46], [32, 65, 109, 101, 116], [32, 115, 101, 100], [32, 100, 259, 257, 101, 109], [32, 256, 105, 115, 256, 258], [32, 256, 105, 115, 256, 258], [32, 256, 105, 256, 105, 97], [32, 97, 100, 105, 112, 105, 115, 99, 105], [46, 10, 10], [69, 105, 117, 115], [32, 100, 259, 257], [32, 118, 101, 108, 105, 116], [32, 256, 97, 101, 114, 97, 116], [32, 100, 259, 257, 101, 109], [32, 258, 101, 116], [46], [32, 68, 259, 257, 101, 109], [32, 256, 97, 101, 114, 97, 116], [32, 100, 259, 257, 101, 109], [32, 118, 259, 117, 112, 116, 97, 116, 101, 109], [32, 101, 105, 117, 115], [32, 105, 112, 115, 117, 109], [32, 117, 116], [32, 100, 259, 257], [46], [32, 65, 100, 105, 112, 105, 115, 99, 105], [32, 101, 115, 116], [32, 100, 259, 257], [32, 100, 259, 257], [32, 112, 257, 114, 111], [32, 101, 115, 116], [32, 118, 101, 108, 105, 116], [32, 100, 259, 257, 101], [46], [32, 83, 101, 100], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 100, 259, 257], [32, 97, 100, 105, 112, 105, 115, 99, 105], [46], [32, 86, 101, 108, 105, 116], [32, 118, 101, 108, 105, 116], [32, 101, 105, 117, 115], [32, 256, 97, 101, 114, 97, 116], [32, 105, 112, 115, 117, 109], [32, 97, 100, 105, 112, 105, 115, 99, 105], [46, 10, 10], [65, 108, 105, 256, 258], [32, 100, 259, 257], [32, 116, 101, 109, 112, 257, 97], [32, 109, 111, 100, 105], [32, 110, 117, 109, 256, 258], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [46], [32, 78, 111, 110], [32, 110, 101, 256, 101], [32, 117, 116], [32, 108, 97, 98, 257, 101], [46], [32, 81, 117, 97, 101, 114, 97, 116], [32, 99, 111, 110, 115, 101, 99, 116, 101, 116, 117, 114], [32, 110, 101, 256, 101], [32, 110, 117, 109, 256, 258], [32, 101, 105, 117, 115], [32, 256, 105, 256, 105, 97], [32, 97, 108, 105, 256, 258], [32, 116, 101, 109, 112, 257, 97], [46], [32, 69, 116, 105, 110, 99, 105, 100, 117, 110, 116], [32, 110, 101, 256, 101], [32, 256, 105, 115, 256, 258], [32, 97, 100, 105, 112, 105, 115, 99, 105], [32, 97, 108, 105, 256, 258], [32, 110, 111, 110], [32, 109, 97, 103, 110, 258], [32, 109, 111, 100, 105], [46], [32, 78, 111, 110], [32, 115, 105, 116], [32, 110, 101, 256, 101], [32, 258, 101, 116], [46], [32, 73, 112, 115, 117, 109], [32, 256, 105, 256, 105, 97], [32, 100, 259, 257], [32, 258, 101, 116], [46]]
decoded_text = chunked_bpe.decode(encoded_text)
print(decoded_text)
Sit est neque etincidunt dolorem magnam non. Eius magnam quaerat labore. Etincidunt sit etincidunt adipisci voluptatem sed magnam est. Quaerat sed ut tempora. Quiquia tempora non voluptatem. Ut velit eius quiquia velit labore. Quisquam dolorem quiquia est sed. Quisquam consectetur quisquam quisquam aliquam. Velit neque aliquam quaerat labore tempora. Amet etincidunt ipsum tempora modi.

Dolor velit porro labore numquam. Est voluptatem dolore est voluptatem non velit etincidunt. Velit consectetur neque amet dolor ut. Sed tempora sed magnam velit. Labore amet velit magnam adipisci est porro consectetur. Labore sed tempora sed est quaerat magnam.

Consectetur sed eius non adipisci quiquia. Ipsum quiquia eius quisquam amet quisquam voluptatem neque. Quisquam modi consectetur dolor aliquam aliquam. Non consectetur consectetur eius porro dolor. Neque labore sed tempora sit porro modi. Tempora tempora dolor modi quisquam consectetur voluptatem non. Modi est aliquam sit labore dolorem neque neque. Labore consectetur sed labore porro. Aliquam magnam est dolorem consectetur voluptatem. Amet sed dolorem quisquam quisquam quiquia adipisci.

Eius dolor velit quaerat dolorem amet. Dolorem quaerat dolorem voluptatem eius ipsum ut dolor. Adipisci est dolor dolor porro est velit dolore. Sed adipisci dolor adipisci. Velit velit eius quaerat ipsum adipisci.

Aliquam dolor tempora modi numquam consectetur. Non neque ut labore. Quaerat consectetur neque numquam eius quiquia aliquam tempora. Etincidunt neque quisquam adipisci aliquam non magnam modi. Non sit neque amet. Ipsum quiquia dolor amet.

Tests

chunked_bpe_test = ChunkedBPE(vocab_size=280)
chunked_bpe_test.train(text)

# Vocab has the right size
assert len(chunked_bpe_test.vocab) == 280, f"Expected 280, got {len(chunked_bpe_test.vocab)}"

# Base vocab entries are single bytes
assert chunked_bpe_test.vocab[65] == b"A"
assert chunked_bpe_test.vocab[32] == b" "

# Merged vocab entries are byte concatenations of their component tokens
for pair, idx in chunked_bpe_test.merges.items():
    assert chunked_bpe_test.vocab[idx] == chunked_bpe_test.vocab[pair[0]] + chunked_bpe_test.vocab[pair[1]], (
        f"vocab[{idx}] mismatch for pair {pair}"
    )

# Encoding compresses: total tokens across chunks should be fewer than raw UTF-8 bytes
raw_len = len(list(map(int, text.encode("utf-8"))))
encoded = chunked_bpe_test.encode(text)
total_encoded = sum(len(chunk) for chunk in encoded)
assert total_encoded < raw_len, f"Expected compression: {total_encoded} >= {raw_len}"

# Round-trip: decode(encode(text)) == text
assert chunked_bpe_test.decode(chunked_bpe_test.encode(text)) == text, "Round-trip failed"

# Merges never cross chunk boundaries: each encoded chunk independently decodes to its source chunk
split_pattern = chunked_bpe_test.split_pattern
text_chunks = re.findall(split_pattern, text)
for chunk_text, chunk_tokens in zip(text_chunks, chunked_bpe_test.encode(text)):
    assert chunked_bpe_test._decode_single(chunk_tokens) == chunk_text, f"Chunk round-trip failed for {chunk_text!r}"

# Encoding unseen text produces a valid, decodable token sequence
unseen = "Hello, world! This is a test."
encoded_unseen = chunked_bpe_test.encode(unseen)
assert isinstance(encoded_unseen, list)
assert all(isinstance(chunk, list) for chunk in encoded_unseen)
assert all(isinstance(t, int) for chunk in encoded_unseen for t in chunk)
assert chunked_bpe_test.decode(encoded_unseen) == unseen, "Unseen text round-trip failed"

# Empty string encodes and decodes cleanly
assert chunked_bpe_test.encode("") == []
assert chunked_bpe_test.decode([]) == ""
All ChunkedBPE tests passed.

Full BPE

The final bit we will add is a list of special tokens. We use these for things like marking the end of a sequence of text, or for masking parts of text which is something we do in certain training routines. Some models like the AWD-LSTM also use special tokens to mark things like “the character after this is being repeated n times”.


BPE


def BPE(
    vocab_size:int
):

Initialize self. See help(type(self)) for accurate signature.

text = lorem.text()

print(text)
Velit dolor amet sit ut. Aliquam etincidunt dolorem aliquam velit dolorem aliquam. Etincidunt numquam magnam voluptatem porro non. Aliquam ut quisquam dolorem etincidunt labore. Quisquam ipsum est consectetur. Modi magnam tempora neque magnam ipsum. Non sit non labore velit. Est velit amet tempora modi amet modi quiquia. Porro etincidunt ipsum consectetur amet. Sed etincidunt etincidunt neque quiquia sit.

Sit neque labore neque aliquam. Quiquia modi est quisquam velit aliquam. Sed quisquam magnam labore. Velit tempora sed dolore adipisci dolor quiquia sed. Aliquam dolor amet eius ut quisquam labore. Magnam quaerat aliquam quiquia dolor sed. Labore labore neque ut amet quaerat dolor. Labore magnam ipsum quisquam est eius consectetur adipisci.

Non numquam magnam porro consectetur. Ut est sit neque dolorem quiquia. Sit non dolorem consectetur dolor amet. Dolore sit neque amet velit sed ipsum. Sit amet porro porro etincidunt etincidunt. Sit sit sit quisquam dolore. Numquam eius consectetur dolorem velit dolor. Velit non adipisci labore porro. Labore neque velit porro dolor quiquia amet amet. Magnam sit tempora non numquam modi.

Amet dolor ipsum quisquam amet neque neque sed. Labore dolor voluptatem ut porro quisquam quaerat dolore. Modi numquam aliquam voluptatem consectetur. Dolor etincidunt dolore dolore ipsum sit eius labore. Etincidunt amet est dolorem quisquam porro est dolor. Quiquia porro amet numquam quisquam dolore. Quaerat dolor dolore sed quiquia tempora quiquia ipsum. Aliquam voluptatem adipisci amet aliquam quisquam ut quaerat.

Ut quaerat modi quiquia quisquam. Non quisquam porro etincidunt. Non eius velit velit sed magnam. Neque amet sit aliquam dolore neque. Non neque dolorem ipsum quiquia dolor ipsum. Tempora labore magnam dolor ut labore. Porro consectetur ipsum eius neque consectetur est neque.

Consectetur aliquam etincidunt magnam quiquia. Eius dolore voluptatem ipsum quaerat tempora ipsum. Ut dolore non numquam neque. Etincidunt non adipisci amet magnam numquam neque. Dolorem quaerat labore aliquam aliquam. Non quiquia voluptatem quisquam sed non.
bpe = BPE(vocab_size=300)
bpe.train(text)
encoded_text = bpe.encode(text)
print(encoded_text)
[[86, 101, 267, 116], [265], [276], [272, 289], [32, 117, 116], [46], [32, 65, 282], [32, 261, 105, 110, 273, 100, 117, 110, 116], [265, 270], [280, 282], [283, 101, 267, 116], [265, 270], [280, 282], [46], [32, 69, 116, 105, 110, 273, 100, 117, 110, 116], [266, 271, 262], [284, 97, 103, 110, 258], [283, 259, 117, 112, 116, 290, 270], [32, 277, 114, 111], [266, 269], [46], [32, 65, 282], [32, 117, 116], [291], [265, 270], [32, 261, 105, 110, 273, 100, 117, 110, 116], [32, 108, 288], [46], [32, 81, 117, 105, 285], [32, 274, 115, 271], [32, 101, 115, 116], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [46], [32, 77, 111, 100, 105], [284, 97, 103, 110, 258], [32, 116, 270, 277, 97], [281], [284, 97, 103, 110, 258], [32, 274, 115, 271], [46], [32, 78, 269], [272, 289], [266, 269], [32, 108, 288], [283, 101, 267, 116], [46], [32, 69, 115, 116], [283, 101, 267, 116], [276], [32, 116, 270, 277, 97], [284, 111, 100, 105], [276], [284, 111, 100, 105], [268, 260, 97], [46], [32, 80, 257, 114, 111], [32, 261, 105, 110, 273, 100, 117, 110, 116], [32, 274, 115, 271], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [276], [46], [32, 83, 101, 100], [32, 261, 105, 110, 273, 100, 117, 110, 116], [32, 261, 105, 110, 273, 100, 117, 110, 116], [281], [268, 260, 97], [272, 289], [46, 10, 10], [83, 289], [281], [32, 108, 288], [281], [280, 282], [46], [32, 81, 117, 105, 260, 97], [284, 111, 100, 105], [32, 101, 115, 116], [291], [283, 101, 267, 116], [280, 282], [46], [32, 83, 101, 100], [291], [284, 97, 103, 110, 258], [32, 108, 288], [46], [32, 86, 101, 267, 116], [32, 116, 270, 277, 97], [272, 101, 100], [265, 101], [280, 100, 274, 105, 115, 273], [265], [268, 260, 97], [272, 101, 100], [46], [32, 65, 282], [265], [276], [32, 101, 105, 117, 115], [32, 117, 116], [291], [32, 108, 288], [46], [32, 77, 97, 103, 110, 258], [32, 256, 97, 101, 114, 290], [280, 282], [268, 260, 97], [265], [272, 101, 100], [46], [32, 76, 288], [32, 108, 288], [281], [32, 117, 116], [276], [32, 256, 97, 101, 114, 290], [265], [46], [32, 76, 288], [284, 97, 103, 110, 258], [32, 274, 115, 271], [291], [32, 101, 115, 116], [32, 101, 105, 117, 115], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [280, 100, 274, 105, 115, 273], [46, 10, 10], [78, 269], [266, 271, 262], [284, 97, 103, 110, 258], [32, 277, 114, 111], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [46], [32, 85, 116], [32, 101, 115, 116], [272, 289], [281], [265, 270], [268, 260, 97], [46], [32, 83, 289], [266, 269], [265, 270], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [265], [276], [46], [32, 68, 263, 101], [272, 289], [281], [276], [283, 101, 267, 116], [272, 101, 100], [32, 274, 115, 271], [46], [32, 83, 289], [276], [32, 277, 114, 111], [32, 277, 114, 111], [32, 261, 105, 110, 273, 100, 117, 110, 116], [32, 261, 105, 110, 273, 100, 117, 110, 116], [46], [32, 83, 289], [272, 289], [272, 289], [291], [265, 101], [46], [32, 78, 271, 262], [32, 101, 105, 117, 115], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [265, 270], [283, 101, 267, 116], [265], [46], [32, 86, 101, 267, 116], [266, 269], [280, 100, 274, 105, 115, 273], [32, 108, 288], [32, 277, 114, 111], [46], [32, 76, 288], [281], [283, 101, 267, 116], [32, 277, 114, 111], [265], [268, 260, 97], [276], [276], [46], [32, 77, 97, 103, 110, 258], [272, 289], [32, 116, 270, 277, 97], [266, 269], [266, 271, 262], [284, 111, 100, 105], [46, 10, 10], [65, 109, 261], [265], [32, 274, 115, 271], [291], [276], [281], [281], [272, 101, 100], [46], [32, 76, 288], [265], [283, 259, 117, 112, 116, 290, 270], [32, 117, 116], [32, 277, 114, 111], [291], [32, 256, 97, 101, 114, 290], [265, 101], [46], [32, 77, 111, 100, 105], [266, 271, 262], [280, 282], [283, 259, 117, 112, 116, 290, 270], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [46], [32, 68, 263], [32, 261, 105, 110, 273, 100, 117, 110, 116], [265, 101], [265, 101], [32, 274, 115, 271], [272, 289], [32, 101, 105, 117, 115], [32, 108, 288], [46], [32, 69, 116, 105, 110, 273, 100, 117, 110, 116], [276], [32, 101, 115, 116], [265, 270], [291], [32, 277, 114, 111], [32, 101, 115, 116], [265], [46], [32, 81, 117, 105, 260, 97], [32, 277, 114, 111], [276], [266, 271, 262], [291], [265, 101], [46], [32, 81, 117, 97, 101, 114, 290], [265], [265, 101], [272, 101, 100], [268, 260, 97], [32, 116, 270, 277, 97], [268, 260, 97], [32, 274, 115, 271], [46], [32, 65, 282], [283, 259, 117, 112, 116, 290, 270], [280, 100, 274, 105, 115, 273], [276], [280, 282], [291], [32, 117, 116], [32, 256, 97, 101, 114, 290], [46, 10, 10], [85, 116], [32, 256, 97, 101, 114, 290], [284, 111, 100, 105], [268, 260, 97], [291], [46], [32, 78, 269], [291], [32, 277, 114, 111], [32, 261, 105, 110, 273, 100, 117, 110, 116], [46], [32, 78, 269], [32, 101, 105, 117, 115], [283, 101, 267, 116], [283, 101, 267, 116], [272, 101, 100], [284, 97, 103, 110, 258], [46], [32, 78, 279], [276], [272, 289], [280, 282], [265, 101], [281], [46], [32, 78, 269], [281], [265, 270], [32, 274, 115, 271], [268, 260, 97], [265], [32, 274, 115, 271], [46], [32, 84, 270, 277, 97], [32, 108, 288], [284, 97, 103, 110, 258], [265], [32, 117, 116], [32, 108, 288], [46], [32, 80, 257, 114, 111], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [32, 274, 115, 271], [32, 101, 105, 117, 115], [281], [32, 99, 269, 115, 101, 99, 116, 261, 117, 114], [32, 101, 115, 116], [281], [46, 10, 10], [67, 269, 115, 101, 99, 116, 261, 117, 114], [280, 282], [32, 261, 105, 110, 273, 100, 117, 110, 116], [284, 97, 103, 110, 258], [268, 260, 97], [46], [32, 69, 105, 117, 115], [265, 101], [283, 259, 117, 112, 116, 290, 270], [32, 274, 115, 271], [32, 256, 97, 101, 114, 290], [32, 116, 270, 277, 97], [32, 274, 115, 271], [46], [32, 85, 116], [265, 101], [266, 269], [266, 271, 262], [281], [46], [32, 69, 116, 105, 110, 273, 100, 117, 110, 116], [266, 269], [280, 100, 274, 105, 115, 273], [276], [284, 97, 103, 110, 258], [266, 271, 262], [281], [46], [32, 68, 263, 270], [32, 256, 97, 101, 114, 290], [32, 108, 288], [280, 282], [280, 282], [46], [32, 78, 269], [268, 260, 97], [283, 259, 117, 112, 116, 290, 270], [291], [272, 101, 100], [266, 269], [46]]
decoded_text = bpe.decode(encoded_text)
print(decoded_text)
Velit dolor amet sit ut. Aliquam etincidunt dolorem aliquam velit dolorem aliquam. Etincidunt numquam magnam voluptatem porro non. Aliquam ut quisquam dolorem etincidunt labore. Quisquam ipsum est consectetur. Modi magnam tempora neque magnam ipsum. Non sit non labore velit. Est velit amet tempora modi amet modi quiquia. Porro etincidunt ipsum consectetur amet. Sed etincidunt etincidunt neque quiquia sit.

Sit neque labore neque aliquam. Quiquia modi est quisquam velit aliquam. Sed quisquam magnam labore. Velit tempora sed dolore adipisci dolor quiquia sed. Aliquam dolor amet eius ut quisquam labore. Magnam quaerat aliquam quiquia dolor sed. Labore labore neque ut amet quaerat dolor. Labore magnam ipsum quisquam est eius consectetur adipisci.

Non numquam magnam porro consectetur. Ut est sit neque dolorem quiquia. Sit non dolorem consectetur dolor amet. Dolore sit neque amet velit sed ipsum. Sit amet porro porro etincidunt etincidunt. Sit sit sit quisquam dolore. Numquam eius consectetur dolorem velit dolor. Velit non adipisci labore porro. Labore neque velit porro dolor quiquia amet amet. Magnam sit tempora non numquam modi.

Amet dolor ipsum quisquam amet neque neque sed. Labore dolor voluptatem ut porro quisquam quaerat dolore. Modi numquam aliquam voluptatem consectetur. Dolor etincidunt dolore dolore ipsum sit eius labore. Etincidunt amet est dolorem quisquam porro est dolor. Quiquia porro amet numquam quisquam dolore. Quaerat dolor dolore sed quiquia tempora quiquia ipsum. Aliquam voluptatem adipisci amet aliquam quisquam ut quaerat.

Ut quaerat modi quiquia quisquam. Non quisquam porro etincidunt. Non eius velit velit sed magnam. Neque amet sit aliquam dolore neque. Non neque dolorem ipsum quiquia dolor ipsum. Tempora labore magnam dolor ut labore. Porro consectetur ipsum eius neque consectetur est neque.

Consectetur aliquam etincidunt magnam quiquia. Eius dolore voluptatem ipsum quaerat tempora ipsum. Ut dolore non numquam neque. Etincidunt non adipisci amet magnam numquam neque. Dolorem quaerat labore aliquam aliquam. Non quiquia voluptatem quisquam sed non.