Spaces:
Sleeping
Sleeping
File size: 8,798 Bytes
41dfb3a f049fd3 41dfb3a 1287e6c 41dfb3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import sys
import regex as re
from tqdm import tqdm
from .base import Tokenizer, get_stats, merge, merge_hindi
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
class BPETokenizer(Tokenizer):
def __init__(self, pattern=None, word_pattern = None):
super().__init__()
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.word_pattern = None
self.compiled_pattern_word = None
if word_pattern:
self.word_pattern = word_pattern
self.compiled_pattern_word = re.compile(self.word_pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
def build(self, text, vocab_size, verbose=False):
text_chunks = re.findall(self.compiled_pattern, text)
if self.compiled_pattern_word:
print("Spliting hindi words")
text_chunks_words = []
for chunk in tqdm(text_chunks):
element_chunks = re.findall(self.compiled_pattern_word, chunk)
if element_chunks == []:
text_chunks_words.append(chunk)
else:
text_chunks_words.extend(element_chunks[0])
text_chunks = text_chunks_words
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
merges = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab.update({idx: bytes(list(chr(value).encode('utf-8'))) for idx,value in zip(range(256, 384), range(2304, 2432))})
print("Merging hindi characters in single token")
for index in range(256, 384):
pair = list(vocab[index])
ids = [merge_hindi(chunk_ids, pair, index) for chunk_ids in ids]
num_merges = vocab_size - 384
original_length = len([x for xs in ids for x in xs])
print("Building BPE")
for i in tqdm(range(num_merges), file=sys.stdout):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 384 + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
try:
tqdm.write(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx].decode('utf-8')}) had {stats[pair]} occurrences")
except Exception as e:
tqdm.write(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
lenght_after_merging = len([x for xs in ids for x in xs])
print(f'Compression ratio: {original_length/lenght_after_merging}')
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, ids):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
if self.compiled_pattern_word:
print("Spliting hindi words")
text_chunks_words = []
for chunk in tqdm(text_chunks):
element_chunks = re.findall(self.compiled_pattern_word, chunk)
if element_chunks == []:
text_chunks_words.append(chunk)
else:
text_chunks_words.extend(element_chunks[0])
text_chunks = text_chunks_words
# all chunks of text are encoded separately, then results are joined
ids_list = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
ids = list(chunk_bytes)
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab.update({idx: bytes(list(chr(value).encode('utf-8'))) for idx,value in zip(range(256, 384), range(2304, 2432))})
for index in tqdm(range(256, 384)):
pair = list(vocab[index])
ids = merge_hindi(ids, pair, index)
chunk_ids = self._encode_chunk(ids)
ids_list.extend(chunk_ids)
return ids_list
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# otherwise, we have to be careful with potential special tokens in text
# we handle special tokens by splitting the text
# based on the occurrence of any exact match with any of the special tokens
# we can use re.split for this. note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
if part in special:
# this is a special token, encode it separately as a special case
ids.append(special[part])
else:
# this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
return ids |