VietCat commited on
Commit
c3ffcdd
·
1 Parent(s): b9e0b8c

add time log and reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -2,42 +2,53 @@ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
 
 
5
 
6
  app = FastAPI()
7
 
8
- # Load model tokenizer
9
  model_name = "VietAI/vit5-base-vietnews-summarization"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model.to(device)
14
 
15
- # Định nghĩa schema đầu vào
16
- class SummaryRequest(BaseModel):
17
  text: str
18
 
19
  @app.get("/")
20
  def read_root():
21
- return {"message": "VietAI viT5 summarization API is running."}
22
 
23
  @app.post("/summarize")
24
- def summarize(request: SummaryRequest):
25
- text = request.text.strip()
26
- if not text:
27
- return {"summary": ""}
28
 
29
- prefix = "vietnews: " + text + " </s>"
30
- encoding = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
31
  input_ids = encoding["input_ids"].to(device)
32
  attention_mask = encoding["attention_mask"].to(device)
33
 
34
- outputs = model.generate(
35
- input_ids=input_ids,
36
- attention_mask=attention_mask,
37
- max_length=128, # Tóm tắt ngắn gọn
38
- do_sample=False, # Không sampling
39
- num_beams=1 # Greedy decoding (nhanh nhất)
40
- )
 
 
 
41
 
42
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
 
 
 
 
43
  return {"summary": summary}
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
+ from datetime import datetime
6
+ import time
7
 
8
  app = FastAPI()
9
 
10
+ # Load model and tokenizer
11
  model_name = "VietAI/vit5-base-vietnews-summarization"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = model.to(device)
16
 
17
+ class TextInput(BaseModel):
 
18
  text: str
19
 
20
  @app.get("/")
21
  def read_root():
22
+ return {"message": "Summarization API is running"}
23
 
24
  @app.post("/summarize")
25
+ async def summarize(input_text: TextInput, request: Request):
26
+ start_time = time.time()
27
+ print(f"[{datetime.now()}] 🔵 Received request from {request.client.host}")
 
28
 
29
+ text = input_text.text.strip()
30
+ prefix = "vietnews: "
31
+ input_text_prefixed = prefix + text + " </s>"
32
+
33
+ # Tokenize
34
+ encoding = tokenizer(input_text_prefixed, return_tensors="pt", truncation=True, max_length=512)
35
  input_ids = encoding["input_ids"].to(device)
36
  attention_mask = encoding["attention_mask"].to(device)
37
 
38
+ # Generate summary with optimized settings
39
+ with torch.inference_mode():
40
+ outputs = model.generate(
41
+ input_ids=input_ids,
42
+ attention_mask=attention_mask,
43
+ max_length=96, # giảm độ dài để xử lý nhanh hơn
44
+ num_beams=1, # dùng greedy decoding
45
+ no_repeat_ngram_size=2,
46
+ early_stopping=True
47
+ )
48
 
49
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
50
+
51
+ end_time = time.time()
52
+ print(f"[{datetime.now()}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
53
+
54
  return {"summary": summary}