sunbv56 commited on
Commit
7936364
·
verified ·
1 Parent(s): b8ddda5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -102
app.py CHANGED
@@ -1,133 +1,124 @@
1
- # app.py
2
 
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
- from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
7
  from gradio.events import SelectData
8
  import warnings
9
  import os
10
  import requests
 
11
 
12
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
13
 
14
- # --- 1. Tải Model và Processor (ĐÃ TỐI ƯU) ---
15
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
16
- print(f"🚀 Đang tải model '{MODEL_ID}' và processor với các tối ưu hóa...")
17
-
18
- # ### THAY ĐỔI TỐI ƯU 1: Cấu hình Lượng tử hóa 4-bit (Quantization) ###
19
- # Sử dụng 4-bit quantization để tăng tốc độ inference giảm VRAM
20
- quantization_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_quant_type="nf4",
23
- bnb_4bit_compute_dtype=torch.bfloat16
24
- )
25
-
26
  try:
27
- # ### THAY ĐỔI TỐI ƯU 2: Tải model với Quantization và Flash Attention 2 ###
28
  model = AutoModelForImageTextToText.from_pretrained(
29
- MODEL_ID,
30
- device_map="auto",
31
  trust_remote_code=True,
32
- quantization_config=quantization_config,
33
- # attn_implementation="flash_attention_2" # Bỏ comment dòng này nếu bạn có GPU tương thích (NVIDIA 30xx/40xx) và đã cài flash-attn thành công
 
 
34
  )
35
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
36
 
37
- # ### THAY ĐỔI TỐI ƯU 3: Biên dịch model với torch.compile ###
38
- # Lần chạy đầu tiên sẽ mất chút thời gian để biên dịch, nhưng các lần sau sẽ rất nhanh.
39
- # Chỉ hoạt động trên Linux/MacOS với PyTorch 2.0+GPU.
40
- try:
41
- print("🚀 Đang cố gắng biên dịch model với torch.compile()...")
42
- model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
43
- print("✅ Biên dịch model thành công!")
44
- except Exception as e:
45
- print(f"⚠️ Không thể biên dịch model: {e}. Chạy ở chế độ thông thường.")
46
 
47
  model.eval()
48
- print(f"✅ Model và processor đã được tải và tối ưu thành công!")
49
-
50
  except Exception as e:
51
- print(f"❌ Lỗi khi tải model/processor: {e}")
52
  exit()
53
 
54
- # --- 2. Hàm Inference Cốt lõi (Không cần thay đổi) ---
55
- # Các tối ưu đã được áp dụng ở tầng model, nên hàm này sẽ tự động chạy nhanh hơn.
56
- def process_vqa(image: Image.Image, question: str):
57
- if image.mode != "RGB":
58
- image = image.convert("RGB")
59
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
60
- prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
- model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
62
-
63
- # Sử dụng torch.no_grad() để tắt việc tính toán gradient, giúp tiết kiệm bộ nhớ và tăng tốc độ
64
- with torch.no_grad():
65
- generated_ids = model.generate(
66
- **model_inputs,
67
- max_new_tokens=128,
68
- do_sample=False,
69
- temperature=1.0,
70
- eos_token_id=processor.tokenizer.eos_token_id,
71
- pad_token_id=processor.tokenizer.pad_token_id
72
- )
 
 
 
 
 
 
73
 
 
74
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
75
- response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
76
- return response
77
-
78
- # --- 3. Logic Chatbot (Giữ nguyên) ---
79
- THINKING_HTML = """
80
- <div class="typing-indicator">
81
- <span></span>
82
- <span></span>
83
- <span></span>
84
- </div>
85
- """
86
- CUSTOM_CSS = """
87
- @keyframes blink {
88
- 0% { opacity: .2; }
89
- 20% { opacity: 1; }
90
- 100% { opacity: .2; }
91
- }
92
- .typing-indicator {
93
- display: flex;
94
- align-items: center;
95
- justify-content: flex-start;
96
- padding: 8px 0;
97
- }
98
- .typing-indicator span {
99
- height: 10px;
100
- width: 10px;
101
- margin: 0 2px;
102
- background-color: #9E9E9E;
103
- border-radius: 50%;
104
- animation: blink 1.4s infinite both;
105
- }
106
- .typing-indicator span:nth-child(2) {
107
- animation-delay: .2s;
108
- }
109
- .typing-indicator span:nth-child(3) {
110
- animation-delay: .4s;
111
- }
112
- """
113
-
114
- def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
115
  if uploaded_image is None:
116
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
117
- return "", chat_history
118
  if not user_question or not user_question.strip():
119
  gr.Warning("Vui lòng nhập một câu hỏi.")
120
- return "", chat_history
121
-
122
  chat_history.append({"role": "user", "content": user_question})
123
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
124
- yield "", chat_history
125
-
126
- bot_response = process_vqa(uploaded_image, user_question)
 
 
 
 
127
 
128
  chat_history[-1]["content"] = bot_response
129
- yield "", chat_history
 
 
130
 
 
 
131
  def run_example(evt: SelectData):
132
  selected_example = example_list[evt.index]
133
  image_path, question = selected_example
@@ -140,7 +131,7 @@ def run_example(evt: SelectData):
140
  ]
141
  yield image, question, chat_history
142
 
143
- bot_response = process_vqa(image, question)
144
 
145
  chat_history[-1]["content"] = bot_response
146
  yield image, question, chat_history
@@ -148,9 +139,9 @@ def run_example(evt: SelectData):
148
  def clear_chat():
149
  return []
150
 
151
- # --- 4. Giao diện Người dùng Gradio (Giữ nguyên) ---
152
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
153
- gr.Markdown("# 🤖 Vibook VQA Chatbot")
154
 
155
  example_list = [
156
  ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
@@ -169,13 +160,20 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), ti
169
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
170
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
171
 
172
- question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
 
 
 
 
 
 
173
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
 
174
  image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
175
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
176
 
177
- # --- Phần cuối (Giữ nguyên) ---
178
- if __name__ == "__main__":
179
  ASSETS_DIR = "assets"
180
  if not os.path.exists(ASSETS_DIR):
181
  os.makedirs(ASSETS_DIR)
@@ -200,4 +198,6 @@ if __name__ == "__main__":
200
  except requests.exceptions.RequestException as e:
201
  print(f" Lỗi khi tải {filename}: {e}")
202
 
 
 
203
  demo.launch(debug=True)
 
1
+ # app_optimized.py
2
 
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ from transformers import AutoModelForImageTextToText, AutoProcessor
7
  from gradio.events import SelectData
8
  import warnings
9
  import os
10
  import requests
11
+ from typing import List
12
 
13
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
14
 
15
+ # --- 1. Tải Model và Processor với TỐI ƯU HÓA ---
16
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
17
+ print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
18
+
19
+ # *** TỐI ƯU HÓA 1: Lượng tử hóa (Quantization) ***
20
+ # Sử dụng `load_in_8bit=True` để tăng tốc đáng kể trên CPU.
21
+ # Yêu cầu `pip install bitsandbytes accelerate`
22
+ # Lưu ý: Lượng tử hóa sẽ không dùng `torch_dtype` vì nó hoạt động trên các kiểu dữ liệu khác.
23
+ # `device_map="auto"` sẽ tự động xử lý việc đặt model lên thiết bị.
24
+ use_gpu = torch.cuda.is_available()
 
 
25
  try:
 
26
  model = AutoModelForImageTextToText.from_pretrained(
27
+ MODEL_ID,
28
+ device_map="auto",
29
  trust_remote_code=True,
30
+ # Chỉ lượng tử hóa khi chạy trên CPU để tiết kiệm tài nguyên và tăng tốc
31
+ load_in_8bit=not use_gpu,
32
+ # Nếu có GPU, sử dụng bfloat16/float16 để có hiệu năng tốt nhất
33
+ torch_dtype=torch.bfloat16 if use_gpu and torch.cuda.is_bf16_supported() else torch.float16
34
  )
35
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
36
 
37
+ # *** TỐI ƯU HÓA 2: Sử dụng torch.compile() (cho PyTorch 2.0+) ***
38
+ # Biên dịch model để tăng tốc độ inference sau lần chạy đầu tiên.
39
+ # Chế độ 'reduce-overhead' tốt cho các input nhỏgiảm gánh nặng của framework.
40
+ print("🚀 Đang biên dịch model với torch.compile()...")
41
+ model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
42
+ print("✅ Model đã được biên dịch.")
 
 
 
43
 
44
  model.eval()
45
+ print(f"✅ Model và processor đã được tải và tối ưu hóa thành công!")
 
46
  except Exception as e:
47
+ print(f"❌ Lỗi khi tải/tối ưu hóa model/processor: {e}")
48
  exit()
49
 
50
+ # --- 2. Hàm Inference Cốt lõi đã được sửa đổi để xử lý BATCH ---
51
+ def process_vqa_batch(images: List[Image.Image], questions: List[str]):
52
+ prompts = []
53
+ processed_images = []
54
+
55
+ # Chuẩn bị prompt ảnh cho từng item trong batch
56
+ for image, question in zip(images, questions):
57
+ if image.mode != "RGB":
58
+ image = image.convert("RGB")
59
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
60
+ prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ prompts.append(prompt_text)
62
+ processed_images.append(image)
63
+
64
+ # Xử lý cả batch cùng một lúc
65
+ model_inputs = processor(text=prompts, images=processed_images, return_tensors="pt", padding=True).to(model.device)
66
+
67
+ generated_ids = model.generate(
68
+ **model_inputs,
69
+ max_new_tokens=128,
70
+ do_sample=False,
71
+ temperature=1.0,
72
+ eos_token_id=processor.tokenizer.eos_token_id,
73
+ pad_token_id=processor.tokenizer.pad_token_id
74
+ )
75
 
76
+ # Decode kết quả cho cả batch
77
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
78
+ responses = processor.batch_decode(generated_ids, skip_special_tokens=True)
79
+
80
+ # Strip() cho mỗi response trong list
81
+ return [res.strip() for res in responses]
82
+
83
+
84
+ # --- 3. Logic Chatbot ---
85
+ THINKING_HTML = """<div class="typing-indicator"><span></span><span></span><span></span></div>"""
86
+ CUSTOM_CSS = """@keyframes blink{0%{opacity:.2}20%{opacity:1}100%{opacity:.2}}.typing-indicator{display:flex;align-items:center;justify-content:flex-start;padding:8px 0}.typing-indicator span{height:10px;width:10px;margin:0 2px;background-color:#9E9E9E;border-radius:50%;animation:blink 1.4s infinite both}.typing-indicator span:nth-child(2){animation-delay:.2s}.typing-indicator span:nth-child(3){animation-delay:.4s}"""
87
+
88
+ # *** TỐI ƯU HÓA 3: Sửa đổi hàm để tương thích với BATCHING của Gradio ***
89
+ # Hàm này giờ nhận vào một list các câu hỏi và trả về một list các câu trả lời
90
+ def manual_chat_responder(user_questions: List[str], chat_histories: List[list], uploaded_images: List[Image.Image]):
91
+ # Do cách Gradio batching hoạt động, chúng ta chỉ lấy item đầu tiên
92
+ # mỗi người dùng có một giao diện riêng biệt.
93
+ # Tuy nhiên, hàm process_vqa_batch vẫn được thiết kế để xử lý batch thực sự.
94
+ user_question = user_questions[0]
95
+ chat_history = chat_histories[0]
96
+ uploaded_image = uploaded_images[0]
97
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if uploaded_image is None:
99
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
100
+ return [("", chat_history)] # Phải trả về list
101
  if not user_question or not user_question.strip():
102
  gr.Warning("Vui lòng nhập một câu hỏi.")
103
+ return [("", chat_history)] # Phải trả về list
104
+
105
  chat_history.append({"role": "user", "content": user_question})
106
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
107
+
108
+ # Tạm thời yield để cập nhật UI
109
+ # Gradio batching không hỗ trợ yield trực tiếp, nên chúng ta sẽ bỏ qua bước này
110
+ # và trả về kết quả cuối cùng. Người dùng sẽ thấy indicator trong một khoảng thời gian ngắn.
111
+
112
+ # Gọi hàm xử lý batch (dù chỉ có 1 item)
113
+ bot_response = process_vqa_batch([uploaded_image], [user_question])[0]
114
 
115
  chat_history[-1]["content"] = bot_response
116
+
117
+ # Phải trả về một list các kết quả, tương ứng với batch đầu vào
118
+ return [("", chat_history)]
119
 
120
+
121
+ # Hàm chạy ví dụ không cần batching vì nó chỉ là một hành động đơn lẻ
122
  def run_example(evt: SelectData):
123
  selected_example = example_list[evt.index]
124
  image_path, question = selected_example
 
131
  ]
132
  yield image, question, chat_history
133
 
134
+ bot_response = process_vqa_batch([image], [question])[0]
135
 
136
  chat_history[-1]["content"] = bot_response
137
  yield image, question, chat_history
 
139
  def clear_chat():
140
  return []
141
 
142
+ # --- 4. Định nghĩa Giao diện Người dùng Gradio ---
143
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
144
+ gr.Markdown("# 🤖 Vibook VQA Chatbot (Optimized)")
145
 
146
  example_list = [
147
  ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
 
160
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
161
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
162
 
163
+ # --- 5. Xử Sự kiện với TỐI ƯU HÓA BATCHING ---
164
+ question_input.submit(
165
+ fn=manual_chat_responder,
166
+ inputs=[question_input, chatbot, image_input],
167
+ outputs=[question_input, chatbot]
168
+ ).batch(batch_size=4, max_latency=0.1) # Gom tối đa 4 request, hoặc xử lý sau mỗi 0.1 giây
169
+
170
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
171
+
172
  image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
173
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
174
 
175
+ # --- Phần cuối ---
176
+ def setup_examples():
177
  ASSETS_DIR = "assets"
178
  if not os.path.exists(ASSETS_DIR):
179
  os.makedirs(ASSETS_DIR)
 
198
  except requests.exceptions.RequestException as e:
199
  print(f" Lỗi khi tải {filename}: {e}")
200
 
201
+ if __name__ == "__main__":
202
+ setup_examples()
203
  demo.launch(debug=True)