VietCat commited on
Commit
c5a0bf8
·
1 Parent(s): a9b7eee

switch to fastapi

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -2
  2. app.py +32 -38
  3. requirements.txt +3 -3
Dockerfile CHANGED
@@ -18,5 +18,4 @@ COPY app.py .
18
  # Expose port mặc định HFS (7860)
19
  EXPOSE 7860
20
 
21
- # Run Flask
22
- CMD ["python", "app.py"]
 
18
  # Expose port mặc định HFS (7860)
19
  EXPOSE 7860
20
 
21
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
app.py CHANGED
@@ -1,46 +1,40 @@
1
- import os
2
- from flask import Flask, request, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- # Khai báo thư mục cache an toàn cho Hugging Face
6
- os.environ["HF_HOME"] = "/app/cache"
7
- os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
8
 
9
- app = Flask(__name__)
10
 
 
11
  model_name = "VietAI/vit5-base"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
 
15
- @app.route("/summarize", methods=["POST"])
16
- def summarize():
17
- data = request.get_json()
18
- text = data.get("text", "").strip()
19
-
20
- if not text:
21
- return jsonify({"error": "Missing 'text' field"}), 400
22
-
23
- # ✅ Rất quan trọng: Thêm tiền tố 'summarize:'
24
- prompt = f"summarize: {text}"
25
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
26
-
27
- summary_ids = model.generate(
28
- inputs,
29
- max_length=100,
30
- min_length=10,
31
- num_beams=4,
32
- no_repeat_ngram_size=3,
33
- repetition_penalty=2.5,
34
- length_penalty=1.0,
35
- early_stopping=True
36
- )
37
-
38
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
- return jsonify({"summary": summary})
40
-
41
- @app.route("/", methods=["GET"])
42
- def index():
43
- return "✅ ViT5 summarization API is running."
44
-
45
- if __name__ == "__main__":
46
- app.run(host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ import torch
 
 
6
 
7
+ app = FastAPI()
8
 
9
+ # Load model
10
  model_name = "VietAI/vit5-base"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
+ # Input format
15
+ class TextInput(BaseModel):
16
+ text: str
17
+
18
+ @app.get("/")
19
+ def read_root():
20
+ return {"message": "ViT5 summarization API is running!"}
21
+
22
+ @app.post("/summarize")
23
+ def summarize(input: TextInput):
24
+ try:
25
+ input_text = f"summarize: {input.text}"
26
+ inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
27
+ summary_ids = model.generate(
28
+ inputs,
29
+ max_length=128,
30
+ min_length=20,
31
+ num_beams=4,
32
+ no_repeat_ngram_size=3,
33
+ repetition_penalty=2.5,
34
+ length_penalty=1.0,
35
+ early_stopping=True
36
+ )
37
+ output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
+ return {"summary": output}
39
+ except Exception as e:
40
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- flask
2
- transformers
3
  torch
4
- sentencepiece
 
 
1
+ transformers==4.41.2
 
2
  torch
3
+ fastapi
4
+ uvicorn