usta-llm-demo / v2 /usta_tokenizer.py
alibayram's picture
v2 implemented
6563ff2
import json
import torch
class UstaTokenizer:
def __init__(self, vocab_file):
with open(vocab_file, "r") as f:
self.vocab = json.load(f)
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
def encode_batch(self, texts, context_length):
sentences_tokens = []
for text in texts:
tokens = self.encode(text).tolist()
if len(tokens) > context_length:
tokens = tokens[:context_length]
else:
tokens = tokens + [self.vocab["<pad>"]] * (context_length - len(tokens))
sentences_tokens.append(tokens)
return torch.tensor(sentences_tokens)
def encode(self, text):
tokens = []
for word in text.split():
i = 0
# example: states
# state => 4
# s => 58
while i < len(word):
found_match = False
for j in range(len(word), i, -1):
sub_word = word[i:j]
if sub_word in self.vocab:
tokens.append(self.vocab[sub_word])
i = j
found_match = True
break
if not found_match:
tokens.append(self.vocab["<unk>"])
i += 1
tokens.append(self.vocab[" "])
# check if text is not ends with a space
if not text.endswith(" "):
tokens.pop()
return torch.tensor(tokens)
def tokenize(self, text):
token_ids = self.encode(text)
# token_ids from tensor to list
token_ids = token_ids.detach().numpy().tolist()
return [self.reverse_vocab[id] for id in token_ids]
def decode(self, ids):
text = ""
for id in ids:
text += self.reverse_vocab[id]
return text