File size: 2,164 Bytes
831df6f
 
69847a0
c5a0bf8
20688a8
fd6737e
20688a8
4814cd0
fd6737e
4814cd0
fd6737e
69847a0
fd6737e
e3dbead
 
fd6737e
 
831df6f
fd6737e
c5a0bf8
 
831df6f
fd6737e
 
4814cd0
fd6737e
 
 
 
 
 
 
 
 
 
 
29182c9
 
 
fd6737e
 
 
831df6f
 
4814cd0
fd6737e
8a05f36
 
 
 
 
 
 
 
 
4814cd0
08a672b
8a05f36
 
4814cd0
fd6737e
c3ffcdd
 
fd6737e
c3ffcdd
69847a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

class SummarizeRequest(BaseModel):
    text: str

@app.get("/")
async def root():
    return {"message": "Model is ready."}

@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
    start_time = time.time()
    client_ip = req.client.host
    logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")

    text = body.text.strip()

    # Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
    if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
        text = "Tin nhanh: " + text
    else:
        text = "Vietnews: " + text


    input_text = text + " </s>"
    encoding = tokenizer(input_text, return_tensors="pt")
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    # Sinh tóm tắt với cấu hình ổn định
    # outputs = model.generate(
    #     input_ids=input_ids,
    #     attention_mask=attention_mask,
    #     max_length=128,
    #     num_beams=1,
    #     early_stopping=True,
    #     no_repeat_ngram_size=2,
    #     num_return_sequences=1
    # )
    outputs = model.generate(
        input_ids=input_ids, attention_mask=attention_mask,
        max_length=256,
        early_stopping=True
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

    end_time = time.time()
    logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")

    return {"summary": summary}