VietCat commited on
Commit
29e22ca
·
1 Parent(s): 6f583aa

fix duplicate issue

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -1,35 +1,46 @@
 
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
 
 
 
 
4
  app = Flask(__name__)
5
 
6
- # Load model
7
  model_name = "VietAI/vit5-base"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  @app.route("/summarize", methods=["POST"])
12
  def summarize():
13
- data = request.json
14
- text = data.get("text", "")
15
- if not text.strip():
16
- return jsonify({"error": "Missing text"}), 400
 
17
 
 
18
  inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
 
 
19
  summary_ids = model.generate(
20
  inputs,
21
  max_length=100,
22
- min_length=30,
23
  num_beams=4,
24
- length_penalty=2.0,
 
 
25
  early_stopping=True
26
  )
27
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
28
  return jsonify({"summary": summary})
29
 
30
  @app.route("/", methods=["GET"])
31
- def root():
32
- return "ViT5 summarization API is running."
33
 
34
  if __name__ == "__main__":
35
  app.run(host="0.0.0.0", port=7860)
 
1
+ import os
2
  from flask import Flask, request, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # ⚙️ Khắc phục lỗi không ghi được cache khi deploy trên HFS
6
+ os.environ["HF_HOME"] = "/app/cache"
7
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
8
+
9
  app = Flask(__name__)
10
 
11
+ # 🚀 Load mô hình
12
  model_name = "VietAI/vit5-base"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
 
16
  @app.route("/summarize", methods=["POST"])
17
  def summarize():
18
+ data = request.get_json()
19
+ text = data.get("text", "").strip()
20
+
21
+ if not text:
22
+ return jsonify({"error": "Missing 'text' field"}), 400
23
 
24
+ # ⚠️ Giới hạn đầu vào (ViT5-base tối đa 512 tokens)
25
  inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
26
+
27
+ # ✅ Tham số sinh văn bản chống lặp + chất lượng cao
28
  summary_ids = model.generate(
29
  inputs,
30
  max_length=100,
31
+ min_length=10,
32
  num_beams=4,
33
+ no_repeat_ngram_size=3,
34
+ repetition_penalty=2.5,
35
+ length_penalty=1.0,
36
  early_stopping=True
37
  )
38
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
  return jsonify({"summary": summary})
40
 
41
  @app.route("/", methods=["GET"])
42
+ def index():
43
+ return "ViT5 summarization API is running."
44
 
45
  if __name__ == "__main__":
46
  app.run(host="0.0.0.0", port=7860)