File size: 7,484 Bytes
9bc5cc9
3fd1c1a
 
 
 
 
 
 
 
aa2e87f
3fd1c1a
 
 
 
d5bebb0
3fd1c1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bc5cc9
 
 
 
 
 
 
 
 
 
 
 
3fd1c1a
 
 
 
 
9bc5cc9
3fd1c1a
 
 
 
 
 
 
9bc5cc9
 
3fd1c1a
9bc5cc9
3fd1c1a
9bc5cc9
3fd1c1a
9bc5cc9
 
 
3fd1c1a
 
9bc5cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fd1c1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5bebb0
3fd1c1a
 
aa2e87f
3fd1c1a
aa2e87f
3fd1c1a
 
 
 
aa2e87f
9bc5cc9
 
aa2e87f
3fd1c1a
 
 
 
 
 
 
 
 
 
 
aa2e87f
 
3fd1c1a
 
aa2e87f
3fd1c1a
 
 
 
aa2e87f
 
 
 
 
 
 
 
3fd1c1a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# app.py (Phiên bản cuối cùng: Sửa lỗi cảnh báo và thêm tin nhắn "Thinking...")

import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from gradio.events import SelectData
import warnings
import os
import requests

warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")

# --- 1. Tải Model và Processor ---
MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
try:
    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
    model.eval()
    print(f"✅ Model và processor đã được tải thành công!")
except Exception as e:
    print(f"❌ Lỗi khi tải model/processor: {e}")
    exit()

# --- 2. Hàm Inference Cốt lõi ---
def process_vqa(image: Image.Image, question: str):
    if image.mode != "RGB":
        image = image.convert("RGB")
    messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
    prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
    
    # SỬA LỖI 1: Ghi đè `temperature` để tắt cảnh báo.
    # Đặt là 1.0 (trung tính) vì do_sample=False nên nó sẽ không được sử dụng.
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024,
        do_sample=False,
        temperature=1.0, 
        eos_token_id=processor.tokenizer.eos_token_id,
        pad_token_id=processor.tokenizer.pad_token_id
    )
    
    generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
    response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
    return response

# --- 3. Logic Chatbot ---
# Hàm dành cho việc người dùng tự nhập câu hỏi
def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
    if uploaded_image is None:
        gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
        return "", chat_history
    if not user_question or not user_question.strip():
        gr.Warning("Vui lòng nhập một câu hỏi.")
        return "", chat_history
    
    # THÊM TÍNH NĂNG 2: Hiển thị tin nhắn chờ
    chat_history.append({"role": "user", "content": user_question})
    chat_history.append({"role": "assistant", "content": "🤔 Thinking..."})
    yield "", chat_history

    bot_response = process_vqa(uploaded_image, user_question)
    
    # THÊM TÍNH NĂNG 2: Cập nhật tin nhắn chờ bằng câu trả lời thật
    chat_history[-1]["content"] = bot_response
    yield "", chat_history

# Hàm dành riêng cho việc xử lý khi nhấn vào ví dụ
def run_example(evt: SelectData):
    # Dùng list toàn cục đã được định nghĩa trong khối `with`
    selected_example = example_list[evt.index]
    image_path, question = selected_example
    gr.Info(f"Đang chạy ví dụ: \"{question}\"")
    image = Image.open(image_path).convert("RGB")
    
    # THÊM TÍNH NĂNG 2: Hiển thị tin nhắn chờ
    chat_history = [
        {"role": "user", "content": question},
        {"role": "assistant", "content": "🤔 Thinking..."}
    ]
    # `yield` lần đầu để cập nhật UI ngay lập tức
    yield image, question, chat_history

    # Chạy xử lý và lấy câu trả lời thật
    bot_response = process_vqa(image, question)
    
    # THÊM TÍNH NĂNG 2: Cập nhật tin nhắn chờ bằng câu trả lời thật
    chat_history[-1]["content"] = bot_response
    # `yield` lần cuối để hiển thị kết quả cuối cùng
    yield image, question, chat_history

def clear_chat():
    return []

# --- 4. Định nghĩa Giao diện Người dùng Gradio ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot") as demo:
    gr.Markdown("# 🤖 Vibook VQA Chatbot")
    
    example_list = [
        ["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
        ["./assets/book_example_1.jpg", "Ai là người đã viết cuốn sách này?"],
        ["./assets/book_example_2.jpg", "tác giả và tên của cuốn sách là gì?"],
    ]
    
    with gr.Row(equal_height=False):
        with gr.Column(scale=1, min_width=350):
            gr.Markdown("### Bảng điều khiển")
            image_input = gr.Image(type="pil", label="Tải ảnh lên", sources=["upload", "clipboard", "webcam"])
            gr.Markdown("---")
            gr.Markdown("### Ví dụ (Nhấn để chạy)")
            example_dataset = gr.Dataset(components=[gr.Image(visible=False), gr.Textbox(visible=False)], samples=example_list, label="Ví dụ", type="index")
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
            question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)

    # --- 5. Xử lý Sự kiện ---
    question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
    
    # THÊM TÍNH NĂNG 2: Hàm `run_example` giờ là một generator, Gradio sẽ tự động xử lý các `yield`
    example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
    
    image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
    image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])

# --- Phần cuối ---
if __name__ == "__main__":
    ASSETS_DIR = "assets"
    if not os.path.exists(ASSETS_DIR):
        os.makedirs(ASSETS_DIR)
        print("Đã tạo thư mục 'assets' cho các hình ảnh ví dụ.")
    
    EXAMPLE_FILES = {
        "book_example_1.jpg": "https://cdn0.fahasa.com/media/catalog/product/d/i/dieu-ky-dieu-cua-tiem-tap-hoa-namiya---tai-ban-2020.jpg",
        "book_example_2.jpg": "https://cdn0.fahasa.com/media/catalog/product/d/r/dr.-stone_bia_tap-26.jpg"
    }

    headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"}
    for filename, url in EXAMPLE_FILES.items():
        filepath = os.path.join(ASSETS_DIR, filename)
        if not os.path.exists(filepath):
            print(f"Đang tải xuống hình ảnh ví dụ: {filename}...")
            try:
                response = requests.get(url, headers=headers, timeout=10)
                response.raise_for_status() 
                with open(filepath, 'wb') as f:
                    f.write(response.content)
                print("...Đã xong.")
            except requests.exceptions.RequestException as e:
                print(f" Lỗi khi tải {filename}: {e}")

    demo.launch(debug=True)