from fastapi import FastAPI, Request from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import time import logging app = FastAPI() # Logging setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Model & tokenizer MODEL_NAME = "VietAI/vit5-base-vietnews-summarization" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) class InputText(BaseModel): text: str @app.post("/summarize") async def summarize(req: Request, input: InputText): start_time = time.time() logger.info(f"\U0001F535 Received request from {req.client.host}") text = input.text.strip() inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) outputs = model.generate( **inputs, max_length=128, num_beams=2, no_repeat_ngram_size=2, early_stopping=True ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True) end_time = time.time() duration = end_time - start_time logger.info(f"\u2705 Response sent — total time: {duration:.2f}s") return {"summary": summary} @app.get("/") def root(): return {"message": "Vietnamese Summarization API is up and running!"}