sunbv56 commited on
Commit
e4dfcbd
·
verified ·
1 Parent(s): a2a8d8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -32
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
- from transformers import AutoModelForImageTextToText, AutoProcessor
7
  from gradio.events import SelectData
8
  import warnings
9
  import os
@@ -11,20 +11,48 @@ import requests
11
 
12
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
13
 
14
- # --- 1. Tải Model và Processor ---
15
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
16
- print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
 
 
 
 
 
 
 
 
 
17
  try:
18
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
19
- model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
 
 
 
 
 
 
20
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
21
  model.eval()
22
- print(f"✅ Model và processor đã được tải thành công!")
 
23
  except Exception as e:
24
  print(f"❌ Lỗi khi tải model/processor: {e}")
25
  exit()
26
 
27
- # --- 2. Hàm Inference Cốt lõi ---
 
28
  def process_vqa(image: Image.Image, question: str):
29
  if image.mode != "RGB":
30
  image = image.convert("RGB")
@@ -32,22 +60,22 @@ def process_vqa(image: Image.Image, question: str):
32
  prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
33
  model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
34
 
35
- generated_ids = model.generate(
36
- **model_inputs,
37
- max_new_tokens=128,
38
- do_sample=False,
39
- temperature=1.0,
40
- eos_token_id=processor.tokenizer.eos_token_id,
41
- pad_token_id=processor.tokenizer.pad_token_id
42
- )
 
 
43
 
44
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
45
  response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
46
  return response
47
 
48
- # --- 3. Logic Chatbot ---
49
- # ### THAY ĐỔI MỚI 1: Định nghĩa HTML và CSS cho hiệu ứng động ###
50
- # HTML cho hiệu ứng "đang gõ"
51
  THINKING_HTML = """
52
  <div class="typing-indicator">
53
  <span></span>
@@ -55,7 +83,6 @@ THINKING_HTML = """
55
  <span></span>
56
  </div>
57
  """
58
- # CSS để tạo hiệu ứng
59
  CUSTOM_CSS = """
60
  @keyframes blink {
61
  0% { opacity: .2; }
@@ -65,14 +92,14 @@ CUSTOM_CSS = """
65
  .typing-indicator {
66
  display: flex;
67
  align-items: center;
68
- justify-content: flex-start; /* Căn trái */
69
- padding: 8px 0; /* Thêm chút khoảng đệm */
70
  }
71
  .typing-indicator span {
72
  height: 10px;
73
  width: 10px;
74
  margin: 0 2px;
75
- background-color: #9E9E9E; /* Màu xám */
76
  border-radius: 50%;
77
  animation: blink 1.4s infinite both;
78
  }
@@ -84,7 +111,6 @@ CUSTOM_CSS = """
84
  }
85
  """
86
 
87
- # Hàm dành cho việc người dùng tự nhập câu hỏi
88
  def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
89
  if uploaded_image is None:
90
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
@@ -94,7 +120,6 @@ def manual_chat_responder(user_question: str, chat_history: list, uploaded_image
94
  return "", chat_history
95
 
96
  chat_history.append({"role": "user", "content": user_question})
97
- # ### THAY ĐỔI MỚI 2: Sử dụng HTML động thay cho text tĩnh ###
98
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
99
  yield "", chat_history
100
 
@@ -103,14 +128,12 @@ def manual_chat_responder(user_question: str, chat_history: list, uploaded_image
103
  chat_history[-1]["content"] = bot_response
104
  yield "", chat_history
105
 
106
- # Hàm dành riêng cho việc xử lý khi nhấn vào ví dụ
107
  def run_example(evt: SelectData):
108
  selected_example = example_list[evt.index]
109
  image_path, question = selected_example
110
  gr.Info(f"Đang chạy ví dụ: \"{question}\"")
111
  image = Image.open(image_path).convert("RGB")
112
 
113
- # ### THAY ĐỔI MỚI 3: Sử dụng HTML động thay cho text tĩnh ###
114
  chat_history = [
115
  {"role": "user", "content": question},
116
  {"role": "assistant", "content": THINKING_HTML}
@@ -125,8 +148,7 @@ def run_example(evt: SelectData):
125
  def clear_chat():
126
  return []
127
 
128
- # --- 4. Định nghĩa Giao diện Người dùng Gradio ---
129
- # ### THAY ĐỔI MỚI 4: Thêm CSS vào Blocks ###
130
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
131
  gr.Markdown("# 🤖 Vibook VQA Chatbot")
132
 
@@ -147,15 +169,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), ti
147
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
148
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
149
 
150
- # --- 5. Xử lý Sự kiện ---
151
  question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
152
-
153
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
154
-
155
  image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
156
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
157
 
158
- # --- Phần cuối ---
159
  if __name__ == "__main__":
160
  ASSETS_DIR = "assets"
161
  if not os.path.exists(ASSETS_DIR):
 
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
7
  from gradio.events import SelectData
8
  import warnings
9
  import os
 
11
 
12
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
13
 
14
+ # --- 1. Tải Model và Processor (ĐÃ TỐI ƯU) ---
15
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
16
+ print(f"🚀 Đang tải model '{MODEL_ID}' và processor với các tối ưu hóa...")
17
+
18
+ # ### THAY ĐỔI TỐI ƯU 1: Cấu hình Lượng tử hóa 4-bit (Quantization) ###
19
+ # Sử dụng 4-bit quantization để tăng tốc độ inference và giảm VRAM
20
+ quantization_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_quant_type="nf4",
23
+ bnb_4bit_compute_dtype=torch.bfloat16
24
+ )
25
+
26
  try:
27
+ # ### THAY ĐỔI TỐI ƯU 2: Tải model với Quantization và Flash Attention 2 ###
28
+ model = AutoModelForImageTextToText.from_pretrained(
29
+ MODEL_ID,
30
+ device_map="auto",
31
+ trust_remote_code=True,
32
+ quantization_config=quantization_config,
33
+ # attn_implementation="flash_attention_2" # Bỏ comment dòng này nếu bạn có GPU tương thích (NVIDIA 30xx/40xx) và đã cài flash-attn thành công
34
+ )
35
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
36
+
37
+ # ### THAY ĐỔI TỐI ƯU 3: Biên dịch model với torch.compile ###
38
+ # Lần chạy đầu tiên sẽ mất chút thời gian để biên dịch, nhưng các lần sau sẽ rất nhanh.
39
+ # Chỉ hoạt động trên Linux/MacOS với PyTorch 2.0+ và GPU.
40
+ try:
41
+ print("🚀 Đang cố gắng biên dịch model với torch.compile()...")
42
+ model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
43
+ print("✅ Biên dịch model thành công!")
44
+ except Exception as e:
45
+ print(f"⚠️ Không thể biên dịch model: {e}. Chạy ở chế độ thông thường.")
46
+
47
  model.eval()
48
+ print(f"✅ Model và processor đã được tải và tối ưu thành công!")
49
+
50
  except Exception as e:
51
  print(f"❌ Lỗi khi tải model/processor: {e}")
52
  exit()
53
 
54
+ # --- 2. Hàm Inference Cốt lõi (Không cần thay đổi) ---
55
+ # Các tối ưu đã được áp dụng ở tầng model, nên hàm này sẽ tự động chạy nhanh hơn.
56
  def process_vqa(image: Image.Image, question: str):
57
  if image.mode != "RGB":
58
  image = image.convert("RGB")
 
60
  prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
  model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
62
 
63
+ # Sử dụng torch.no_grad() để tắt việc tính toán gradient, giúp tiết kiệm bộ nhớ và tăng tốc độ
64
+ with torch.no_grad():
65
+ generated_ids = model.generate(
66
+ **model_inputs,
67
+ max_new_tokens=128,
68
+ do_sample=False,
69
+ temperature=1.0,
70
+ eos_token_id=processor.tokenizer.eos_token_id,
71
+ pad_token_id=processor.tokenizer.pad_token_id
72
+ )
73
 
74
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
75
  response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
76
  return response
77
 
78
+ # --- 3. Logic Chatbot (Giữ nguyên) ---
 
 
79
  THINKING_HTML = """
80
  <div class="typing-indicator">
81
  <span></span>
 
83
  <span></span>
84
  </div>
85
  """
 
86
  CUSTOM_CSS = """
87
  @keyframes blink {
88
  0% { opacity: .2; }
 
92
  .typing-indicator {
93
  display: flex;
94
  align-items: center;
95
+ justify-content: flex-start;
96
+ padding: 8px 0;
97
  }
98
  .typing-indicator span {
99
  height: 10px;
100
  width: 10px;
101
  margin: 0 2px;
102
+ background-color: #9E9E9E;
103
  border-radius: 50%;
104
  animation: blink 1.4s infinite both;
105
  }
 
111
  }
112
  """
113
 
 
114
  def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
115
  if uploaded_image is None:
116
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
 
120
  return "", chat_history
121
 
122
  chat_history.append({"role": "user", "content": user_question})
 
123
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
124
  yield "", chat_history
125
 
 
128
  chat_history[-1]["content"] = bot_response
129
  yield "", chat_history
130
 
 
131
  def run_example(evt: SelectData):
132
  selected_example = example_list[evt.index]
133
  image_path, question = selected_example
134
  gr.Info(f"Đang chạy ví dụ: \"{question}\"")
135
  image = Image.open(image_path).convert("RGB")
136
 
 
137
  chat_history = [
138
  {"role": "user", "content": question},
139
  {"role": "assistant", "content": THINKING_HTML}
 
148
  def clear_chat():
149
  return []
150
 
151
+ # --- 4. Giao diện Người dùng Gradio (Giữ nguyên) ---
 
152
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
153
  gr.Markdown("# 🤖 Vibook VQA Chatbot")
154
 
 
169
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
170
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
171
 
 
172
  question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
 
173
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
 
174
  image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
175
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
176
 
177
+ # --- Phần cuối (Giữ nguyên) ---
178
  if __name__ == "__main__":
179
  ASSETS_DIR = "assets"
180
  if not os.path.exists(ASSETS_DIR):