VideoAnalyzer / app.py
Zeph27's picture
update
5bf65b0
raw
history blame
2.42 kB
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from decord import VideoReader, cpu
import base64
import io
# Load model
model_path = 'openbmb/MiniCPM-V-2_6'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(device='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
MAX_NUM_FRAMES = 64
def encode_image(image):
if not isinstance(image, Image.Image):
image = Image.open(image).convert("RGB")
max_size = 448*16
if max(image.size) > max_size:
w,h = image.size
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
return image
def encode_video(video_path):
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = frame_idx[:MAX_NUM_FRAMES]
video = vr.get_batch(frame_idx).asnumpy()
video = [Image.fromarray(v.astype('uint8')) for v in video]
video = [encode_image(v) for v in video]
return video
def analyze_video(prompt, video):
encoded_video = encode_video(video.name)
context = [
{"role": "user", "content": [prompt] + encoded_video}
]
params = {
'sampling': True,
'top_p': 0.8,
'top_k': 100,
'temperature': 0.7,
'repetition_penalty': 1.05,
"max_new_tokens": 2048,
"max_inp_length": 4352,
"use_image_id": False,
"max_slice_nums": 1 if len(encoded_video) > 16 else 2
}
response = model.chat(image=None, msgs=context, tokenizer=tokenizer, **params)
return response
with gr.Blocks() as demo:
gr.Markdown("# Video Analyzer")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
video_input = gr.Video(label="Upload Video")
with gr.Column():
output = gr.Textbox(label="Analysis Result")
analyze_button = gr.Button("Analyze Video")
analyze_button.click(fn=analyze_video, inputs=[prompt_input, video_input], outputs=output)
demo.launch()