ViT5BaseNode / app.py
VietCat's picture
reduce processing time
8a05f36
raw
history blame
2.17 kB
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
class SummarizeRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {"message": "Model is ready."}
@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
start_time = time.time()
client_ip = req.client.host
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
text = body.text.strip()
# Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
text = "Tin nhanh: " + text
else:
text = "Vietnews: " + text
input_text = text + " </s>"
encoding = tokenizer(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Sinh tóm tắt với cấu hình ổn định
# outputs = model.generate(
# input_ids=input_ids,
# attention_mask=attention_mask,
# max_length=128,
# num_beams=1,
# early_stopping=True,
# no_repeat_ngram_size=2,
# num_return_sequences=1
# )
outputs = model.generate(
input_ids=input_ids, attention_mask=attention_masks,
max_length=256,
early_stopping=True
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
end_time = time.time()
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
return {"summary": summary}