ViT5BaseNode / app.py
VietCat's picture
reduce processing time
fb4a646
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Warm-up model to reduce first-request latency
dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device)
with torch.no_grad():
_ = model.generate(**dummy_input, max_length=32)
class SummarizeRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {"message": "Model is ready."}
@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
start_time = time.time()
client_ip = req.client.host
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
text = body.text.strip()
# Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
text = "Tin nhanh: " + text
else:
text = "Vietnews: " + text
input_text = text + " </s>"
encoding = tokenizer(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=256,
num_beams=1, # greedy decoding
no_repeat_ngram_size=2
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
end_time = time.time()
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
return {"summary": summary}