sunbv56 commited on
Commit
3fd1c1a
·
verified ·
1 Parent(s): 70d2c5c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Phiên bản cuối cùng đã sửa lỗi và cảnh báo)
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import AutoModelForImageTextToText, AutoProcessor
7
+ from gradio.events import SelectData
8
+ import warnings
9
+ import os
10
+ from urllib.request import urlretrieve
11
+
12
+ warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
13
+
14
+ # --- 1. Tải Model và Processor ---
15
+ MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook-lora-merged"
16
+ print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
17
+ try:
18
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
19
+ model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
20
+ # SỬA LỖI 3: Thêm use_fast=True để tắt cảnh báo
21
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
22
+ model.eval()
23
+ print(f"✅ Model và processor đã được tải thành công!")
24
+ except Exception as e:
25
+ print(f"❌ Lỗi khi tải model/processor: {e}")
26
+ exit()
27
+
28
+ # --- 2. Hàm Inference Cốt lõi ---
29
+ def process_vqa(image: Image.Image, question: str):
30
+ if image.mode != "RGB":
31
+ image = image.convert("RGB")
32
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
33
+ prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
+ model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
35
+ generated_ids = model.generate(
36
+ **model_inputs,
37
+ max_new_tokens=1024,
38
+ do_sample=False,
39
+ eos_token_id=processor.tokenizer.eos_token_id,
40
+ pad_token_id=processor.tokenizer.pad_token_id
41
+ )
42
+ generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
43
+ response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
44
+ return response
45
+
46
+ # --- 3. Logic Chatbot ---
47
+ # Hàm dành cho việc người dùng tự nhập câu hỏi
48
+ def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
49
+ if uploaded_image is None:
50
+ gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
51
+ return "", chat_history
52
+ if not user_question or not user_question.strip():
53
+ gr.Warning("Vui lòng nhập một câu hỏi.")
54
+ return "", chat_history
55
+
56
+ # SỬA LỖI 2: Sử dụng định dạng `messages` mới
57
+ chat_history.append({"role": "user", "content": user_question})
58
+ yield "", chat_history
59
+
60
+ bot_response = process_vqa(uploaded_image, user_question)
61
+ chat_history.append({"role": "assistant", "content": bot_response})
62
+ yield "", chat_history
63
+
64
+ # Hàm dành riêng cho việc xử lý khi nhấn vào ví dụ
65
+ def run_example(example_list: list, evt: SelectData):
66
+ selected_example = example_list[evt.index]
67
+ image_path, question = selected_example
68
+ gr.Info(f"Đang chạy ví dụ: \"{question}\"")
69
+ image = Image.open(image_path).convert("RGB")
70
+
71
+ # SỬA LỖI 2: Bắt đầu cuộc trò chuyện với định dạng `messages` mới
72
+ chat_history = [{"role": "user", "content": question}]
73
+
74
+ bot_response = process_vqa(image, question)
75
+ chat_history.append({"role": "assistant", "content": bot_response})
76
+
77
+ return image, question, chat_history
78
+
79
+ def clear_chat():
80
+ return []
81
+
82
+ # --- 4. Định nghĩa Giao diện Người dùng Gradio ---
83
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot") as demo:
84
+ gr.Markdown("# 🤖 Vibook VQA Chatbot")
85
+
86
+ example_list = [
87
+ ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
88
+ ["./assets/book_example_1.jpg", "Ai là người đã viết cuốn sách này?"],
89
+ ["./assets/book_example_2.jpg", "tác giả và tên của cuốn sách là gì?"],
90
+ ]
91
+
92
+ with gr.Row(equal_height=False):
93
+ with gr.Column(scale=1, min_width=350):
94
+ gr.Markdown("### Bảng điều khiển")
95
+ image_input = gr.Image(type="pil", label="Tải ảnh lên", sources=["upload", "clipboard"])
96
+ gr.Markdown("---")
97
+ gr.Markdown("### Ví dụ (Nhấn để chạy)")
98
+ example_dataset = gr.Dataset(
99
+ components=[gr.Image(visible=False), gr.Textbox(visible=False)],
100
+ samples=example_list,
101
+ label="Ví dụ",
102
+ type="index"
103
+ )
104
+ with gr.Column(scale=2):
105
+ # SỬA LỖI 2: Thêm type="messages" và khởi tạo giá trị
106
+ chatbot = gr.Chatbot(
107
+ label="Cuộc trò chuyện",
108
+ bubble_full_width=False,
109
+ height=600,
110
+ avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"),
111
+ type="messages",
112
+ value=[]
113
+ )
114
+ question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
115
+
116
+ # --- 5. Xử lý Sự kiện ---
117
+ question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
118
+ example_dataset.select(fn=run_example, inputs=[example_dataset], outputs=[image_input, question_input, chatbot], show_progress="full")
119
+ image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
120
+ image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
121
+
122
+ # --- Phần cuối ---
123
+ if __name__ == "__main__":
124
+ ASSETS_DIR = "assets"
125
+ if not os.path.exists(ASSETS_DIR):
126
+ os.makedirs(ASSETS_DIR)
127
+ print("Đã tạo thư mục 'assets' cho các hình ảnh ví dụ.")
128
+
129
+ # SỬA LỖI 1: Thêm định nghĩa EXAMPLE_FILES bị thiếu
130
+ EXAMPLE_FILES = {
131
+ "book_example_1.jpg": "https://huggingface.co/spaces/sunbv56/demo-qwen2.5-vl-vqa-vibook/resolve/main/assets/book_example_1.jpg",
132
+ "book_example_2.jpg": "https://huggingface.co/spaces/sunbv56/demo-qwen2.5-vl-vqa-vibook/resolve/main/assets/book_example_2.jpg"
133
+ }
134
+
135
+ for filename, url in EXAMPLE_FILES.items():
136
+ filepath = os.path.join(ASSETS_DIR, filename)
137
+ if not os.path.exists(filepath):
138
+ print(f"Đang tải xuống hình ảnh ví dụ: {filename}...")
139
+ # Sửa lỗi logic tải file
140
+ urlretrieve(url, filepath)
141
+ print("...Đã xong.")
142
+
143
+ demo.launch(debug=True)