VietCat commited on
Commit
831df6f
·
1 Parent(s): 4814cd0

add time log and reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +41 -25
app.py CHANGED
@@ -1,49 +1,65 @@
 
 
 
1
  from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- import torch
5
- import time
6
- import logging
7
 
 
8
  app = FastAPI()
9
 
10
- # Logging setup
11
  logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger("summarizer")
13
 
14
- # Model & tokenizer
15
- MODEL_NAME = "VietAI/vit5-base-vietnews-summarization"
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
 
20
 
21
- class InputText(BaseModel):
 
 
 
 
22
  text: str
23
 
24
- @app.post("/summarize")
25
- async def summarize(req: Request, input: InputText):
26
- start_time = time.time()
27
- logger.info(f"\U0001F535 Received request from {req.client.host}")
 
28
 
29
- text = input.text.strip()
30
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
 
 
 
 
 
31
 
32
  outputs = model.generate(
33
- **inputs,
 
34
  max_length=128,
35
  num_beams=2,
36
- no_repeat_ngram_size=2,
37
  early_stopping=True
38
  )
39
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
40
 
41
  end_time = time.time()
42
  duration = end_time - start_time
43
- logger.info(f"\u2705 Response sent — total time: {duration:.2f}s")
44
 
45
  return {"summary": summary}
46
-
47
- @app.get("/")
48
- def root():
49
- return {"message": "Vietnamese Summarization API is up and running!"}
 
1
+ import time
2
+ import logging
3
+ import torch
4
  from fastapi import FastAPI, Request
5
  from pydantic import BaseModel
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import asyncio
 
9
 
10
+ # Khởi tạo app
11
  app = FastAPI()
12
 
13
+ # Logging
14
  logging.basicConfig(level=logging.INFO)
 
15
 
16
+ # Load model tokenizer
 
 
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base").to(device)
20
 
21
+ # Thread executor để xử lý blocking
22
+ executor = ThreadPoolExecutor(max_workers=2)
23
+
24
+ # Kiểu dữ liệu đầu vào
25
+ class TextIn(BaseModel):
26
  text: str
27
 
28
+ # -------------------------------
29
+ # GET: kiểm tra API sẵn sàng
30
+ @app.get("/")
31
+ def read_root():
32
+ return {"message": "API is ready."}
33
 
34
+ # -------------------------------
35
+ # Hàm tóm tắt (blocking)
36
+ def summarize_text(text: str) -> str:
37
+ prompt = "vietnews: " + text.strip() + " </s>"
38
+ encoding = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
39
+ input_ids = encoding["input_ids"].to(device)
40
+ attention_mask = encoding["attention_mask"].to(device)
41
 
42
  outputs = model.generate(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
  max_length=128,
46
  num_beams=2,
 
47
  early_stopping=True
48
  )
49
+ return tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
+
51
+ # -------------------------------
52
+ # POST: async API tóm tắt
53
+ @app.post("/summarize")
54
+ async def summarize(request: Request, payload: TextIn):
55
+ start_time = time.time()
56
+ client_ip = request.client.host
57
+ logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
58
+
59
+ summary = await asyncio.get_event_loop().run_in_executor(executor, summarize_text, payload.text)
60
 
61
  end_time = time.time()
62
  duration = end_time - start_time
63
+ logging.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {duration:.2f}s")
64
 
65
  return {"summary": summary}