VietCat commited on
Commit
a9b7eee
·
1 Parent(s): 4d593bf

fix broken encoding text issue

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -2,13 +2,12 @@ import os
2
  from flask import Flask, request, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- # Set thư mục cache hợp lệ cho Hugging Face
6
  os.environ["HF_HOME"] = "/app/cache"
7
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
8
 
9
  app = Flask(__name__)
10
 
11
- # Load mô hình và tokenizer
12
  model_name = "VietAI/vit5-base"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
@@ -21,11 +20,10 @@ def summarize():
21
  if not text:
22
  return jsonify({"error": "Missing 'text' field"}), 400
23
 
24
- # ✅ Thêm tiền tố đúng kiểu huấn luyện
25
  prompt = f"summarize: {text}"
26
  inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
27
 
28
- # Generate với các tham số tối ưu
29
  summary_ids = model.generate(
30
  inputs,
31
  max_length=100,
@@ -36,6 +34,7 @@ def summarize():
36
  length_penalty=1.0,
37
  early_stopping=True
38
  )
 
39
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
40
  return jsonify({"summary": summary})
41
 
 
2
  from flask import Flask, request, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Khai báo thư mục cache an toàn cho Hugging Face
6
  os.environ["HF_HOME"] = "/app/cache"
7
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
8
 
9
  app = Flask(__name__)
10
 
 
11
  model_name = "VietAI/vit5-base"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
20
  if not text:
21
  return jsonify({"error": "Missing 'text' field"}), 400
22
 
23
+ # ✅ Rất quan trọng: Thêm tiền tố 'summarize:'
24
  prompt = f"summarize: {text}"
25
  inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
26
 
 
27
  summary_ids = model.generate(
28
  inputs,
29
  max_length=100,
 
34
  length_penalty=1.0,
35
  early_stopping=True
36
  )
37
+
38
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
  return jsonify({"summary": summary})
40