metek7's picture
Update app.py
f0272e1 verified
raw
history blame
3.72 kB
import gradio as gr
import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
import copy
from decord import VideoReader, cpu
import numpy as np
title = "# 🎥 Instagram Short Video Analyzer with LLaVA-Video"
description = """
This application uses the LLaVA-Video-7B-Qwen2 model to analyze Instagram short videos.
Upload your Instagram short video and ask questions about its content!
"""
def load_video(video_path, max_frames_num=64, fps=1):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps()/fps)
frame_idx = list(range(0, len(vr), fps))
if len(frame_idx) > max_frames_num:
frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames, frame_time, video_time
# Load the model
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_map = "auto"
print("Loading model...")
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
model.eval()
print("Model loaded successfully!")
def process_instagram_short(video_path, question):
max_frames_num = 64
video, frame_time, video_time = load_video(video_path, max_frames_num)
video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
video = [video]
time_instruction = f"This is an Instagram short video lasting {video_time:.2f} seconds. {len(video[0])} frames were sampled at {frame_time}. Analyze this short video and answer the following question:"
full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
conv = copy.deepcopy(conv_templates["qwen_1_5"])
conv.append_message(conv.roles[0], full_question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
with torch.no_grad():
output = model.generate(
input_ids,
images=video,
modalities=["video"],
do_sample=False,
temperature=0,
max_new_tokens=4096,
)
response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
return response
def gradio_interface(video_file, question):
if video_file is None:
return "Please upload an Instagram short video."
response = process_instagram_short(video_file, question)
return response
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Instagram Short Video")
question_input = gr.Textbox(label="Ask a question about the video", placeholder="What's happening in this Instagram short?")
submit_button = gr.Button("Analyze Short Video")
output = gr.Textbox(label="Analysis Result")
submit_button.click(
fn=gradio_interface,
inputs=[video_input, question_input],
outputs=output
)
if __name__ == "__main__":
demo.launch(show_error=True)