Spaces:
Sleeping
Sleeping
File size: 7,310 Bytes
7b41a20 f3b123c 1c7efef 7e990b0 db64806 7e990b0 db64806 7b41a20 7e990b0 455abc6 7e990b0 455abc6 7e990b0 7b41a20 7e990b0 7defefc 7e990b0 ef750f9 63181af 4174414 1c0e76a 8a2fb86 1c0e76a ae25eea 1c0e76a 8a2fb86 7e990b0 db64806 7e990b0 db64806 7e990b0 db64806 7e990b0 63181af 7e990b0 a9db5e7 7e990b0 a9db5e7 7e990b0 a9db5e7 7e990b0 a9db5e7 7e990b0 db64806 7e990b0 db64806 7e990b0 a9db5e7 7e990b0 a9db5e7 7e990b0 a9db5e7 7e990b0 db64806 7e990b0 db64806 7e990b0 db64806 7e990b0 db64806 7e990b0 a9db5e7 7e990b0 63181af 7e990b0 db64806 7e990b0 db64806 7e990b0 db64806 7e990b0 db64806 7e990b0 2d17b60 db64806 7e990b0 db64806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
#!/usr/bin/env python
# encoding: utf-8
import spaces
import torch
@spaces.GPU
def debug():
torch.randn(10).cuda()
debug()
import argparse
from transformers import AutoModel, AutoTokenizer
import gradio as gr
from PIL import Image
from decord import VideoReader, cpu
import io
import os
os.system("nvidia-smi")
import copy
import requests
import base64
import json
import traceback
import re
import modelscope_studio as mgr
from modelscope.hub.snapshot_download import snapshot_download
# Configuration
model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
device_map = "auto"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
device = args.device
# Replace the model loading section with:
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32,
attn_implementation="sdpa" # Use scaled dot-product attention instead of flash-attn
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
# Constants
ERROR_MSG = "Error occurred, please check inputs and try again"
MAX_NUM_FRAMES = 64
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def is_image(filename):
return get_file_extension(filename) in IMAGE_EXTENSIONS
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
return mgr.MultimodalInput(
upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
submit_button_props={'label': 'Submit'}
)
@spaces.GPU
def chat(images, messages, params):
try:
response = model.chat(
images=images,
messages=messages,
tokenizer=tokenizer,
**params
)
return 0, response, None
except Exception as e:
print(f"Error in chat: {str(e)}")
traceback.print_exc()
return -1, ERROR_MSG, None
def encode_image(image):
try:
if not isinstance(image, Image.Image):
image = Image.open(image.file.path).convert("RGB")
max_size = 448 * 16
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.BICUBIC)
return image
except Exception as e:
raise gr.Error(f"Image processing error: {str(e)}")
def encode_video(video):
try:
vr = VideoReader(video.file.path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = frame_idx[:MAX_NUM_FRAMES]
frames = vr.get_batch(frame_idx).asnumpy()
return [Image.fromarray(frame.astype('uint8')) for frame in frames]
except Exception as e:
raise gr.Error(f"Video processing error: {str(e)}")
def process_inputs(_question, _app_cfg):
try:
files = _question.files
text = _question.text
pattern = r"\[mm_media\]\d+\[/mm_media\]"
matches = re.split(pattern, text)
if len(matches) != len(files) + 1:
raise gr.Error("Media placeholders don't match uploaded files count")
message = []
media_count = 0
for i, match in enumerate(matches):
if match.strip():
message.append({"type": "text", "content": match.strip()})
if i < len(files):
file = files[i]
if is_image(file.file.path):
message.append({"type": "image", "content": encode_image(file)})
elif is_video(file.file.path):
message.append({"type": "video", "content": encode_video(file)})
media_count += 1
return message, media_count
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Input processing failed: {str(e)}")
def generate_response(_question, _chat_history, _app_cfg, params_form):
try:
params = {
'max_new_tokens': 2048,
'temperature': 0.7 if params_form == 'Sampling' else 1.0,
'top_p': 0.8 if params_form == 'Sampling' else None,
'num_beams': 3 if params_form == 'Beam Search' else 1,
'repetition_penalty': 1.1
}
processed_input, media_count = process_inputs(_question, _app_cfg)
_app_cfg['media_count'] += media_count
code, response, _ = chat(
images=[item['content'] for item in processed_input if item['type'] == 'image'],
messages=[{"role": "user", "content": processed_input}],
params=params
)
if code != 0:
raise gr.Error("Model response generation failed")
_chat_history.append((_question, response))
return _chat_history, _app_cfg
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Generation failed: {str(e)}")
def reset_chat():
return [], {'media_count': 0, 'ctx': []}
with gr.Blocks(css="video {height: auto !important;}") as demo:
with gr.Tab("mPLUG-Owl3"):
gr.Markdown("## mPLUG-Owl3 Multi-Modal Chat Interface")
# State management
app_state = gr.State({'media_count': 0, 'ctx': []})
# Chat interface
chatbot = mgr.Chatbot(height=600)
input_interface = create_multimodal_input()
# Controls
with gr.Row():
decode_type = gr.Radio(
choices=['Beam Search', 'Sampling'],
value='Sampling',
label="Decoding Strategy"
)
clear_btn = gr.Button("Clear History")
regenerate_btn = gr.Button("Regenerate")
# Event handlers
input_interface.submit(
generate_response,
[input_interface, chatbot, app_state, decode_type],
[chatbot, app_state]
)
clear_btn.click(
reset_chat,
outputs=[chatbot, app_state]
)
regenerate_btn.click(
lambda history: history[:-1] if history else [],
inputs=[chatbot],
outputs=[chatbot]
)
if __name__ == "__main__":
demo.launch(
server_name=args.host,
server_port=args.port,
share=False,
debug=True
) |