import argparse import gradio as gr import os from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForCausalLM from serve.frontend import reload_javascript from serve.utils import ( configure_logger, ) from serve.gradio_utils import ( cancel_outputing, delete_last_conversation, reset_state, reset_textbox, transfer_input, wrap_gen_fn, ) from serve.chat_utils import compress_video_to_base64 from serve.examples import get_examples import logging TITLE = """

Chat with Video-XL-2

""" DESCRIPTION_TOP = """Video-XL-2, a better, faster, and high-frame-count model for long video understanding.""" DESCRIPTION = """""" ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) DEPLOY_MODELS = dict() logger = configure_logger() DEFAULT_IMAGE_TOKEN = "" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="Video-XL-2") parser.add_argument( "--local-path", type=str, help="huggingface ckpt, optional", ) parser.add_argument("--ip", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) return parser.parse_args() def fetch_model(model_name: str): global DEPLOY_MODELS if args.local_path: local_model_path = args.local_path else: local_model_path = 'BAAI/Video-XL-2' if model_name in DEPLOY_MODELS: model_info = DEPLOY_MODELS[model_name] print(f"{model_name} has been loaded.") else: print(f"{model_name} is loading...") device = 'cuda:0' if torch.cuda.is_available() else 'cpu' tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( local_model_path, trust_remote_code=True, device_map=device, quantization_config=None, attn_implementation="sdpa", torch_dtype=torch.float16, low_cpu_mem_usage=True ) DEPLOY_MODELS[model_name] = (model, tokenizer) print(f"Load {model_name} successfully...") model_info = DEPLOY_MODELS[model_name] return model_info def preview_images(files) -> list[str]: if files is None: return [] image_paths = [] for file in files: image_paths.append(file.name) return image_paths @wrap_gen_fn def predict( text, images, chatbot, history, top_p, temperature, max_generate_length, max_context_length_tokens, video_nframes, chunk_size: int = 512, ): """ Predict the response for the input text and images. Args: text (str): The input text. images (list[PIL.Image.Image]): The input images. chatbot (list): The chatbot. history (list): The history. top_p (float): The top-p value. temperature (float): The temperature value. repetition_penalty (float): The repetition penalty value. max_generate_length (int): The max length tokens. max_context_length_tokens (int): The max context length tokens. chunk_size (int): The chunk size. """ if images is None: pil_images = history["video_path"] else: pil_images = images[0].name print("running the prediction function") try: logger.info("fetching model") model, tokenizer = fetch_model(args.model) logger.info("model fetched") if text == "": yield chatbot, history, "Empty context." return except KeyError: logger.info("no model found") yield [[text, "No Model Found"]], [], "No Model Found" return gen_kwargs = { "do_sample": True if temperature > 1e-2 else False, "temperature": temperature, "top_p": top_p, "num_beams": 1, "use_cache": True, "max_new_tokens": max_generate_length, } # Check if this is the very first turn with an image is_first_image_turn = (len(history) == 0 and pil_images) if is_first_image_turn: history["video_path"] = pil_images history["context"] = None response, temp_history = model.chat( history["video_path"] if "video_path" in history else pil_images, tokenizer, text, chat_history=history["context"], return_history=True, max_num_frames=video_nframes, sample_fps=None, max_sample_fps=None, generation_config=gen_kwargs ) text_for_history = text if is_first_image_turn: media_str = "" b64 = compress_video_to_base64(history["video_path"] if "video_path" in history else pil_images) media_str += ( f'' ) text_for_history = media_str + text_for_history chatbot.append([text_for_history, response]) else: chatbot.append([text_for_history, response]) history["context"] = (temp_history) logger.info("flushed result to gradio") print( f"temperature: {temperature}, " f"top_p: {top_p}, " f"max_generate_length: {max_generate_length}" ) yield chatbot, history, "Generate: Success" def retry( text, # This `text` is the current text box content, not the last user input images, chatbot, full_history, # This is the full history top_p, temperature, max_generate_length, max_context_length_tokens, video_nframes, chunk_size: int = 512, ): """ Retry the response for the input text and images. """ history = full_history["context"] if len(history) == 0: yield (chatbot, history, "Empty context") return # Get the last user input before popping # print("history:", history) last_user_input = history[-2]["content"] # Remove the last turn from chatbot and history chatbot.pop() history.pop() full_history["context"] = history # Now call predict with the last user input and the modified history yield from predict( last_user_input, # Pass the last user input as the current text images, # Images should be the same as the last turn chatbot, # Updated chatbot full_history, # Updated history top_p, temperature, max_generate_length, max_context_length_tokens, video_nframes, chunk_size, ) def build_demo(args: argparse.Namespace) -> gr.Blocks: with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo: history = gr.State(dict()) input_text = gr.State() input_images = gr.State() with gr.Row(): gr.HTML(TITLE) status_display = gr.Markdown("Success", elem_id="status_display") gr.Markdown(DESCRIPTION_TOP) with gr.Row(equal_height=True): with gr.Column(scale=4): with gr.Row(): chatbot = gr.Chatbot( elem_id="Video-XL-2_Demo-chatbot", show_share_button=True, bubble_full_width=False, height=600, ) with gr.Row(): with gr.Column(scale=4): text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False) with gr.Column(min_width=70): submit_btn = gr.Button("Send") with gr.Column(min_width=70): cancel_btn = gr.Button("Stop") with gr.Row(): empty_btn = gr.Button("๐Ÿงน New Conversation") retry_btn = gr.Button("๐Ÿ”„ Regenerate") del_last_btn = gr.Button("๐Ÿ—‘๏ธ Remove Last Turn") with gr.Column(): # add note no more than 2 images once gr.Markdown("Note: you can upload images or videos!") upload_images = gr.Files(file_types=["image", "video"], show_label=True) gallery = gr.Gallery(columns=[3], height="200px", show_label=True) upload_images.change(preview_images, inputs=upload_images, outputs=gallery) # Parameter Setting Tab for control the generation parameters with gr.Tab(label="Parameter Setting"): top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.001, step=0.05, interactive=True, label="Top-p") temperature = gr.Slider( minimum=0, maximum=1.0, value=0.01, step=0.1, interactive=True, label="Temperature" ) max_generate_length = gr.Slider( minimum=512, maximum=8192, value=4096, step=64, interactive=True, label="Max Generate Length" ) max_context_length_tokens = gr.Slider( minimum=512, maximum=65536, value=16384, step=64, interactive=True, label="Max Context Length Tokens" ) video_nframes = gr.Slider( minimum=1, maximum=128, value=128, step=1, interactive=True, label="Video Nframes" ) show_images = gr.HTML(visible=False) gr.Markdown("This demo is based on `moonshotai/Kimi-VL-A3B-Thinking` & `deepseek-ai/deepseek-vl2-small` and extends it by adding support for video input.") gr.Examples( examples=get_examples(ROOT_DIR), inputs=[upload_images, show_images, text_box], ) gr.Markdown() input_widgets = [ input_text, input_images, chatbot, history, top_p, temperature, max_generate_length, max_context_length_tokens, video_nframes ] output_widgets = [chatbot, history, status_display] transfer_input_args = dict( fn=transfer_input, inputs=[text_box, upload_images], outputs=[input_text, input_images, text_box, upload_images, submit_btn], show_progress=True, ) predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True) retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True) reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display]) predict_events = [ text_box.submit(**transfer_input_args).then(**predict_args), submit_btn.click(**transfer_input_args).then(**predict_args), ] empty_btn.click(reset_state, outputs=output_widgets, show_progress=True) empty_btn.click(**reset_args) retry_btn.click(**retry_args) del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True) cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events) demo.title = "Video-XL-2_Demo Chatbot" return demo def main(args: argparse.Namespace): demo = build_demo(args) reload_javascript() # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS favicon_path = os.path.join("serve/assets/favicon.ico") demo.queue().launch( favicon_path=favicon_path if os.path.exists(favicon_path) else None, server_name=args.ip, server_port=args.port, ) if __name__ == "__main__": args = parse_args() main(args)