ShapLLM-Omni / app.py
yejunliang23's picture
Update app.py
6af08b9 unverified
raw
history blame
6.94 kB
import os
import torch
from threading import Thread
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
from qwen_vl_utils import process_vision_info
import trimesh
from trimesh.exchange.gltf import export_glb
import numpy as np
import tempfile
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print("User: " + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ""
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
if is_video_file(q[0]):
content.append({'video': f'file://{q[0]}'})
else:
content.append({'image': f'file://{q[0]}'})
else:
content.append({'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'text': a}]})
content = []
messages.pop()
messages = _transform_messages(messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs,
videos=video_inputs, padding=True, return_tensors='pt')
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
#for new_text in streamer:
# yield new_text
buffer = []
for chunk in streamer:
buffer.append(chunk)
yield "".join(buffer)
def regenerate(_chatbot, task_history):
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
def _transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_item = {'type': 'image', 'image': item['image']}
elif 'text' in item:
new_item = {'type': 'text', 'text': item['text']}
elif 'video' in item:
new_item = {'type': 'video', 'video': item['video']}
else:
continue
new_content.append(new_item)
new_message = {'role': message['role'], 'content': new_content}
transformed_messages.append(new_message)
return transformed_messages
# --------- Configuration & Model Loading ---------
MODEL_DIR = "Qwen/Qwen2.5-VL-3B-Instruct"
# Load processor, tokenizer, model for Qwen2.5-VL
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_DIR,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(MODEL_DIR)
tokenizer = processor.tokenizer
#terminators = [tokenizer.eos_token_id]
def chat_qwen_vl(messages: str, history: list, temperature: float = 0.1, max_new_tokens: int = 1024):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": messages},
],
}
]
messages = _transform_messages(messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs,
videos=video_inputs, padding=True, return_tensors='pt')
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
#for new_text in streamer:
# yield new_text
buffer = []
for chunk in streamer:
buffer.append(chunk)
yield "".join(buffer)
css = """
h1 { text-align: center; }
"""
PLACEHOLDER = (
"<div style='padding:30px;text-align:center;display:flex;flex-direction:column;align-items:center;'>"
"<h1 style='font-size:28px;opacity:0.55;'>Qwen2.5-VL Local Chat</h1>"
"<p style='font-size:18px;opacity:0.65;'>Ask anything or generate images!</p></div>"
)
with gr.Blocks() as demo:
gr.Markdown("""<center><font size=3> ShapeLLM-7B Demo </center>""")
chatbot = gr.Chatbot(label='ShapeLLM-4o', elem_classes="control-height", height=500)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton("πŸ“ Upload (δΈŠδΌ ζ–‡δ»Ά)", file_types=["image", "video"])
submit_btn = gr.Button("πŸš€ Submit (发送)")
regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
if __name__ == "__main__":
demo.launch()