import logging import torch from tokenizers import Tokenizer # Special tokens SOT = "[START]" EOT = "[STOP]" UNK = "[UNK]" SPACE = "[SPACE]" SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] logger = logging.getLogger(__name__) class EnTokenizer: def __init__(self, vocab_file_path): self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) self.check_vocabset_sot_eot() def check_vocabset_sot_eot(self): voc = self.tokenizer.get_vocab() assert SOT in voc assert EOT in voc def text_to_tokens(self, text: str): text_tokens = self.encode(text) text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) return text_tokens def encode( self, txt: str, verbose=False): """ clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer """ txt = txt.replace(' ', SPACE) code = self.tokenizer.encode(txt) ids = code.ids return ids def decode(self, seq): if isinstance(seq, torch.Tensor): seq = seq.cpu().numpy() txt: str = self.tokenizer.decode(seq, skip_special_tokens=False) txt = txt.replace(' ', '') txt = txt.replace(SPACE, ' ') txt = txt.replace(EOT, '') txt = txt.replace(UNK, '') return txt