noob / tokenizer.py
simpx's picture
Upload folder using huggingface_hub
0b0cd06 verified
raw
history blame
1.69 kB
from transformers import PreTrainedTokenizer
from typing import List, Optional
import os
import json
import argparse
class NoobTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file=None, **kwargs):
if vocab_file is None:
# 默认词汇表,仅用于初始化
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
else:
# 从文件加载词汇表
with open(vocab_file, 'r', encoding='utf-8') as f:
self.stoi = json.load(f)
self.itos = {int(i): ch for ch, i in self.stoi.items()}
super().__init__(**kwargs)
@property
def vocab_size(self) -> int:
return len(self.stoi)
def get_vocab(self):
return dict(self.stoi)
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self.stoi.get(token, 0)
def _convert_id_to_token(self, index: int) -> str:
return self.itos[index]
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return "".join(tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
vocab_file = os.path.join(save_directory, 'vocab.json')
with open(vocab_file, 'w', encoding='utf-8') as f:
json.dump(self.stoi, f, ensure_ascii=False)
return (vocab_file,)