VietCat commited on
Commit
4814cd0
·
1 Parent(s): c3ffcdd

add time log and reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +31 -36
app.py CHANGED
@@ -2,53 +2,48 @@ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
- from datetime import datetime
6
  import time
 
7
 
8
  app = FastAPI()
9
 
10
- # Load model and tokenizer
11
- model_name = "VietAI/vit5-base-vietnews-summarization"
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model = model.to(device)
16
 
17
- class TextInput(BaseModel):
18
  text: str
19
 
20
- @app.get("/")
21
- def read_root():
22
- return {"message": "Summarization API is running"}
23
-
24
  @app.post("/summarize")
25
- async def summarize(input_text: TextInput, request: Request):
26
  start_time = time.time()
27
- print(f"[{datetime.now()}] 🔵 Received request from {request.client.host}")
28
-
29
- text = input_text.text.strip()
30
- prefix = "vietnews: "
31
- input_text_prefixed = prefix + text + " </s>"
32
-
33
- # Tokenize
34
- encoding = tokenizer(input_text_prefixed, return_tensors="pt", truncation=True, max_length=512)
35
- input_ids = encoding["input_ids"].to(device)
36
- attention_mask = encoding["attention_mask"].to(device)
37
-
38
- # Generate summary with optimized settings
39
- with torch.inference_mode():
40
- outputs = model.generate(
41
- input_ids=input_ids,
42
- attention_mask=attention_mask,
43
- max_length=96, # giảm độ dài để xử lý nhanh hơn
44
- num_beams=1, # dùng greedy decoding
45
- no_repeat_ngram_size=2,
46
- early_stopping=True
47
- )
48
-
49
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
 
51
  end_time = time.time()
52
- print(f"[{datetime.now()}] Response sent — total time: {end_time - start_time:.2f}s")
 
53
 
54
  return {"summary": summary}
 
 
 
 
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
 
5
  import time
6
+ import logging
7
 
8
  app = FastAPI()
9
 
10
+ # Logging setup
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger("summarizer")
13
+
14
+ # Model & tokenizer
15
+ MODEL_NAME = "VietAI/vit5-base-vietnews-summarization"
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
 
21
+ class InputText(BaseModel):
22
  text: str
23
 
 
 
 
 
24
  @app.post("/summarize")
25
+ async def summarize(req: Request, input: InputText):
26
  start_time = time.time()
27
+ logger.info(f"\U0001F535 Received request from {req.client.host}")
28
+
29
+ text = input.text.strip()
30
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
31
+
32
+ outputs = model.generate(
33
+ **inputs,
34
+ max_length=128,
35
+ num_beams=2,
36
+ no_repeat_ngram_size=2,
37
+ early_stopping=True
38
+ )
39
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
40
 
41
  end_time = time.time()
42
+ duration = end_time - start_time
43
+ logger.info(f"\u2705 Response sent — total time: {duration:.2f}s")
44
 
45
  return {"summary": summary}
46
+
47
+ @app.get("/")
48
+ def root():
49
+ return {"message": "Vietnamese Summarization API is up and running!"}