assentian1970's picture
Update app.py
7e990b0 verified
raw
history blame
7.32 kB
#!/usr/bin/env python
# encoding: utf-8
import spaces
import torch
@spaces.GPU
def debug():
torch.randn(10).cuda()
debug()
import argparse
from transformers import AutoModel, AutoTokenizer
import gradio as gr
from PIL import Image
from decord import VideoReader, cpu
import io
import os
os.system("nvidia-smi")
import copy
import requests
import base64
import json
import traceback
import re
import modelscope_studio as mgr
from modelscope.hub.snapshot_download import snapshot_download
# Configuration
model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
device_map = "auto"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
device = args.device
# Load model and tokenizer
model_path = './iic/mPLUG-Owl3-7B-240728'
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32,
attn_implementation="flash_attention_2" if device == 'cuda' else None
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
# Constants
ERROR_MSG = "Error occurred, please check inputs and try again"
MAX_NUM_FRAMES = 64
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def is_image(filename):
return get_file_extension(filename) in IMAGE_EXTENSIONS
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
return mgr.MultimodalInput(
upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
submit_button_props={'label': 'Submit'}
)
@spaces.GPU
def chat(images, messages, params):
try:
response = model.chat(
images=images,
messages=messages,
tokenizer=tokenizer,
**params
)
return 0, response, None
except Exception as e:
print(f"Error in chat: {str(e)}")
traceback.print_exc()
return -1, ERROR_MSG, None
def encode_image(image):
try:
if not isinstance(image, Image.Image):
image = Image.open(image.file.path).convert("RGB")
max_size = 448 * 16
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.BICUBIC)
return image
except Exception as e:
raise gr.Error(f"Image processing error: {str(e)}")
def encode_video(video):
try:
vr = VideoReader(video.file.path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = frame_idx[:MAX_NUM_FRAMES]
frames = vr.get_batch(frame_idx).asnumpy()
return [Image.fromarray(frame.astype('uint8')) for frame in frames]
except Exception as e:
raise gr.Error(f"Video processing error: {str(e)}")
def process_inputs(_question, _app_cfg):
try:
files = _question.files
text = _question.text
pattern = r"\[mm_media\]\d+\[/mm_media\]"
matches = re.split(pattern, text)
if len(matches) != len(files) + 1:
raise gr.Error("Media placeholders don't match uploaded files count")
message = []
media_count = 0
for i, match in enumerate(matches):
if match.strip():
message.append({"type": "text", "content": match.strip()})
if i < len(files):
file = files[i]
if is_image(file.file.path):
message.append({"type": "image", "content": encode_image(file)})
elif is_video(file.file.path):
message.append({"type": "video", "content": encode_video(file)})
media_count += 1
return message, media_count
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Input processing failed: {str(e)}")
def generate_response(_question, _chat_history, _app_cfg, params_form):
try:
params = {
'max_new_tokens': 2048,
'temperature': 0.7 if params_form == 'Sampling' else 1.0,
'top_p': 0.8 if params_form == 'Sampling' else None,
'num_beams': 3 if params_form == 'Beam Search' else 1,
'repetition_penalty': 1.1
}
processed_input, media_count = process_inputs(_question, _app_cfg)
_app_cfg['media_count'] += media_count
code, response, _ = chat(
images=[item['content'] for item in processed_input if item['type'] == 'image'],
messages=[{"role": "user", "content": processed_input}],
params=params
)
if code != 0:
raise gr.Error("Model response generation failed")
_chat_history.append((_question, response))
return _chat_history, _app_cfg
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Generation failed: {str(e)}")
def reset_chat():
return [], {'media_count': 0, 'ctx': []}
with gr.Blocks(css="video {height: auto !important;}") as demo:
with gr.Tab("mPLUG-Owl3"):
gr.Markdown("## mPLUG-Owl3 Multi-Modal Chat Interface")
# State management
app_state = gr.State({'media_count': 0, 'ctx': []})
# Chat interface
chatbot = mgr.Chatbot(height=600)
input_interface = create_multimodal_input()
# Controls
with gr.Row():
decode_type = gr.Radio(
choices=['Beam Search', 'Sampling'],
value='Sampling',
label="Decoding Strategy"
)
clear_btn = gr.Button("Clear History")
regenerate_btn = gr.Button("Regenerate")
# Event handlers
input_interface.submit(
generate_response,
[input_interface, chatbot, app_state, decode_type],
[chatbot, app_state]
)
clear_btn.click(
reset_chat,
outputs=[chatbot, app_state]
)
regenerate_btn.click(
lambda history: history[:-1] if history else [],
inputs=[chatbot],
outputs=[chatbot]
)
if __name__ == "__main__":
demo.launch(
server_name=args.host,
server_port=args.port,
share=False,
debug=True
)