File size: 2,256 Bytes
831df6f
 
69847a0
c5a0bf8
20688a8
fd6737e
20688a8
4814cd0
fd6737e
4814cd0
fd6737e
69847a0
fd6737e
e3dbead
 
fd6737e
 
831df6f
fb4a646
 
 
 
 
fd6737e
c5a0bf8
 
831df6f
fd6737e
 
4814cd0
fd6737e
 
 
 
 
 
 
 
 
 
 
29182c9
 
 
fd6737e
 
831df6f
 
4814cd0
fb4a646
4814cd0
fb4a646
 
8a05f36
fb4a646
 
4814cd0
fb4a646
fd6737e
c3ffcdd
 
fd6737e
c3ffcdd
69847a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Warm-up model to reduce first-request latency
dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device)
with torch.no_grad():
    _ = model.generate(**dummy_input, max_length=32)

class SummarizeRequest(BaseModel):
    text: str

@app.get("/")
async def root():
    return {"message": "Model is ready."}

@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
    start_time = time.time()
    client_ip = req.client.host
    logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")

    text = body.text.strip()

    # Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
    if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
        text = "Tin nhanh: " + text
    else:
        text = "Vietnews: " + text

    input_text = text + " </s>"
    encoding = tokenizer(input_text, return_tensors="pt")
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    # Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding)
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=256,
        num_beams=1,  # greedy decoding
        no_repeat_ngram_size=2
    )

    summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

    end_time = time.time()
    logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")

    return {"summary": summary}