sunbv56 commited on
Commit
21219d5
·
verified ·
1 Parent(s): 8f3e0b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -84
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
 
3
  import gradio as gr
4
  import torch
@@ -8,66 +8,29 @@ from gradio.events import SelectData
8
  import warnings
9
  import os
10
  import requests
11
- from typing import List
12
 
13
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
14
 
15
- # --- 1. Tải Model và Processor với TỐI ƯU HÓA (KHÔNG LƯỢNG TỬ HÓA) ---
16
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
17
  print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
18
-
19
- use_gpu = torch.cuda.is_available()
20
-
21
- # *** THAY ĐỔI: Không sử dụng lượng tử hóa. Chọn dtype phù hợp. ***
22
- if use_gpu:
23
- # Nếu có GPU, sử dụng bfloat16/float16 để có hiệu năng tốt nhất
24
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
25
- print(f"Sử dụng GPU với dtype: {dtype}")
26
- else:
27
- # Trên CPU, chạy với độ chính xác đầy đủ float32.
28
- # Lượng tử hóa đã bị loại bỏ theo yêu cầu.
29
- print("⚠️ Cảnh báo: Chạy model ở độ chính xác float32 trên CPU.")
30
- print(" -> Tốc độ sẽ chậm hơn so với phiên bản lượng tử hóa.")
31
- print(" -> Các tối ưu hóa khác (torch.compile, batching) vẫn được áp dụng.")
32
- dtype = torch.float32
33
-
34
  try:
35
- model = AutoModelForImageTextToText.from_pretrained(
36
- MODEL_ID,
37
- torch_dtype=dtype, # Sử dụng dtype đã chọn
38
- device_map="auto", # `accelerate` sẽ xử lý việc đặt model lên thiết bị
39
- trust_remote_code=True,
40
- )
41
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
42
-
43
- # *** TỐI ƯU HÓA 1: Sử dụng torch.compile() (cho PyTorch 2.0+) ***
44
- # Biên dịch model để tăng tốc độ inference sau lần chạy đầu tiên.
45
- print("🚀 Đang biên dịch model với torch.compile()... (có thể mất một chút thời gian cho lần đầu)")
46
- model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
47
- print("✅ Model đã được biên dịch.")
48
-
49
  model.eval()
50
- print(f"✅ Model và processor đã được tải và tối ưu hóa thành công!")
51
  except Exception as e:
52
- print(f"❌ Lỗi khi tải/tối ưu hóa model/processor: {e}")
53
  exit()
54
 
55
- # --- 2. Hàm Inference Cốt lõi đã được sửa đổi để xử lý BATCH ---
56
- def process_vqa_batch(images: List[Image.Image], questions: List[str]):
57
- prompts = []
58
- processed_images = []
59
-
60
- # Chuẩn bị prompt và ảnh cho từng item trong batch
61
- for image, question in zip(images, questions):
62
- if image.mode != "RGB":
63
- image = image.convert("RGB")
64
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
65
- prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
- prompts.append(prompt_text)
67
- processed_images.append(image)
68
-
69
- # Xử lý cả batch cùng một lúc
70
- model_inputs = processor(text=prompts, images=processed_images, return_tensors="pt", padding=True).to(model.device)
71
 
72
  generated_ids = model.generate(
73
  **model_inputs,
@@ -78,54 +41,83 @@ def process_vqa_batch(images: List[Image.Image], questions: List[str]):
78
  pad_token_id=processor.tokenizer.pad_token_id
79
  )
80
 
81
- # Decode kết quả cho cả batch
82
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
83
- responses = processor.batch_decode(generated_ids, skip_special_tokens=True)
84
-
85
- # Strip() cho mỗi response trong list
86
- return [res.strip() for res in responses]
87
-
88
 
89
  # --- 3. Logic Chatbot ---
90
- THINKING_HTML = """<div class="typing-indicator"><span></span><span></span><span></span></div>"""
91
- CUSTOM_CSS = """@keyframes blink{0%{opacity:.2}20%{opacity:1}100%{opacity:.2}}.typing-indicator{display:flex;align-items:center;justify-content:flex-start;padding:8px 0}.typing-indicator span{height:10px;width:10px;margin:0 2px;background-color:#9E9E9E;border-radius:50%;animation:blink 1.4s infinite both}.typing-indicator span:nth-child(2){animation-delay:.2s}.typing-indicator span:nth-child(3){animation-delay:.4s}"""
92
-
93
- # *** TỐI ƯU HÓA 2: Sửa đổi hàm để tương thích với BATCHING của Gradio ***
94
- def manual_chat_responder(user_questions: List[str], chat_histories: List[list], uploaded_images: List[Image.Image]):
95
- user_question = user_questions[0]
96
- chat_history = chat_histories[0]
97
- uploaded_image = uploaded_images[0]
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if uploaded_image is None:
100
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
101
- return [("", chat_history)]
102
  if not user_question or not user_question.strip():
103
  gr.Warning("Vui lòng nhập một câu hỏi.")
104
- return [("", chat_history)]
105
-
106
  chat_history.append({"role": "user", "content": user_question})
 
107
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
108
-
109
- # Gọi hàm xử lý batch (dù chỉ có 1 item trong trường hợp này)
110
- bot_response = process_vqa_batch([uploaded_image], [user_question])[0]
111
 
112
  chat_history[-1]["content"] = bot_response
113
-
114
- return [("", chat_history)]
115
 
 
116
  def run_example(evt: SelectData):
117
  selected_example = example_list[evt.index]
118
  image_path, question = selected_example
119
  gr.Info(f"Đang chạy ví dụ: \"{question}\"")
120
  image = Image.open(image_path).convert("RGB")
121
 
 
122
  chat_history = [
123
  {"role": "user", "content": question},
124
  {"role": "assistant", "content": THINKING_HTML}
125
  ]
126
  yield image, question, chat_history
127
 
128
- bot_response = process_vqa_batch([image], [question])[0]
129
 
130
  chat_history[-1]["content"] = bot_response
131
  yield image, question, chat_history
@@ -134,8 +126,9 @@ def clear_chat():
134
  return []
135
 
136
  # --- 4. Định nghĩa Giao diện Người dùng Gradio ---
 
137
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
138
- gr.Markdown("# 🤖 Vibook VQA Chatbot (Optimized - No Quantization)")
139
 
140
  example_list = [
141
  ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
@@ -154,12 +147,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), ti
154
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
155
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
156
 
157
- # --- 5. Xử lý Sự kiện với TỐI ƯU HÓA BATCHING ---
158
- question_input.submit(
159
- fn=manual_chat_responder,
160
- inputs=[question_input, chatbot, image_input],
161
- outputs=[question_input, chatbot]
162
- ).batch(batch_size=4, max_latency=0.1)
163
 
164
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
165
 
@@ -167,7 +156,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), ti
167
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
168
 
169
  # --- Phần cuối ---
170
- def setup_examples():
171
  ASSETS_DIR = "assets"
172
  if not os.path.exists(ASSETS_DIR):
173
  os.makedirs(ASSETS_DIR)
@@ -192,6 +181,4 @@ def setup_examples():
192
  except requests.exceptions.RequestException as e:
193
  print(f" Lỗi khi tải {filename}: {e}")
194
 
195
- if __name__ == "__main__":
196
- setup_examples()
197
  demo.launch(debug=True)
 
1
+ # app.py
2
 
3
  import gradio as gr
4
  import torch
 
8
  import warnings
9
  import os
10
  import requests
 
11
 
12
  warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
13
 
14
+ # --- 1. Tải Model và Processor ---
15
  MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
16
  print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
19
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
 
 
 
 
20
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
 
 
 
 
 
 
 
21
  model.eval()
22
+ print(f"✅ Model và processor đã được tải thành công!")
23
  except Exception as e:
24
+ print(f"❌ Lỗi khi tải model/processor: {e}")
25
  exit()
26
 
27
+ # --- 2. Hàm Inference Cốt lõi ---
28
+ def process_vqa(image: Image.Image, question: str):
29
+ if image.mode != "RGB":
30
+ image = image.convert("RGB")
31
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
32
+ prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
33
+ model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
34
 
35
  generated_ids = model.generate(
36
  **model_inputs,
 
41
  pad_token_id=processor.tokenizer.pad_token_id
42
  )
43
 
 
44
  generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
45
+ response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
46
+ return response
 
 
 
47
 
48
  # --- 3. Logic Chatbot ---
49
+ # ### THAY ĐỔI MỚI 1: Định nghĩa HTML và CSS cho hiệu ứng động ###
50
+ # HTML cho hiệu ứng "đang "
51
+ THINKING_HTML = """
52
+ <div class="typing-indicator">
53
+ <span></span>
54
+ <span></span>
55
+ <span></span>
56
+ </div>
57
+ """
58
+ # CSS để tạo hiệu ứng
59
+ CUSTOM_CSS = """
60
+ @keyframes blink {
61
+ 0% { opacity: .2; }
62
+ 20% { opacity: 1; }
63
+ 100% { opacity: .2; }
64
+ }
65
+ .typing-indicator {
66
+ display: flex;
67
+ align-items: center;
68
+ justify-content: flex-start; /* Căn trái */
69
+ padding: 8px 0; /* Thêm chút khoảng đệm */
70
+ }
71
+ .typing-indicator span {
72
+ height: 10px;
73
+ width: 10px;
74
+ margin: 0 2px;
75
+ background-color: #9E9E9E; /* Màu xám */
76
+ border-radius: 50%;
77
+ animation: blink 1.4s infinite both;
78
+ }
79
+ .typing-indicator span:nth-child(2) {
80
+ animation-delay: .2s;
81
+ }
82
+ .typing-indicator span:nth-child(3) {
83
+ animation-delay: .4s;
84
+ }
85
+ """
86
+
87
+ # Hàm dành cho việc người dùng tự nhập câu hỏi
88
+ def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
89
  if uploaded_image is None:
90
  gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
91
+ return "", chat_history
92
  if not user_question or not user_question.strip():
93
  gr.Warning("Vui lòng nhập một câu hỏi.")
94
+ return "", chat_history
95
+
96
  chat_history.append({"role": "user", "content": user_question})
97
+ # ### THAY ĐỔI MỚI 2: Sử dụng HTML động thay cho text tĩnh ###
98
  chat_history.append({"role": "assistant", "content": THINKING_HTML})
99
+ yield "", chat_history
100
+
101
+ bot_response = process_vqa(uploaded_image, user_question)
102
 
103
  chat_history[-1]["content"] = bot_response
104
+ yield "", chat_history
 
105
 
106
+ # Hàm dành riêng cho việc xử lý khi nhấn vào ví dụ
107
  def run_example(evt: SelectData):
108
  selected_example = example_list[evt.index]
109
  image_path, question = selected_example
110
  gr.Info(f"Đang chạy ví dụ: \"{question}\"")
111
  image = Image.open(image_path).convert("RGB")
112
 
113
+ # ### THAY ĐỔI MỚI 3: Sử dụng HTML động thay cho text tĩnh ###
114
  chat_history = [
115
  {"role": "user", "content": question},
116
  {"role": "assistant", "content": THINKING_HTML}
117
  ]
118
  yield image, question, chat_history
119
 
120
+ bot_response = process_vqa(image, question)
121
 
122
  chat_history[-1]["content"] = bot_response
123
  yield image, question, chat_history
 
126
  return []
127
 
128
  # --- 4. Định nghĩa Giao diện Người dùng Gradio ---
129
+ # ### THAY ĐỔI MỚI 4: Thêm CSS vào Blocks ###
130
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot", css=CUSTOM_CSS) as demo:
131
+ gr.Markdown("# 🤖 Vibook VQA Chatbot")
132
 
133
  example_list = [
134
  ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
 
147
  chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
148
  question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
149
 
150
+ # --- 5. Xử lý Sự kiện ---
151
+ question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
 
 
 
 
152
 
153
  example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
154
 
 
156
  image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
157
 
158
  # --- Phần cuối ---
159
+ if __name__ == "__main__":
160
  ASSETS_DIR = "assets"
161
  if not os.path.exists(ASSETS_DIR):
162
  os.makedirs(ASSETS_DIR)
 
181
  except requests.exceptions.RequestException as e:
182
  print(f" Lỗi khi tải {filename}: {e}")
183
 
 
 
184
  demo.launch(debug=True)