File size: 3,716 Bytes
7efde76
f0272e1
 
 
 
 
 
 
 
7efde76
f0272e1
 
 
 
7efde76
 
f0272e1
 
 
 
 
 
 
 
 
 
 
 
7efde76
f0272e1
 
 
 
 
7efde76
f0272e1
 
 
 
7efde76
f0272e1
 
 
 
 
7efde76
f0272e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7efde76
f0272e1
 
 
 
 
7efde76
f0272e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7efde76
 
f0272e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
import copy
from decord import VideoReader, cpu
import numpy as np

title = "# 🎥 Instagram Short Video Analyzer with LLaVA-Video"
description = """
This application uses the LLaVA-Video-7B-Qwen2 model to analyze Instagram short videos.
Upload your Instagram short video and ask questions about its content!
"""

def load_video(video_path, max_frames_num=64, fps=1):
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frame_num = len(vr)
    video_time = total_frame_num / vr.get_avg_fps()
    fps = round(vr.get_avg_fps()/fps)
    frame_idx = list(range(0, len(vr), fps))
    if len(frame_idx) > max_frames_num:
        frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
    frame_time = [i/vr.get_avg_fps() for i in frame_idx]
    frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames, frame_time, video_time

# Load the model
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_map = "auto"

print("Loading model...")
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
model.eval()
print("Model loaded successfully!")

def process_instagram_short(video_path, question):
    max_frames_num = 64
    video, frame_time, video_time = load_video(video_path, max_frames_num)
    video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
    video = [video]

    time_instruction = f"This is an Instagram short video lasting {video_time:.2f} seconds. {len(video[0])} frames were sampled at {frame_time}. Analyze this short video and answer the following question:"
    
    full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
    
    conv = copy.deepcopy(conv_templates["qwen_1_5"])
    conv.append_message(conv.roles[0], full_question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()
    
    input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            images=video,
            modalities=["video"],
            do_sample=False,
            temperature=0,
            max_new_tokens=4096,
        )
    
    response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
    return response

def gradio_interface(video_file, question):
    if video_file is None:
        return "Please upload an Instagram short video."
    response = process_instagram_short(video_file, question)
    return response

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Row():
        with gr.Column():
            video_input = gr.Video(label="Upload Instagram Short Video")
            question_input = gr.Textbox(label="Ask a question about the video", placeholder="What's happening in this Instagram short?")
            submit_button = gr.Button("Analyze Short Video")
        output = gr.Textbox(label="Analysis Result")
    
    submit_button.click(
        fn=gradio_interface,
        inputs=[video_input, question_input],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(show_error=True)