VietCat commited on
Commit
fd6737e
·
1 Parent(s): 831df6f

add time log and reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +31 -37
app.py CHANGED
@@ -1,65 +1,59 @@
1
  import time
2
  import logging
3
- import torch
4
  from fastapi import FastAPI, Request
5
  from pydantic import BaseModel
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
- from concurrent.futures import ThreadPoolExecutor
8
- import asyncio
9
-
10
- # Khởi tạo app
11
- app = FastAPI()
12
 
13
- # Logging
14
  logging.basicConfig(level=logging.INFO)
 
15
 
16
- # Load model và tokenizer
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
19
- model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base").to(device)
20
 
21
- # Thread executor để xử lý blocking
22
- executor = ThreadPoolExecutor(max_workers=2)
 
 
 
23
 
24
- # Kiểu dữ liệu đầu vào
25
- class TextIn(BaseModel):
26
  text: str
27
 
28
- # -------------------------------
29
- # GET: kiểm tra API sẵn sàng
30
  @app.get("/")
31
- def read_root():
32
- return {"message": "API is ready."}
33
 
34
- # -------------------------------
35
- # Hàm tóm tắt (blocking)
36
- def summarize_text(text: str) -> str:
37
- prompt = "vietnews: " + text.strip() + " </s>"
38
- encoding = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
 
39
  input_ids = encoding["input_ids"].to(device)
40
  attention_mask = encoding["attention_mask"].to(device)
41
 
 
42
  outputs = model.generate(
43
  input_ids=input_ids,
44
  attention_mask=attention_mask,
45
  max_length=128,
46
  num_beams=2,
47
- early_stopping=True
 
 
48
  )
49
- return tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
-
51
- # -------------------------------
52
- # POST: async API tóm tắt
53
- @app.post("/summarize")
54
- async def summarize(request: Request, payload: TextIn):
55
- start_time = time.time()
56
- client_ip = request.client.host
57
- logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
58
 
59
- summary = await asyncio.get_event_loop().run_in_executor(executor, summarize_text, payload.text)
60
 
61
  end_time = time.time()
62
- duration = end_time - start_time
63
- logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {duration:.2f}s")
64
 
65
  return {"summary": summary}
 
1
  import time
2
  import logging
 
3
  from fastapi import FastAPI, Request
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import torch
 
 
 
 
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
+ app = FastAPI()
 
 
 
12
 
13
+ # Load model and tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
15
+ model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
 
19
+ class SummarizeRequest(BaseModel):
 
20
  text: str
21
 
 
 
22
  @app.get("/")
23
+ async def root():
24
+ return {"message": "Model is ready."}
25
 
26
+ @app.post("/summarize")
27
+ async def summarize(req: Request, body: SummarizeRequest):
28
+ start_time = time.time()
29
+ client_ip = req.client.host
30
+ logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
31
+
32
+ text = body.text.strip()
33
+
34
+ # Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
35
+ if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
36
+ text = "Tin nhanh: " + text
37
+
38
+ input_text = text + " </s>"
39
+ encoding = tokenizer(input_text, return_tensors="pt")
40
  input_ids = encoding["input_ids"].to(device)
41
  attention_mask = encoding["attention_mask"].to(device)
42
 
43
+ # Sinh tóm tắt với cấu hình ổn định
44
  outputs = model.generate(
45
  input_ids=input_ids,
46
  attention_mask=attention_mask,
47
  max_length=128,
48
  num_beams=2,
49
+ early_stopping=True,
50
+ no_repeat_ngram_size=2,
51
+ num_return_sequences=1
52
  )
 
 
 
 
 
 
 
 
 
53
 
54
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
55
 
56
  end_time = time.time()
57
+ logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
 
58
 
59
  return {"summary": summary}