tokenGPT-2 / app.py
Itsme5's picture
Update app.py
283d1fe verified
raw
history blame
1.8 kB
from fastapi import FastAPI
from transformers import PreTrainedTokenizerFast
from tokenizers import ByteLevelBPETokenizer
from datasets import load_dataset
from contextlib import asynccontextmanager
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Application starting...")
await train_tokenizer()
yield
logger.info("Application shutting down...")
app = FastAPI(lifespan=lifespan)
async def train_tokenizer():
vocab_size = 50000
min_frequency = 2
#dataset_greek = load_dataset("oscar", "unshuffled_deduplicated_el", split="train", streaming=True)
dataset_greek = load_dataset("wikipedia", "20231101.el", split="train", streaming=True)
dataset_english = load_dataset("wikipedia", "20231101.en", split="train", streaming=True)
try:
dataset_code = load_dataset("bigcode/the-stack", split="train", streaming=True)
datasets_list = [dataset_greek, dataset_english]
except:
datasets_list = [dataset_greek, dataset_english]
def preprocess_data(dataset):
for item in dataset:
text = item["text"]
text = text.strip().lower()
if text:
yield text
combined_data = (preprocess_data(dataset.take(1000)) for dataset in datasets_list)
tokenizer = ByteLevelBPETokenizer()
tokenizer.train_from_iterator(
combined_data,
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
)
tokenizer.save_model(".")
logger.info("Tokenizer training completed!")
@app.get("/")
async def root():
return {"message": "Custom Tokenizer Training Completed and Saved"}