Spaces:
Running
Running
import json | |
import torch | |
class UstaTokenizer: | |
def __init__(self, vocab_file): | |
with open(vocab_file, "r") as f: | |
self.vocab = json.load(f) | |
self.reverse_vocab = {v: k for k, v in self.vocab.items()} | |
def encode_batch(self, texts, context_length): | |
sentences_tokens = [] | |
for text in texts: | |
tokens = self.encode(text).tolist() | |
if len(tokens) > context_length: | |
tokens = tokens[:context_length] | |
else: | |
tokens = tokens + [self.vocab["<pad>"]] * (context_length - len(tokens)) | |
sentences_tokens.append(tokens) | |
return torch.tensor(sentences_tokens) | |
def encode(self, text): | |
tokens = [] | |
for word in text.split(): | |
i = 0 | |
# example: states | |
# state => 4 | |
# s => 58 | |
while i < len(word): | |
found_match = False | |
for j in range(len(word), i, -1): | |
sub_word = word[i:j] | |
if sub_word in self.vocab: | |
tokens.append(self.vocab[sub_word]) | |
i = j | |
found_match = True | |
break | |
if not found_match: | |
tokens.append(self.vocab["<unk>"]) | |
i += 1 | |
tokens.append(self.vocab[" "]) | |
# check if text is not ends with a space | |
if not text.endswith(" "): | |
tokens.pop() | |
return torch.tensor(tokens) | |
def tokenize(self, text): | |
token_ids = self.encode(text) | |
# token_ids from tensor to list | |
token_ids = token_ids.detach().numpy().tolist() | |
return [self.reverse_vocab[id] for id in token_ids] | |
def decode(self, ids): | |
text = "" | |
for id in ids: | |
text += self.reverse_vocab[id] | |
return text | |