Spaces:
Sleeping
Sleeping
File size: 2,256 Bytes
831df6f 69847a0 c5a0bf8 20688a8 fd6737e 20688a8 4814cd0 fd6737e 4814cd0 fd6737e 69847a0 fd6737e e3dbead fd6737e 831df6f fb4a646 fd6737e c5a0bf8 831df6f fd6737e 4814cd0 fd6737e 29182c9 fd6737e 831df6f 4814cd0 fb4a646 4814cd0 fb4a646 8a05f36 fb4a646 4814cd0 fb4a646 fd6737e c3ffcdd fd6737e c3ffcdd 69847a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Warm-up model to reduce first-request latency
dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device)
with torch.no_grad():
_ = model.generate(**dummy_input, max_length=32)
class SummarizeRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {"message": "Model is ready."}
@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
start_time = time.time()
client_ip = req.client.host
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
text = body.text.strip()
# Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
text = "Tin nhanh: " + text
else:
text = "Vietnews: " + text
input_text = text + " </s>"
encoding = tokenizer(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=256,
num_beams=1, # greedy decoding
no_repeat_ngram_size=2
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
end_time = time.time()
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
return {"summary": summary}
|