def get_stats(chars):
stats = {}
for i in range(len(chars)-1):
stats[(chars[i], chars[i+1])] = stats.get((chars[i], chars[i+1]), 0) + 1
stats = dict(sorted(stats.items(), key=lambda item: item[1], reverse=True))
return stats Understanding Byte Pair Encoding: Part 3: the Algorithm
I wrote about encodings and the basics of tokenization in my two earlier posts, so in this post, I will dig into the actual algorithm of byte-pair encoding (BPE). In the paper Language Models are Unsupervised Multitask Learners, which introduces GPT2, the authors note they use BPE at the byte level and that some preprocessing improves results by explicitly avoiding merges across character categories. It would seem, then, that interpreting text as a sequence of bytes, and not as a sequence of Unicode code points, is at the heart of the BPE method.
I’ll get into some of these preprocessing details in another post, but, for now, I just want to get an idea of how this works.
The paper that introduced BPE states that “The algorithm compresses data by finding the most frequently occurring pairs of adjacent bytes in the data and replacing all instances of the pair with a byte that was not in the original data. The algorithm repeats this process until no further compression is possible, …”
Character level
Using the example text ababcabcd given in the paper noted above, I will outline the basic process. To do this, however, I need a way to find the most frequently occurring pairs:
I will work at the character level for now, since it is easier to see how this works. Later, I will switch over to using bytes. To begin, I’ll look at the number of unique tokens and note that no pairs have yet been merged:
text = 'ababcabcd'
tokens = list(sorted(set(text)))
merges = {}
print(f"Current text: {text}")
print(f"Number of characters in current text: {len(text)}")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")
print(f"Pairs merged: {merges}")Current text: ababcabcd
Number of characters in current text: 9
Tokens: ['a', 'b', 'c', 'd']
Number of tokens: 4
Pairs merged: {}
Now, I determine the frequency of occurrence for all the pairs of adjacent characters in the text:
stats = get_stats(text)
stats{('a', 'b'): 3, ('b', 'c'): 2, ('b', 'a'): 1, ('c', 'a'): 1, ('c', 'd'): 1}
Since a and b occur together most frequently, we will merge those. To do this, we create a new character (one that doesn’t currently exist in our set of unique tokens), replace all occurrences of ab with the new character and keep track of the merge in the merges dictionary:
text_1 = text.replace('ab', 'X')
tokens = list(sorted(set(text + text_1)))
merges[('a', 'b')] = 'X'
print(f"Current text: {text_1}")
print(f"Number of characters in current text: {len(text_1)}")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")
print(f"Pairs merged: {merges}")Current text: XXcXcd
Number of characters in current text: 6
Tokens: ['X', 'a', 'b', 'c', 'd']
Number of tokens: 5
Pairs merged: {('a', 'b'): 'X'}
I now repeat the process:
stats = get_stats(text_1)
stats{('X', 'c'): 2, ('X', 'X'): 1, ('c', 'X'): 1, ('c', 'd'): 1}
This time, X and c are the most frequent pair.
text_2 = text_1.replace('Xc', 'Y')
tokens = list(sorted(set(text + text_1 + text_2)))
merges[('X', 'c')] = 'Y'
print(f"Current text: {text_2}")
print(f"Number of characters in current text: {len(text_2)}")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")
print(f"Pairs merged: {merges}")Current text: XYYd
Number of characters in current text: 4
Tokens: ['X', 'Y', 'a', 'b', 'c', 'd']
Number of tokens: 6
Pairs merged: {('a', 'b'): 'X', ('X', 'c'): 'Y'}
And again, I check the stats:
stats = get_stats(text_2)
stats{('X', 'Y'): 1, ('Y', 'Y'): 1, ('Y', 'd'): 1}
No further compression can happen now through the merging of pairs because if I merge, say, X and Y, since the pair only occurs once I achieve no compression because I also have to add a new token, say, R = XY, to my list of unique tokens. Thus, the number of characters in the text will go down by 1, but the number of tokens will go up by 1; there is no benefit to merging pairs that have a frequency of 1. So, I stop the process.
Byte level
For clarity, at least to start, I kept this process at the character level. To get the method that was used for GPT2, I will repeat a slightly modified version of what I did above, but at the byte level. I discussed UTF-8 encoding in Part 1, but to make it easier to rember that we are working at the byte level, I have added some emojis to the original text:
text = '😄😄 ababcabcd 😄😄'
byte_text = text.encode('utf-8')
print(f"Text: {text}")
print(f"Number of characters in original text: {len(text)}")
print()
print(f"Raw bytes version of text: {byte_text}")
print(f"Number of bytes in raw bytes version of text: {len(byte_text)}")Text: 😄😄 ababcabcd 😄😄
Number of characters in original text: 15
Raw bytes version of text: b'\xf0\x9f\x98\x84\xf0\x9f\x98\x84 ababcabcd \xf0\x9f\x98\x84\xf0\x9f\x98\x84'
Number of bytes in raw bytes version of text: 27
The b in front of the string let’s me know that this is a Python bytes object:
type(byte_text)bytes
The UTF-8 encoding scheme allows any character to be represented by a sequence of 1, 2, 3, or 4 bytes. In the text here, the emoji requires 4 bytes but the lowercase English letters only need 1 byte each. I know each byte has a value between 0 and 255, so it will be easier to use the equivalent values instead of the raw bytes:
text_values = list(map(int, byte_text))
print(text_values)[240, 159, 152, 132, 240, 159, 152, 132, 32, 97, 98, 97, 98, 99, 97, 98, 99, 100, 32, 240, 159, 152, 132, 240, 159, 152, 132]
It is important that I remember that the values here are not Unicode code points, they are byte values. The fact that Unicode characters with code points below 256 have the same value is simply a result of those characters can be represented by 1 byte and all such character code points equal to their byte value. To remember this, I look at the emoji, which is 4 bytes, and it’s Unicode code point is 128516. which is not equal to the ‘240, 159, 152, 132’ values as seen in text_values:
ord('a'), ord('😄')(97, 128516)
I will now repeat the process I did above:
tokens = list(sorted(set(text_values)))
merges = {}
print(f"Current text as byte values: {text_values}")
print(f"Tokens: {tokens}")
print(f"Pairs merged: {merges}")Current text as byte values: [240, 159, 152, 132, 240, 159, 152, 132, 32, 97, 98, 97, 98, 99, 97, 98, 99, 100, 32, 240, 159, 152, 132, 240, 159, 152, 132]
Tokens: [32, 97, 98, 99, 100, 132, 152, 159, 240]
Pairs merged: {}
And get the pair frequencies:
stats = get_stats(text_values)
stats{(240, 159): 4,
(159, 152): 4,
(152, 132): 4,
(97, 98): 3,
(132, 240): 2,
(98, 99): 2,
(132, 32): 1,
(32, 97): 1,
(98, 97): 1,
(99, 97): 1,
(99, 100): 1,
(100, 32): 1,
(32, 240): 1}
I have 3 pairs with the highest frequency, so I will pick the first that occurs, (240, 159), and merge those. Since I am dealing with byte values, instead of creating a new character not in my current set of tokens, I will create a new byte value for this merged pair. Since I am dealing in single bytes, the first available value will be 256.
To make this easier, I’ll introduce a function to do the merging:
def replace_pairs(text, pair, idx):
new_text = []
i = 0
while i < len(text):
if text[i] == pair[0] and i < len(text) - 1 and text[i + 1] == pair[1]:
new_text.append(idx)
i += 2
else:
new_text.append(text[i])
i += 1
return new_textAnd now I do the merge:
text_values = replace_pairs(text_values, (240, 159), 256)
tokens.append(256)
merges[(240, 159)] = 256
print(f"Current text as byte values: {text_values}")
print(f"Tokens: {tokens}")
print(f"Pairs merged: {merges}")Current text as byte values: [256, 152, 132, 256, 152, 132, 32, 97, 98, 97, 98, 99, 97, 98, 99, 100, 32, 256, 152, 132, 256, 152, 132]
Tokens: [32, 97, 98, 99, 100, 132, 152, 159, 240, 256]
Pairs merged: {(240, 159): 256}
Check stats:
stats = get_stats(text_values)
stats{(256, 152): 4,
(152, 132): 4,
(97, 98): 3,
(132, 256): 2,
(98, 99): 2,
(132, 32): 1,
(32, 97): 1,
(98, 97): 1,
(99, 97): 1,
(99, 100): 1,
(100, 32): 1,
(32, 256): 1}
Merge:
text_values = replace_pairs(text_values, (256, 152), 257)
tokens.append(257)
merges[(256, 152)] = 257
print(f"Current text as byte values: {text_values}")
print(f"Tokens: {tokens}")
print(f"Pairs merged: {merges}")Current text as byte values: [257, 132, 257, 132, 32, 97, 98, 97, 98, 99, 97, 98, 99, 100, 32, 257, 132, 257, 132]
Tokens: [32, 97, 98, 99, 100, 132, 152, 159, 240, 256, 257]
Pairs merged: {(240, 159): 256, (256, 152): 257}
Check stats:
stats = get_stats(text_values)
stats{(257, 132): 4,
(97, 98): 3,
(132, 257): 2,
(98, 99): 2,
(132, 32): 1,
(32, 97): 1,
(98, 97): 1,
(99, 97): 1,
(99, 100): 1,
(100, 32): 1,
(32, 257): 1}
Merge:
text_values = replace_pairs(text_values, (257, 132), 258)
tokens.append(258)
merges[(257, 132)] = 258
print(f"Current text as byte values: {text_values}")
print(f"Tokens: {tokens}")
print(f"Pairs merged: {merges}")Current text as byte values: [258, 258, 32, 97, 98, 97, 98, 99, 97, 98, 99, 100, 32, 258, 258]
Tokens: [32, 97, 98, 99, 100, 132, 152, 159, 240, 256, 257, 258]
Pairs merged: {(240, 159): 256, (256, 152): 257, (257, 132): 258}
Check stats:
stats = get_stats(text_values)
stats{(97, 98): 3,
(258, 258): 2,
(98, 99): 2,
(258, 32): 1,
(32, 97): 1,
(98, 97): 1,
(99, 97): 1,
(99, 100): 1,
(100, 32): 1,
(32, 258): 1}
Merge:
text_values = replace_pairs(text_values, (97, 98), 259)
tokens.append(259)
merges[(97, 98)] = 259
print(f"Current text as byte values: {text_values}")
print(f"Tokens: {tokens}")
print(f"Pairs merged: {merges}")Current text as byte values: [258, 258, 32, 259, 259, 99, 259, 99, 100, 32, 258, 258]
Tokens: [32, 97, 98, 99, 100, 132, 152, 159, 240, 256, 257, 258, 259]
Pairs merged: {(240, 159): 256, (256, 152): 257, (257, 132): 258, (97, 98): 259}
I will stop here, even though there are a couple more merges I could do, as I think the process is pretty clear now. All of this could, of course, be cleaned up code-wise but I wanted to go step by step with a semi-manual process so that the algorithm would sink in.
What comes out of this BPE process is a set of tokens, which would be the vocabulary, and an ordered list of merges. With a sufficient amount of starting text, the vocabulary and merges would be sufficient to tokenize any text for input into a language model.
That’s it for now. In the next post I will go into some of the more nuanced modifications of BPE as done for GPT2.
Support
If you enjoy this blog and would like to support my work, you can buy me a cup of coffee!