Spaces:
Running
Running
File size: 1,626 Bytes
8d4b0c7 6563ff2 8d4b0c7 6563ff2 8d4b0c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import json
import torch
class UstaTokenizer:
def __init__(self, vocab_file):
with open(vocab_file, "r") as f:
self.vocab = json.load(f)
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
def encode_batch(self, texts, context_length):
sentences_tokens = []
for text in texts:
tokens = self.encode(text).tolist()
if len(tokens) > context_length:
tokens = tokens[:context_length]
else:
tokens = tokens + [self.vocab["<pad>"]] * (context_length - len(tokens))
sentences_tokens.append(tokens)
return torch.tensor(sentences_tokens)
def encode(self, text):
tokens = []
for word in text.split():
i = 0
# example: states
# state => 4
# s => 58
while i < len(word):
found_match = False
for j in range(len(word), i, -1):
sub_word = word[i:j]
if sub_word in self.vocab:
tokens.append(self.vocab[sub_word])
i = j
found_match = True
break
if not found_match:
tokens.append(self.vocab["<unk>"])
i += 1
tokens.append(self.vocab[" "])
# check if text is not ends with a space
if not text.endswith(" "):
tokens.pop()
return torch.tensor(tokens)
def tokenize(self, text):
token_ids = self.encode(text)
# token_ids from tensor to list
token_ids = token_ids.detach().numpy().tolist()
return [self.reverse_vocab[id] for id in token_ids]
def decode(self, ids):
text = ""
for id in ids:
text += self.reverse_vocab[id]
return text
|