import pandas as pd import pyarrow.parquet as pq import regex as re from collections import Counter, defaultdict import json from tqdm import tqdm VOCAB_SIZE = 5000 # the desired final vocabulary size def get_stats(ids, counts=None): counts = {} if counts is None else counts for pair in zip(ids, ids[1:]): counts[pair] = counts.get(pair, 0) + 1 return counts # ids: list of integer, pair: the pair of int we are merging, idx: the new int we want to replace the pair with. def merge(ids, pair, idx): """ In the list of integers (ids), replace all consecutive occurrences of pair with the new integer token idx Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] """ newids = [] i = 0 while i < len(ids): if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: newids.append(idx) i += 2 else: newids.append(ids[i]) i += 1 return newids class HindiTokenizer(): def __init__(self): self.pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""; self.merges = {} self.vocab = {idx: bytes([idx]) for idx in range(256)} self.special_tokens = { '<|endoftext|>': VOCAB_SIZE } def _build_vocab(self): # vocab is simply and deterministically derived from merges vocab = {idx: bytes([idx]) for idx in range(256)} # initial vocab is first 255 unicode bytes for (p0, p1), idx in self.merges.items(): # Get all the merges and add to vocab vocab[idx] = vocab[p0] + vocab[p1] for special, idx in self.special_tokens.items(): vocab[idx] = special.encode("utf-8") return vocab def tokenize_hindi(self, text): # Tokenization for Hindi, including math digits '''pattern = re.compile(r""" |[\u0900-\u097F](?![\u0964\u0965])+ # Match Hindi words (Devanagari script) |[\u0966-\u096F]+ # Match Hindi digits (०-९) |[a-zA-Z]+ # Match English words (Latin script) |[0-9]+ # Match Latin digits (0-9) |\s+ # Match whitespace (spaces, tabs, newlines) |'[^\r\n\p{L}\p{N}]*\p{L}+ # Match apostrophes followed by letters |\p{N}{1,3} # Match numbers (1 to 3 digits) |[^\s\p{L}\p{N}]+ # Match non-letter, non-number special characters |\s*[\r\n] # Match line breaks and leading spaces |\s+(?!\S) # Match trailing whitespace """, re.VERBOSE)''' pattern = re.compile(self.pattern) return pattern.findall(text) def learn_bpe_vocab(self, text, num_merges=50): tokenized_text = self.tokenize_hindi(text) #print(tokenized_text) tokens = [list(map(int, token.encode("utf-8"))) for token in tokenized_text] input_len = 0 for chunk_ids in tokens: # calculate length of tokens for compression ratio. # total token length is sum of all token length in each chunk. input_len += len(chunk_ids) for i in tqdm(range(num_merges), desc="Merging pairs", unit="merge"): stats = {} for chunk_ids in tokens: stats = get_stats(chunk_ids, stats) pair = max(stats, key=stats.get) idx = 256 + i tokens = [merge(chunk_ids, pair, idx) for chunk_ids in tokens] self.merges[pair] = idx self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]] output_len = 0 for chunk_ids in tokens: output_len += len(chunk_ids) print(f"input_len: {input_len}, output_len: {output_len} compression ratio: {input_len / output_len:.2f}X") def save_bpe_vocab(self, model_file): with open(model_file, 'w') as f: # write the version, pattern and merges, that's all that's needed f.write("minbpe v1\n") f.write(f"{self.pattern}\n") # write the special tokens, first the number of them, then each one f.write(f"{len(self.special_tokens)}\n") for special, idx in self.special_tokens.items(): f.write(f"{special} {idx}\n") # the merges dict for idx1, idx2 in self.merges: f.write(f"{idx1} {idx2}\n") def load_bpe_vocab(self, filepath): assert filepath.endswith(".model") # read the model file merges = {} special_tokens = {} idx = 256 with open(filepath, 'r', encoding="utf-8") as f: # read the version version = f.readline().strip() assert version == "minbpe v1" # read the pattern self.pattern = f.readline().strip() # read the special tokens num_special = int(f.readline().strip()) for _ in range(num_special): special, special_idx = f.readline().strip().split() special_tokens[special] = int(special_idx) # read the merges for line in f: idx1, idx2 = map(int, line.split()) merges[(idx1, idx2)] = idx idx += 1 self.merges = merges self.special_tokens = special_tokens self.vocab = self._build_vocab() def register_special_tokens(self, special_tokens): # special_tokens is a dictionary of str -> int # example: {"<|endoftext|>": 100257} self.special_tokens = special_tokens self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} def decode(self, ids): # given ids (list of integers), return Python string part_bytes = [] # get the byte for the corresponding token from vocab for idx in ids: if idx in self.vocab: part_bytes.append(self.vocab[idx]) elif idx in self.inverse_special_tokens: part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) else: raise ValueError(f"invalid token id: {idx}") text_bytes = b"".join(part_bytes) text = text_bytes.decode("utf-8", errors="replace") return text def _encode_chunk(self, text_bytes): # return the token ids # let's begin. first, convert all bytes to integers in range 0..255 ids = list(text_bytes) while len(ids) >= 2: # find the pair with the lowest merge index stats = get_stats(ids) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) # subtle: if there are no more merges available, the key will # result in an inf for every single pair, and the min will be # just the first pair in the list, arbitrarily # we can detect this terminating case by a membership check if pair not in self.merges: break # nothing else can be merged anymore # otherwise let's merge the best pair (lowest merge index) idx = self.merges[pair] ids = merge(ids, pair, idx) return ids def encode_ordinary(self, text): """Encoding that ignores any special tokens.""" # split text into chunks of text by categories defined in regex pattern text_chunks = self.tokenize_hindi(text) # all chunks of text are encoded separately, then results are joined ids = [] for chunk in text_chunks: chunk_bytes = chunk.encode("utf-8") # raw bytes chunk_ids = self._encode_chunk(chunk_bytes) ids.extend(chunk_ids) return ids def encode(self, text, allowed_special="none_raise"): """ Unlike encode_ordinary, this function handles special tokens. allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens if none_raise, then an error is raised if any special token is encountered in text this is the default tiktoken behavior right now as well any other behavior is either annoying, or a major footgun """ # decode the user desire w.r.t. handling of special tokens special = None if allowed_special == "all": special = self.special_tokens elif allowed_special == "none": special = {} elif allowed_special == "none_raise": special = {} assert all(token not in text for token in self.special_tokens) elif isinstance(allowed_special, set): special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} else: raise ValueError(f"allowed_special={allowed_special} not understood") if not special: # shortcut: if no special tokens, just use the ordinary encoding return self.encode_ordinary(text) # otherwise, we have to be careful with potential special tokens in text # we handle special tokens by splitting the text # based on the occurrence of any exact match with any of the special tokens # we can use re.split for this. note that surrounding the pattern with () # makes it into a capturing group, so the special tokens will be included special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" special_chunks = re.split(special_pattern, text) # now all the special characters are separated from the rest of the text # all chunks of text are encoded separately, then results are joined ids = [] for part in special_chunks: if part in special: # this is a special token, encode it separately as a special case ids.append(special[part]) else: # this is an ordinary sequence, encode it normally ids.extend(self.encode_ordinary(part)) return ids #print(len(texts)) tokenizer = HindiTokenizer() tokenizer.load_bpe_vocab("hindi_bpe_vocab.model")