Spaces:
Running
Running
import logging | |
import torch | |
from tokenizers import Tokenizer | |
# Special tokens | |
SOT = "[START]" | |
EOT = "[STOP]" | |
UNK = "[UNK]" | |
SPACE = "[SPACE]" | |
SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] | |
logger = logging.getLogger(__name__) | |
class EnTokenizer: | |
def __init__(self, vocab_file_path): | |
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
self.check_vocabset_sot_eot() | |
def check_vocabset_sot_eot(self): | |
voc = self.tokenizer.get_vocab() | |
assert SOT in voc | |
assert EOT in voc | |
def text_to_tokens(self, text: str): | |
text_tokens = self.encode(text) | |
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
return text_tokens | |
def encode( self, txt: str, verbose=False): | |
""" | |
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer | |
""" | |
txt = txt.replace(' ', SPACE) | |
code = self.tokenizer.encode(txt) | |
ids = code.ids | |
return ids | |
def decode(self, seq): | |
if isinstance(seq, torch.Tensor): | |
seq = seq.cpu().numpy() | |
txt: str = self.tokenizer.decode(seq, | |
skip_special_tokens=False) | |
txt = txt.replace(' ', '') | |
txt = txt.replace(SPACE, ' ') | |
txt = txt.replace(EOT, '') | |
txt = txt.replace(UNK, '') | |
return txt | |