import argparse
import gradio as gr
import os
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from serve.frontend import reload_javascript
from serve.utils import (
configure_logger,
)
from serve.gradio_utils import (
cancel_outputing,
delete_last_conversation,
reset_state,
reset_textbox,
transfer_input,
wrap_gen_fn,
)
from serve.chat_utils import compress_video_to_base64
from serve.examples import get_examples
import logging
TITLE = """
Chat with Video-XL-2
"""
DESCRIPTION_TOP = """Video-XL-2, a better, faster, and high-frame-count model for long video understanding."""
DESCRIPTION = """"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEPLOY_MODELS = dict()
logger = configure_logger()
DEFAULT_IMAGE_TOKEN = ""
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="Video-XL-2")
parser.add_argument(
"--local-path",
type=str,
help="huggingface ckpt, optional",
)
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
return parser.parse_args()
def fetch_model(model_name: str):
global DEPLOY_MODELS
if args.local_path:
local_model_path = args.local_path
else:
local_model_path = 'BAAI/Video-XL-2'
if model_name in DEPLOY_MODELS:
model_info = DEPLOY_MODELS[model_name]
print(f"{model_name} has been loaded.")
else:
print(f"{model_name} is loading...")
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
local_model_path,
trust_remote_code=True,
device_map=device,
quantization_config=None,
attn_implementation="sdpa",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
DEPLOY_MODELS[model_name] = (model, tokenizer)
print(f"Load {model_name} successfully...")
model_info = DEPLOY_MODELS[model_name]
return model_info
def preview_images(files) -> list[str]:
if files is None:
return []
image_paths = []
for file in files:
image_paths.append(file.name)
return image_paths
@wrap_gen_fn
def predict(
text,
images,
chatbot,
history,
top_p,
temperature,
max_generate_length,
max_context_length_tokens,
video_nframes,
chunk_size: int = 512,
):
"""
Predict the response for the input text and images.
Args:
text (str): The input text.
images (list[PIL.Image.Image]): The input images.
chatbot (list): The chatbot.
history (list): The history.
top_p (float): The top-p value.
temperature (float): The temperature value.
repetition_penalty (float): The repetition penalty value.
max_generate_length (int): The max length tokens.
max_context_length_tokens (int): The max context length tokens.
chunk_size (int): The chunk size.
"""
if images is None:
pil_images = history["video_path"]
else:
pil_images = images[0].name
print("running the prediction function")
try:
logger.info("fetching model")
model, tokenizer = fetch_model(args.model)
logger.info("model fetched")
if text == "":
yield chatbot, history, "Empty context."
return
except KeyError:
logger.info("no model found")
yield [[text, "No Model Found"]], [], "No Model Found"
return
gen_kwargs = {
"do_sample": True if temperature > 1e-2 else False,
"temperature": temperature,
"top_p": top_p,
"num_beams": 1,
"use_cache": True,
"max_new_tokens": max_generate_length,
}
# Check if this is the very first turn with an image
is_first_image_turn = (len(history) == 0 and pil_images)
if is_first_image_turn:
history["video_path"] = pil_images
history["context"] = None
response, temp_history = model.chat(
history["video_path"] if "video_path" in history else pil_images,
tokenizer,
text,
chat_history=history["context"],
return_history=True,
max_num_frames=video_nframes,
sample_fps=None,
max_sample_fps=None,
generation_config=gen_kwargs
)
text_for_history = text
if is_first_image_turn:
media_str = ""
b64 = compress_video_to_base64(history["video_path"] if "video_path" in history else pil_images)
media_str += (
f''
)
text_for_history = media_str + text_for_history
chatbot.append([text_for_history, response])
else:
chatbot.append([text_for_history, response])
history["context"] = (temp_history)
logger.info("flushed result to gradio")
print(
f"temperature: {temperature}, "
f"top_p: {top_p}, "
f"max_generate_length: {max_generate_length}"
)
yield chatbot, history, "Generate: Success"
def retry(
text, # This `text` is the current text box content, not the last user input
images,
chatbot,
full_history, # This is the full history
top_p,
temperature,
max_generate_length,
max_context_length_tokens,
video_nframes,
chunk_size: int = 512,
):
"""
Retry the response for the input text and images.
"""
history = full_history["context"]
if len(history) == 0:
yield (chatbot, history, "Empty context")
return
# Get the last user input before popping
# print("history:", history)
last_user_input = history[-2]["content"]
# Remove the last turn from chatbot and history
chatbot.pop()
history.pop()
full_history["context"] = history
# Now call predict with the last user input and the modified history
yield from predict(
last_user_input, # Pass the last user input as the current text
images, # Images should be the same as the last turn
chatbot, # Updated chatbot
full_history, # Updated history
top_p,
temperature,
max_generate_length,
max_context_length_tokens,
video_nframes,
chunk_size,
)
def build_demo(args: argparse.Namespace) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
history = gr.State(dict())
input_text = gr.State()
input_images = gr.State()
with gr.Row():
gr.HTML(TITLE)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(DESCRIPTION_TOP)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="Video-XL-2_Demo-chatbot",
show_share_button=True,
bubble_full_width=False,
height=600,
)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
with gr.Column(min_width=70):
submit_btn = gr.Button("Send")
with gr.Column(min_width=70):
cancel_btn = gr.Button("Stop")
with gr.Row():
empty_btn = gr.Button("๐งน New Conversation")
retry_btn = gr.Button("๐ Regenerate")
del_last_btn = gr.Button("๐๏ธ Remove Last Turn")
with gr.Column():
# add note no more than 2 images once
gr.Markdown("Note: you can upload images or videos!")
upload_images = gr.Files(file_types=["image", "video"], show_label=True)
gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
# Parameter Setting Tab for control the generation parameters
with gr.Tab(label="Parameter Setting"):
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.001, step=0.05, interactive=True, label="Top-p")
temperature = gr.Slider(
minimum=0, maximum=1.0, value=0.01, step=0.1, interactive=True, label="Temperature"
)
max_generate_length = gr.Slider(
minimum=512, maximum=8192, value=4096, step=64, interactive=True, label="Max Generate Length"
)
max_context_length_tokens = gr.Slider(
minimum=512, maximum=65536, value=16384, step=64, interactive=True, label="Max Context Length Tokens"
)
video_nframes = gr.Slider(
minimum=1, maximum=128, value=128, step=1, interactive=True, label="Video Nframes"
)
show_images = gr.HTML(visible=False)
gr.Markdown("This demo is based on `moonshotai/Kimi-VL-A3B-Thinking` & `deepseek-ai/deepseek-vl2-small` and extends it by adding support for video input.")
gr.Examples(
examples=get_examples(ROOT_DIR),
inputs=[upload_images, show_images, text_box],
)
gr.Markdown()
input_widgets = [
input_text,
input_images,
chatbot,
history,
top_p,
temperature,
max_generate_length,
max_context_length_tokens,
video_nframes
]
output_widgets = [chatbot, history, status_display]
transfer_input_args = dict(
fn=transfer_input,
inputs=[text_box, upload_images],
outputs=[input_text, input_images, text_box, upload_images, submit_btn],
show_progress=True,
)
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])
predict_events = [
text_box.submit(**transfer_input_args).then(**predict_args),
submit_btn.click(**transfer_input_args).then(**predict_args),
]
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
empty_btn.click(**reset_args)
retry_btn.click(**retry_args)
del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)
demo.title = "Video-XL-2_Demo Chatbot"
return demo
def main(args: argparse.Namespace):
demo = build_demo(args)
reload_javascript()
# concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS
favicon_path = os.path.join("serve/assets/favicon.ico")
demo.queue().launch(
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
server_name=args.ip,
server_port=args.port,
)
if __name__ == "__main__":
args = parse_args()
main(args)