Spaces:
Sleeping
Sleeping
File size: 5,459 Bytes
41dfb3a 5021135 41dfb3a c563cac 41dfb3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import unicodedata
def get_stats(ids, counts=None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def merge_hindi(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 2 and ids[i+1] == pair[1] and ids[i+2] == pair[2]:
newids.append(idx)
i += 3
else:
newids.append(ids[i])
i += 1
return newids
def replace_control_characters(s):
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch)
else:
chars.append(f"\\u{ord(ch):04x}")
return "".join(chars)
def render_token(t):
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
class Tokenizer:
def __init__(self):
self.merges = {}
self.pattern = ""
self.special_tokens = {}
self.vocab = self._build_vocab()
def build(self, text, vocab_size, verbose=False):
raise NotImplementedError
def encode(self, text):
raise NotImplementedError
def decode(self, ids):
raise NotImplementedError
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab.update({idx: bytes(list(chr(value).encode('utf-8'))) for idx,value in zip(range(256, 384), range(2304, 2432))})
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load()
- vocab file is just a pretty printed version for human inspection only
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
# write the version, pattern and merges, that's all that's needed
f.write("minbpe v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
# note: many tokens may be partial utf-8 sequences
# and cannot be decoded into valid strings. Here we're using
# errors='replace' to replace them with the replacement char �.
# this also means that we couldn't possibly use .vocab in load()
# because decoding in this way is a lossy operation!
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
# if this token has children, render it nicely as a merge
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
# otherwise this is leaf token, just print it
# (this should just be the first 256 tokens, the bytes)
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 384
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab() |