Spaces:
Running
Running
import os | |
import time | |
import json | |
import gradio as gr | |
import cv2 | |
from google import genai | |
from google.genai import types | |
# Retrieve API key from environment variables | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
raise ValueError("Please set the GOOGLE_API_KEY environment variable with your Google Cloud API key.") | |
# Initialize the Gemini API client | |
client = genai.Client(api_key=GOOGLE_API_KEY) | |
MODEL_NAME = "gemini-2.5-pro-exp-03-25" # Model supporting video analysis | |
def upload_and_process_video(video_file: str, timeout: int = 300) -> types.File: | |
""" | |
Upload a video file to the Gemini API and wait for processing. | |
Args: | |
video_file (str): Path to the video file | |
timeout (int): Maximum time to wait for processing in seconds (default: 5 minutes) | |
Returns: | |
types.File: Processed video file object | |
""" | |
try: | |
video_file_obj = client.files.upload(file=video_file) | |
start_time = time.time() | |
while video_file_obj.state == "PROCESSING": | |
elapsed_time = time.time() - start_time | |
if elapsed_time > timeout: | |
raise TimeoutError(f"Video processing timed out after {timeout} seconds.") | |
print(f"Processing {video_file}... ({int(elapsed_time)}s elapsed)") | |
time.sleep(10) | |
video_file_obj = client.files.get(name=video_file_obj.name) | |
if video_file_obj.state == "FAILED": | |
raise ValueError(f"Video processing failed: {video_file_obj.state}") | |
print(f"Video processing complete: {video_file_obj.uri}") | |
return video_file_obj | |
except Exception as e: | |
raise Exception(f"Error uploading video: {str(e)}") | |
def hhmmss_to_seconds(timestamp: str) -> float: | |
""" | |
Convert HH:MM:SS timestamp to seconds. | |
Args: | |
timestamp (str): Time in HH:MM:SS format | |
Returns: | |
float: Time in seconds | |
""" | |
try: | |
h, m, s = map(float, timestamp.split(":")) | |
return h * 3600 + m * 60 + s | |
except ValueError: | |
return 0.0 # Default to 0 if parsing fails | |
def extract_key_frames(video_file: str, key_frames_response: str) -> list: | |
""" | |
Extract key frames from the video based on Gemini API response. | |
Args: | |
video_file (str): Path to the video file | |
key_frames_response (str): Raw response from Gemini API | |
Returns: | |
list: List of tuples (image, caption) | |
""" | |
extracted_frames = [] | |
cap = cv2.VideoCapture(video_file) | |
if not cap.isOpened(): | |
print("Error: Could not open video file.") | |
return extracted_frames | |
# Strip Markdown code block if present | |
cleaned_response = key_frames_response.strip() | |
if cleaned_response.startswith("```json") and cleaned_response.endswith("```"): | |
cleaned_response = cleaned_response[7:-3].strip() | |
elif cleaned_response.startswith("```") and cleaned_response.endswith("```"): | |
cleaned_response = cleaned_response[3:-3].strip() | |
print(f"Cleaned key frames response: {cleaned_response}") # Debug output | |
try: | |
# Try parsing as JSON | |
key_frames = json.loads(cleaned_response) | |
if not isinstance(key_frames, list): | |
raise ValueError("Response is not a list.") | |
except json.JSONDecodeError as e: | |
print(f"JSON parsing failed: {str(e)}. Falling back to text parsing.") | |
# Fallback: Parse plain text with timecodes (e.g., "00:00:03 - Scene" or "00:00:03: Scene") | |
key_frames = [] | |
lines = cleaned_response.strip().split("\n") | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
if " - " in line: | |
timestamp, title = line.split(" - ", 1) | |
key_frames.append({"timecode": timestamp.strip(), "title": title.strip()}) | |
elif ": " in line and len(line.split(":")[0]) == 2: # Check for HH:MM:SS format | |
timestamp, title = line.split(": ", 1) | |
key_frames.append({"timecode": timestamp.strip(), "title": title.strip()}) | |
elif len(line.split(":")) == 3: # Rough check for standalone HH:MM:SS | |
key_frames.append({"timecode": line.strip(), "title": "Untitled"}) | |
for frame in key_frames: | |
timestamp = frame.get("timecode", frame.get("timestamp", "")) | |
title = frame.get("title", frame.get("caption", "Untitled")) | |
if not timestamp: | |
continue | |
seconds = hhmmss_to_seconds(timestamp) | |
if seconds == 0.0: # Skip invalid timestamps | |
continue | |
cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000) | |
ret, frame_img = cap.read() | |
if ret: | |
frame_rgb = cv2.cvtColor(frame_img, cv2.COLOR_BGR2RGB) | |
caption = f"{timestamp}: {title}" | |
extracted_frames.append((frame_rgb, caption)) | |
cap.release() | |
return extracted_frames | |
def analyze_video(video_file: str, user_query: str) -> tuple[str, list]: | |
""" | |
Analyze the video using the Gemini API and extract key frames. | |
Args: | |
video_file (str): Path to the video file | |
user_query (str): Optional query to guide the analysis | |
Returns: | |
tuple: (Markdown report, list of key frames as (image, caption) tuples) | |
""" | |
# Validate input | |
if not video_file or not os.path.exists(video_file): | |
return "Please upload a valid video file.", [] | |
if not video_file.lower().endswith('.mp4'): | |
return "Please upload an MP4 video file.", [] | |
try: | |
# Upload and process the video | |
video_file_obj = upload_and_process_video(video_file) | |
# Step 1: Generate detailed summary | |
summary_prompt = "Provide a detailed summary of this video with timestamps for key sections." | |
if user_query: | |
summary_prompt += f" Focus on: {user_query}" | |
summary_response = client.models.generate_content( | |
model=MODEL_NAME, | |
contents=[video_file_obj, summary_prompt] | |
) | |
summary = summary_response.text | |
# Step 2: Extract key frames with few-shot examples | |
key_frames_prompt = ( | |
"Identify key frames in this video and return them as a JSON array. " | |
"Each object must have 'timecode' (in HH:MM:SS format) and 'title' describing the scene. " | |
"Ensure the response is valid JSON. Here are examples of the expected format:\n" | |
"Example 1: For a video of a car chase:\n" | |
"```json\n" | |
"[\n" | |
" {\"timecode\": \"00:00:00\", \"title\": \"Car chase begins on highway\"},\n" | |
" {\"timecode\": \"00:00:10\", \"title\": \"Police car joins pursuit\"}\n" | |
"]\n" | |
"```\n" | |
"Example 2: For a nature video:\n" | |
"```json\n" | |
"[\n" | |
" {\"timecode\": \"00:00:05\", \"title\": \"Bird flies across screen\"},\n" | |
" {\"timecode\": \"00:00:15\", \"title\": \"Deer appears in forest\"}\n" | |
"]\n" | |
"```\n" | |
"Now, provide the key frames for this video in the same JSON format." | |
) | |
if user_query: | |
key_frames_prompt += f" Focus on: {user_query}" | |
key_frames_response = client.models.generate_content( | |
model=MODEL_NAME, | |
contents=[video_file_obj, key_frames_prompt] | |
) | |
key_frames = extract_key_frames(video_file, key_frames_response.text) | |
# Generate Markdown report | |
markdown_report = ( | |
"## Video Analysis Report\n\n" | |
f"**Summary:**\n{summary}\n" | |
) | |
if key_frames: | |
markdown_report += "\n**Key Frames Identified:**\n" | |
for i, (_, caption) in enumerate(key_frames, 1): | |
markdown_report += f"- Frame {i}: {caption}\n" | |
else: | |
markdown_report += "\n*No key frames extracted. Check the console for the raw response.*\n" | |
return markdown_report, key_frames | |
except Exception as e: | |
error_msg = ( | |
"## Video Analysis Report\n\n" | |
f"**Error:** Unable to analyze video.\n" | |
f"Details: {str(e)}\n" | |
"Please check your API key, ensure the video is valid, or try again later." | |
) | |
return error_msg, [] | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=analyze_video, | |
inputs=[ | |
gr.Video(label="Upload Video File (MP4)"), | |
gr.Textbox(label="Analysis Query (optional)", | |
placeholder="e.g., focus on main events or themes") | |
], | |
outputs=[ | |
gr.Markdown(label="Video Analysis Report"), | |
gr.Gallery(label="Key Frames", columns=2) | |
], | |
title="AI Video Analysis Agent with Gemini", | |
description=( | |
"Upload an MP4 video to get a detailed summary and key frames using Google's Gemini API. " | |
"This tool analyzes the video content directly and extracts key moments as images. " | |
"Optionally, provide a query to guide the analysis." | |
) | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) |