sunbv56's picture
Update app.py
d5bebb0 verified
raw
history blame
7.48 kB
# app.py (Phiên bản cuối cùng: Sửa lỗi cảnh báo và thêm tin nhắn "Thinking...")
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from gradio.events import SelectData
import warnings
import os
import requests
warnings.filterwarnings("ignore", category=UserWarning, message="Overriding torch_dtype=None")
# --- 1. Tải Model và Processor ---
MODEL_ID = "sunbv56/qwen2.5-vl-vqa-vibook"
print(f"🚀 Đang tải model '{MODEL_ID}' và processor...")
try:
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)
model.eval()
print(f"✅ Model và processor đã được tải thành công!")
except Exception as e:
print(f"❌ Lỗi khi tải model/processor: {e}")
exit()
# --- 2. Hàm Inference Cốt lõi ---
def process_vqa(image: Image.Image, question: str):
if image.mode != "RGB":
image = image.convert("RGB")
messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = processor(text=[prompt_text], images=[image], return_tensors="pt").to(model.device)
# SỬA LỖI 1: Ghi đè `temperature` để tắt cảnh báo.
# Đặt là 1.0 (trung tính) vì do_sample=False nên nó sẽ không được sử dụng.
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024,
do_sample=False,
temperature=1.0,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
generated_ids = generated_ids[:, model_inputs['input_ids'].shape[1]:]
response = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
return response
# --- 3. Logic Chatbot ---
# Hàm dành cho việc người dùng tự nhập câu hỏi
def manual_chat_responder(user_question: str, chat_history: list, uploaded_image: Image.Image):
if uploaded_image is None:
gr.Warning("Vui lòng tải ảnh lên trước để đặt câu hỏi về nó.")
return "", chat_history
if not user_question or not user_question.strip():
gr.Warning("Vui lòng nhập một câu hỏi.")
return "", chat_history
# THÊM TÍNH NĂNG 2: Hiển thị tin nhắn chờ
chat_history.append({"role": "user", "content": user_question})
chat_history.append({"role": "assistant", "content": "🤔 Thinking..."})
yield "", chat_history
bot_response = process_vqa(uploaded_image, user_question)
# THÊM TÍNH NĂNG 2: Cập nhật tin nhắn chờ bằng câu trả lời thật
chat_history[-1]["content"] = bot_response
yield "", chat_history
# Hàm dành riêng cho việc xử lý khi nhấn vào ví dụ
def run_example(evt: SelectData):
# Dùng list toàn cục đã được định nghĩa trong khối `with`
selected_example = example_list[evt.index]
image_path, question = selected_example
gr.Info(f"Đang chạy ví dụ: \"{question}\"")
image = Image.open(image_path).convert("RGB")
# THÊM TÍNH NĂNG 2: Hiển thị tin nhắn chờ
chat_history = [
{"role": "user", "content": question},
{"role": "assistant", "content": "🤔 Thinking..."}
]
# `yield` lần đầu để cập nhật UI ngay lập tức
yield image, question, chat_history
# Chạy xử lý và lấy câu trả lời thật
bot_response = process_vqa(image, question)
# THÊM TÍNH NĂNG 2: Cập nhật tin nhắn chờ bằng câu trả lời thật
chat_history[-1]["content"] = bot_response
# `yield` lần cuối để hiển thị kết quả cuối cùng
yield image, question, chat_history
def clear_chat():
return []
# --- 4. Định nghĩa Giao diện Người dùng Gradio ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), title="Vibook VQA Chatbot") as demo:
gr.Markdown("# 🤖 Vibook VQA Chatbot")
example_list = [
["./assets/book_example_1.jpg", "Đâu là tên đúng của cuốn sách này?"],
["./assets/book_example_1.jpg", "Ai là người đã viết cuốn sách này?"],
["./assets/book_example_2.jpg", "tác giả và tên của cuốn sách là gì?"],
]
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=350):
gr.Markdown("### Bảng điều khiển")
image_input = gr.Image(type="pil", label="Tải ảnh lên", sources=["upload", "clipboard", "webcam"])
gr.Markdown("---")
gr.Markdown("### Ví dụ (Nhấn để chạy)")
example_dataset = gr.Dataset(components=[gr.Image(visible=False), gr.Textbox(visible=False)], samples=example_list, label="Ví dụ", type="index")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Cuộc trò chuyện", height=600, avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png"), type="messages", value=[])
question_input = gr.Textbox(label="Hoặc nhập câu hỏi về ảnh đã tải lên", placeholder="Nhập câu hỏi và nhấn Enter...", container=False, scale=7)
# --- 5. Xử lý Sự kiện ---
question_input.submit(fn=manual_chat_responder, inputs=[question_input, chatbot, image_input], outputs=[question_input, chatbot])
# THÊM TÍNH NĂNG 2: Hàm `run_example` giờ là một generator, Gradio sẽ tự động xử lý các `yield`
example_dataset.select(fn=run_example, inputs=None, outputs=[image_input, question_input, chatbot], show_progress="full")
image_input.upload(fn=clear_chat, inputs=None, outputs=[chatbot])
image_input.clear(fn=clear_chat, inputs=None, outputs=[chatbot])
# --- Phần cuối ---
if __name__ == "__main__":
ASSETS_DIR = "assets"
if not os.path.exists(ASSETS_DIR):
os.makedirs(ASSETS_DIR)
print("Đã tạo thư mục 'assets' cho các hình ảnh ví dụ.")
EXAMPLE_FILES = {
"book_example_1.jpg": "https://cdn0.fahasa.com/media/catalog/product/d/i/dieu-ky-dieu-cua-tiem-tap-hoa-namiya---tai-ban-2020.jpg",
"book_example_2.jpg": "https://cdn0.fahasa.com/media/catalog/product/d/r/dr.-stone_bia_tap-26.jpg"
}
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"}
for filename, url in EXAMPLE_FILES.items():
filepath = os.path.join(ASSETS_DIR, filename)
if not os.path.exists(filepath):
print(f"Đang tải xuống hình ảnh ví dụ: {filename}...")
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
with open(filepath, 'wb') as f:
f.write(response.content)
print("...Đã xong.")
except requests.exceptions.RequestException as e:
print(f" Lỗi khi tải {filename}: {e}")
demo.launch(debug=True)