#!/usr/bin/env python # encoding: utf-8 import spaces import torch @spaces.GPU def debug(): torch.randn(10).cuda() debug() import argparse from transformers import AutoModel, AutoTokenizer import gradio as gr from PIL import Image from decord import VideoReader, cpu import io import os os.system("nvidia-smi") import copy import requests import base64 import json import traceback import re import modelscope_studio as mgr from modelscope.hub.snapshot_download import snapshot_download # Configuration model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./') device_map = "auto" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Argparser parser = argparse.ArgumentParser(description='demo') parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu') parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) args = parser.parse_args() device = args.device # Replace the model loading section with: model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32, attn_implementation="sdpa" # Use scaled dot-product attention instead of flash-attn ).to(device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval() # Constants ERROR_MSG = "Error occurred, please check inputs and try again" MAX_NUM_FRAMES = 64 IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} def get_file_extension(filename): return os.path.splitext(filename)[1].lower() def is_image(filename): return get_file_extension(filename) in IMAGE_EXTENSIONS def is_video(filename): return get_file_extension(filename) in VIDEO_EXTENSIONS def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False): return mgr.MultimodalInput( upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'}, upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'}, submit_button_props={'label': 'Submit'} ) @spaces.GPU def chat(images, messages, params): try: response = model.chat( images=images, messages=messages, tokenizer=tokenizer, **params ) return 0, response, None except Exception as e: print(f"Error in chat: {str(e)}") traceback.print_exc() return -1, ERROR_MSG, None def encode_image(image): try: if not isinstance(image, Image.Image): image = Image.open(image.file.path).convert("RGB") max_size = 448 * 16 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) image = image.resize(new_size, Image.BICUBIC) return image except Exception as e: raise gr.Error(f"Image processing error: {str(e)}") def encode_video(video): try: vr = VideoReader(video.file.path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) frame_idx = [i for i in range(0, len(vr), sample_fps)] if len(frame_idx) > MAX_NUM_FRAMES: frame_idx = frame_idx[:MAX_NUM_FRAMES] frames = vr.get_batch(frame_idx).asnumpy() return [Image.fromarray(frame.astype('uint8')) for frame in frames] except Exception as e: raise gr.Error(f"Video processing error: {str(e)}") def process_inputs(_question, _app_cfg): try: files = _question.files text = _question.text pattern = r"\[mm_media\]\d+\[/mm_media\]" matches = re.split(pattern, text) if len(matches) != len(files) + 1: raise gr.Error("Media placeholders don't match uploaded files count") message = [] media_count = 0 for i, match in enumerate(matches): if match.strip(): message.append({"type": "text", "content": match.strip()}) if i < len(files): file = files[i] if is_image(file.file.path): message.append({"type": "image", "content": encode_image(file)}) elif is_video(file.file.path): message.append({"type": "video", "content": encode_video(file)}) media_count += 1 return message, media_count except Exception as e: traceback.print_exc() raise gr.Error(f"Input processing failed: {str(e)}") def generate_response(_question, _chat_history, _app_cfg, params_form): try: params = { 'max_new_tokens': 2048, 'temperature': 0.7 if params_form == 'Sampling' else 1.0, 'top_p': 0.8 if params_form == 'Sampling' else None, 'num_beams': 3 if params_form == 'Beam Search' else 1, 'repetition_penalty': 1.1 } processed_input, media_count = process_inputs(_question, _app_cfg) _app_cfg['media_count'] += media_count code, response, _ = chat( images=[item['content'] for item in processed_input if item['type'] == 'image'], messages=[{"role": "user", "content": processed_input}], params=params ) if code != 0: raise gr.Error("Model response generation failed") _chat_history.append((_question, response)) return _chat_history, _app_cfg except Exception as e: traceback.print_exc() raise gr.Error(f"Generation failed: {str(e)}") def reset_chat(): return [], {'media_count': 0, 'ctx': []} with gr.Blocks(css="video {height: auto !important;}") as demo: with gr.Tab("mPLUG-Owl3"): gr.Markdown("## mPLUG-Owl3 Multi-Modal Chat Interface") # State management app_state = gr.State({'media_count': 0, 'ctx': []}) # Chat interface chatbot = mgr.Chatbot(height=600) input_interface = create_multimodal_input() # Controls with gr.Row(): decode_type = gr.Radio( choices=['Beam Search', 'Sampling'], value='Sampling', label="Decoding Strategy" ) clear_btn = gr.Button("Clear History") regenerate_btn = gr.Button("Regenerate") # Event handlers input_interface.submit( generate_response, [input_interface, chatbot, app_state, decode_type], [chatbot, app_state] ) clear_btn.click( reset_chat, outputs=[chatbot, app_state] ) regenerate_btn.click( lambda history: history[:-1] if history else [], inputs=[chatbot], outputs=[chatbot] ) if __name__ == "__main__": demo.launch( server_name=args.host, server_port=args.port, share=False, debug=True )