Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import random | |
| import spaces | |
| import time | |
| from PIL import Image | |
| from threading import Thread | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
| from transformers.image_utils import load_image | |
| ##################################### | |
| # 1. Load Qwen2.5-VL Model & Processor | |
| ##################################### | |
| MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # or "Qwen/Qwen2.5-VL-3B-Instruct" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| model.eval() | |
| ##################################### | |
| # 2. Helper Function: Downsample Video | |
| ##################################### | |
| def downsample_video(video_path, num_frames=10): | |
| """ | |
| Downsamples the video file to `num_frames` evenly spaced frames. | |
| Each frame is converted to a PIL Image along with its timestamp. | |
| """ | |
| vidcap = cv2.VideoCapture(video_path) | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| frames = [] | |
| if total_frames <= 0 or fps <= 0: | |
| vidcap.release() | |
| return frames | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| for i in frame_indices: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| success, image = vidcap.read() | |
| if success: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(image) | |
| timestamp = round(i / fps, 2) | |
| frames.append((pil_image, timestamp)) | |
| vidcap.release() | |
| return frames | |
| ##################################### | |
| # 3. The Inference Function | |
| ##################################### | |
| def video_inference(video_file, duration): | |
| """ | |
| - Takes a recorded video file and a chosen duration (string). | |
| - Downsamples the video, passes frames to Qwen2.5-VL for inference. | |
| - Returns model-generated text + a dummy bar chart as example analytics. | |
| """ | |
| if video_file is None: | |
| return "No video provided.", None | |
| # 3.1: Downsample the recorded video | |
| frames = downsample_video(video_file) | |
| if not frames: | |
| return "Could not read frames from video.", None | |
| # 3.2: Construct Qwen2.5-VL prompt | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": "Please describe what's happening in this video."}] | |
| } | |
| ] | |
| # Add frames (with timestamp) to the messages | |
| for (image, ts) in frames: | |
| messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"}) | |
| messages[0]["content"].append({"type": "image", "image": image}) | |
| # Prepare final prompt for the model | |
| prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Qwen requires images in the same order. We'll just collect them: | |
| frame_images = [img for (img, _) in frames] | |
| inputs = processor( | |
| text=[prompt], | |
| images=frame_images, | |
| return_tensors="pt", | |
| padding=True | |
| ).to("cuda") | |
| # 3.3: Generate text output (streaming) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| time.sleep(0.01) | |
| # 3.4: Dummy bar chart for demonstration | |
| fig, ax = plt.subplots() | |
| categories = ["Category A", "Category B", "Category C"] | |
| values = [random.randint(1, 10) for _ in categories] | |
| ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]) | |
| ax.set_title("Example Analytics Chart") | |
| ax.set_ylabel("Value") | |
| ax.set_xlabel("Category") | |
| return generated_text, fig | |
| ##################################### | |
| # 4. Build a Professional Gradio UI | |
| ##################################### | |
| def build_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # **Qwen2.5-VL-7B-Instruct Live Video Analysis** | |
| Record a video (from webcam or file), then click **Stop**. | |
| Next, click **Analyze** to run Qwen2.5-VL and see textual + chart outputs. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| duration = gr.Radio( | |
| choices=["5", "10", "20", "30"], | |
| value="5", | |
| label="Suggested Recording Duration (seconds)", | |
| info="Select how long you plan to record before pressing Stop." | |
| ) | |
| # Remove 'source="webcam"' to avoid the TypeError on older Gradio versions | |
| video = gr.Video( | |
| label="Webcam Recording (press the Record button, then Stop)", | |
| format="mp4" | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Model Output") | |
| output_plot = gr.Plot(label="Analytics Chart") | |
| analyze_btn.click( | |
| fn=video_inference, | |
| inputs=[video, duration], | |
| outputs=[output_text, output_plot] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(debug=True) |