|
import sentencepiece as spm |
|
|
|
from pathlib import Path |
|
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
|
|
|
import json |
|
from abc import abstractmethod |
|
from abc import ABC |
|
import numpy as np |
|
|
|
|
|
class BaseTokenizer(ABC): |
|
def __init__( |
|
self, |
|
token_list: Union[Path, str, Iterable[str]] = None, |
|
unk_symbol: str = "<unk>", |
|
**kwargs, |
|
): |
|
|
|
if token_list is not None: |
|
if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"): |
|
token_list = Path(token_list) |
|
self.token_list_repr = str(token_list) |
|
self.token_list: List[str] = [] |
|
|
|
with token_list.open("r", encoding="utf-8") as f: |
|
for idx, line in enumerate(f): |
|
line = line.rstrip() |
|
self.token_list.append(line) |
|
elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"): |
|
token_list = Path(token_list) |
|
self.token_list_repr = str(token_list) |
|
self.token_list: List[str] = [] |
|
|
|
with open(token_list, "r", encoding="utf-8") as f: |
|
self.token_list = json.load(f) |
|
|
|
else: |
|
self.token_list: List[str] = list(token_list) |
|
self.token_list_repr = "" |
|
for i, t in enumerate(self.token_list): |
|
if i == 3: |
|
break |
|
self.token_list_repr += f"{t}, " |
|
self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
|
|
|
self.token2id: Dict[str, int] = {} |
|
for i, t in enumerate(self.token_list): |
|
if t in self.token2id: |
|
raise RuntimeError(f'Symbol "{t}" is duplicated') |
|
self.token2id[t] = i |
|
|
|
self.unk_symbol = unk_symbol |
|
if self.unk_symbol not in self.token2id: |
|
raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list") |
|
self.unk_id = self.token2id[self.unk_symbol] |
|
|
|
def encode(self, text, **kwargs): |
|
tokens = self.text2tokens(text) |
|
text_ints = self.tokens2ids(tokens) |
|
|
|
return text_ints |
|
|
|
def decode(self, text_ints): |
|
token = self.ids2tokens(text_ints) |
|
text = self.tokens2text(token) |
|
return text |
|
|
|
def get_num_vocabulary_size(self) -> int: |
|
return len(self.token_list) |
|
|
|
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
|
if isinstance(integers, np.ndarray) and integers.ndim != 1: |
|
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
|
return [self.token_list[i] for i in integers] |
|
|
|
def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
|
return [self.token2id.get(i, self.unk_id) for i in tokens] |
|
|
|
@abstractmethod |
|
def text2tokens(self, line: str) -> List[str]: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def tokens2text(self, tokens: Iterable[str]) -> str: |
|
raise NotImplementedError |
|
|
|
|
|
class SentencepiecesTokenizer(BaseTokenizer): |
|
def __init__(self, bpemodel: Union[Path, str], **kwargs): |
|
super().__init__(**kwargs) |
|
self.bpemodel = str(bpemodel) |
|
|
|
|
|
|
|
|
|
|
|
self.sp = None |
|
self._build_sentence_piece_processor() |
|
|
|
def __repr__(self): |
|
return f'{self.__class__.__name__}(model="{self.bpemodel}")' |
|
|
|
def _build_sentence_piece_processor(self): |
|
|
|
if self.sp is None: |
|
self.sp = spm.SentencePieceProcessor() |
|
self.sp.load(self.bpemodel) |
|
|
|
def text2tokens(self, line: str) -> List[str]: |
|
self._build_sentence_piece_processor() |
|
return self.sp.EncodeAsPieces(line) |
|
|
|
def tokens2text(self, tokens: Iterable[str]) -> str: |
|
self._build_sentence_piece_processor() |
|
return self.sp.DecodePieces(list(tokens)) |
|
|
|
def encode(self, line: str, **kwargs) -> List[int]: |
|
self._build_sentence_piece_processor() |
|
return self.sp.EncodeAsIds(line) |
|
|
|
def decode(self, line: List[int], **kwargs): |
|
self._build_sentence_piece_processor() |
|
return self.sp.DecodeIds(line) |
|
|
|
def get_vocab_size(self): |
|
return self.sp.GetPieceSize() |
|
|
|
def ids2tokens(self, *args, **kwargs): |
|
return self.decode(*args, **kwargs) |
|
|
|
def tokens2ids(self, *args, **kwargs): |
|
return self.encode(*args, **kwargs) |