Spaces:
Sleeping
Sleeping
File size: 10,543 Bytes
6582b49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import pandas as pd
import pyarrow.parquet as pq
import regex as re
from collections import Counter, defaultdict
import json
from tqdm import tqdm
VOCAB_SIZE = 5000 # the desired final vocabulary size
def get_stats(ids, counts=None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
# ids: list of integer, pair: the pair of int we are merging, idx: the new int we want to replace the pair with.
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
class HindiTokenizer():
def __init__(self):
self.pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""";
self.merges = {}
self.vocab = {idx: bytes([idx]) for idx in range(256)}
self.special_tokens = {
'<|endoftext|>': VOCAB_SIZE
}
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)} # initial vocab is first 255 unicode bytes
for (p0, p1), idx in self.merges.items(): # Get all the merges and add to vocab
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def tokenize_hindi(self, text):
# Tokenization for Hindi, including math digits
'''pattern = re.compile(r"""
|[\u0900-\u097F](?![\u0964\u0965])+ # Match Hindi words (Devanagari script)
|[\u0966-\u096F]+ # Match Hindi digits (०-९)
|[a-zA-Z]+ # Match English words (Latin script)
|[0-9]+ # Match Latin digits (0-9)
|\s+ # Match whitespace (spaces, tabs, newlines)
|'[^\r\n\p{L}\p{N}]*\p{L}+ # Match apostrophes followed by letters
|\p{N}{1,3} # Match numbers (1 to 3 digits)
|[^\s\p{L}\p{N}]+ # Match non-letter, non-number special characters
|\s*[\r\n] # Match line breaks and leading spaces
|\s+(?!\S) # Match trailing whitespace
""", re.VERBOSE)'''
pattern = re.compile(self.pattern)
return pattern.findall(text)
def learn_bpe_vocab(self, text, num_merges=50):
tokenized_text = self.tokenize_hindi(text)
#print(tokenized_text)
tokens = [list(map(int, token.encode("utf-8"))) for token in tokenized_text]
input_len = 0
for chunk_ids in tokens:
# calculate length of tokens for compression ratio.
# total token length is sum of all token length in each chunk.
input_len += len(chunk_ids)
for i in tqdm(range(num_merges), desc="Merging pairs", unit="merge"):
stats = {}
for chunk_ids in tokens:
stats = get_stats(chunk_ids, stats)
pair = max(stats, key=stats.get)
idx = 256 + i
tokens = [merge(chunk_ids, pair, idx) for chunk_ids in tokens]
self.merges[pair] = idx
self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
output_len = 0
for chunk_ids in tokens:
output_len += len(chunk_ids)
print(f"input_len: {input_len}, output_len: {output_len} compression ratio: {input_len / output_len:.2f}X")
def save_bpe_vocab(self, model_file):
with open(model_file, 'w') as f:
# write the version, pattern and merges, that's all that's needed
f.write("minbpe v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
def load_bpe_vocab(self, filepath):
assert filepath.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(filepath, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
# get the byte for the corresponding token from vocab
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = self.tokenize_hindi(text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# otherwise, we have to be careful with potential special tokens in text
# we handle special tokens by splitting the text
# based on the occurrence of any exact match with any of the special tokens
# we can use re.split for this. note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
if part in special:
# this is a special token, encode it separately as a special case
ids.append(special[part])
else:
# this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
return ids
#print(len(texts))
tokenizer = HindiTokenizer()
tokenizer.load_bpe_vocab("hindi_bpe_vocab.model") |