import time import logging from fastapi import FastAPI, Request from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization") model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Warm-up model to reduce first-request latency dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device) with torch.no_grad(): _ = model.generate(**dummy_input, max_length=32) class SummarizeRequest(BaseModel): text: str @app.get("/") async def root(): return {"message": "Model is ready."} @app.post("/summarize") async def summarize(req: Request, body: SummarizeRequest): start_time = time.time() client_ip = req.client.host logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}") text = body.text.strip() # Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:" if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")): text = "Tin nhanh: " + text else: text = "Vietnews: " + text input_text = text + " " encoding = tokenizer(input_text, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) # Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding) outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=256, num_beams=1, # greedy decoding no_repeat_ngram_size=2 ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) end_time = time.time() logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s") return {"summary": summary}