VideoAnalyzer / app.py
Zeph27's picture
update flash
213c599
raw
history blame
2.8 kB
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
from decord import VideoReader, cpu
import os
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Load the model and tokenizer
model_name = "openbmb/MiniCPM-V-2_6-int4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
model.eval()
MAX_NUM_FRAMES = 64
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
def encode_video(video):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
if hasattr(video, 'path'):
video_path = video.path
else:
video_path = video.file.path
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
if total_frames <= MAX_NUM_FRAMES:
frame_idxs = list(range(total_frames))
else:
frame_idxs = uniform_sample(range(total_frames), MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idxs).asnumpy()
return frames
@spaces.GPU
def analyze_video(video, prompt):
if not is_video(video.name):
return "Please upload a valid video file."
frames = encode_video(video)
# Prepare the frames for the model
inputs = model.vpm(frames)
# Generate the caption with the user's prompt
with torch.no_grad():
outputs = model.generate(inputs=inputs, tokenizer=tokenizer, max_new_tokens=50, prompt=prompt)
# Decode the output
caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
return caption
# Create the Gradio interface using Blocks
with gr.Blocks(title="Video Analyzer using MiniCPM-V-2.6-int4") as iface:
gr.Markdown("# Video Analyzer using MiniCPM-V-2.6-int4")
gr.Markdown("Upload a video to get an analysis using the MiniCPM-V-2.6-int4 model.")
gr.Markdown("This model uses 4-bit quantization for improved efficiency. [Learn more](https://huggingface.co/openbmb/MiniCPM-V-2_6-int4)")
with gr.Row():
video_input = gr.Video()
prompt_input = gr.Textbox(label="Prompt (optional)", placeholder="Enter a prompt to guide the analysis...")
analysis_output = gr.Textbox(label="Video Analysis")
analyze_button = gr.Button("Analyze Video")
analyze_button.click(fn=analyze_video, inputs=[video_input, prompt_input], outputs=analysis_output)
# Launch the interface
iface.launch()