VietCat commited on
Commit
4159c4a
·
1 Parent(s): 69847a0

update logic based on official example

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -3,43 +3,41 @@ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
6
- # Khởi tạo FastAPI app
7
  app = FastAPI()
8
 
9
- # Tải model và tokenizer
10
  model_name = "VietAI/vit5-base"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
-
14
- # Thiết bị (GPU nếu có, nếu không dùng CPU)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.to(device)
17
 
18
- # Schema cho input
19
- class SummarizeInput(BaseModel):
20
  text: str
21
 
22
  @app.get("/")
23
- async def root():
24
- return {"message": "VietAI vit5-base summarization API is running."}
25
 
26
  @app.post("/summarize")
27
- async def summarize(input: SummarizeInput):
28
- prefix = "vietnews: "
29
- text = prefix + input.text.strip() + " </s>"
 
30
 
31
- # Tokenize chuyển sang device
32
- encoding = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
33
  input_ids = encoding["input_ids"].to(device)
34
  attention_mask = encoding["attention_mask"].to(device)
35
 
36
- # Sinh tóm tắt
37
- summary_ids = model.generate(
38
  input_ids=input_ids,
39
  attention_mask=attention_mask,
40
- max_length=256,
41
- early_stopping=True
 
42
  )
43
 
44
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
45
  return {"summary": summary}
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
 
6
  app = FastAPI()
7
 
8
+ # Load model và tokenizer
9
  model_name = "VietAI/vit5-base"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
+ # Định nghĩa schema đầu vào
16
+ class SummaryRequest(BaseModel):
17
  text: str
18
 
19
  @app.get("/")
20
+ def read_root():
21
+ return {"message": "VietAI viT5 summarization API is running."}
22
 
23
  @app.post("/summarize")
24
+ def summarize(request: SummaryRequest):
25
+ text = request.text.strip()
26
+ if not text:
27
+ return {"summary": ""}
28
 
29
+ prefix = "vietnews: " + text + " </s>"
30
+ encoding = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=512)
31
  input_ids = encoding["input_ids"].to(device)
32
  attention_mask = encoding["attention_mask"].to(device)
33
 
34
+ outputs = model.generate(
 
35
  input_ids=input_ids,
36
  attention_mask=attention_mask,
37
+ max_length=128, # Tóm tắt ngắn gọn
38
+ do_sample=False, # Không sampling
39
+ num_beams=1 # Greedy decoding (nhanh nhất)
40
  )
41
 
42
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
43
  return {"summary": summary}