VietCat commited on
Commit
69847a0
·
1 Parent(s): 9597384

update logic based on official example

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -1,40 +1,45 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
-
5
  import torch
6
 
 
7
  app = FastAPI()
8
 
9
- # Load model
10
  model_name = "VietAI/vit5-base"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
- # Input format
15
- class TextInput(BaseModel):
 
 
 
 
16
  text: str
17
 
18
  @app.get("/")
19
- def read_root():
20
- return {"message": "ViT5 summarization API is running!"}
21
 
22
  @app.post("/summarize")
23
- def summarize(input: TextInput):
24
- try:
25
- input_text = f"summarize: {input.text}"
26
- inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
27
- summary_ids = model.generate(
28
- inputs,
29
- max_length=128,
30
- min_length=20,
31
- num_beams=4,
32
- no_repeat_ngram_size=3,
33
- repetition_penalty=2.5,
34
- length_penalty=1.0,
35
- early_stopping=True
36
- )
37
- output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
- return {"summary": output}
39
- except Exception as e:
40
- raise HTTPException(status_code=500, detail=str(e))
 
 
1
+ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
  import torch
5
 
6
+ # Khởi tạo FastAPI app
7
  app = FastAPI()
8
 
9
+ # Tải model và tokenizer
10
  model_name = "VietAI/vit5-base"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
+ # Thiết bị (GPU nếu có, nếu không dùng CPU)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ # Schema cho input
19
+ class SummarizeInput(BaseModel):
20
  text: str
21
 
22
  @app.get("/")
23
+ async def root():
24
+ return {"message": "VietAI vit5-base summarization API is running."}
25
 
26
  @app.post("/summarize")
27
+ async def summarize(input: SummarizeInput):
28
+ prefix = "vietnews: "
29
+ text = prefix + input.text.strip() + " </s>"
30
+
31
+ # Tokenize và chuyển sang device
32
+ encoding = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
33
+ input_ids = encoding["input_ids"].to(device)
34
+ attention_mask = encoding["attention_mask"].to(device)
35
+
36
+ # Sinh tóm tắt
37
+ summary_ids = model.generate(
38
+ input_ids=input_ids,
39
+ attention_mask=attention_mask,
40
+ max_length=256,
41
+ early_stopping=True
42
+ )
43
+
44
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
45
+ return {"summary": summary}