Spaces:
Running
Running
File size: 6,684 Bytes
f8aaa9d 7c2c622 f8aaa9d c43728b f8aaa9d c137e5c f8aaa9d c137e5c f8aaa9d c137e5c f8aaa9d c43728b f8aaa9d c137e5c f8aaa9d 001b623 f8aaa9d 001b623 0425992 c43728b f8aaa9d d638712 c137e5c d638712 c137e5c d638712 7c2c622 f8aaa9d 7c2c622 c137e5c 0f96bc2 7c2c622 cba459f 7c2c622 cba459f 7c2c622 63595a8 7c2c622 c137e5c 7c2c622 c137e5c 7c2c622 0425992 c137e5c 7c2c622 63595a8 7c2c622 63595a8 7c2c622 d638712 c137e5c d638712 0f96bc2 7c2c622 001b623 7c2c622 001b623 0f96bc2 c137e5c 7c2c622 c137e5c f8aaa9d 0425992 63595a8 d638712 0425992 d638712 63595a8 7c2c622 001b623 f8aaa9d c137e5c f8aaa9d c137e5c 001b623 c137e5c 001b623 f8aaa9d c137e5c f8aaa9d 0f96bc2 03c6357 c137e5c 0f96bc2 3f2c22a 0f96bc2 f8aaa9d c137e5c 63595a8 f8aaa9d c137e5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import gradio as gr
import cv2
from google import genai
from google.genai.types import Part
from tenacity import retry, stop_after_attempt, wait_random_exponential
# Retrieve API key from environment variables
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
# Initialize the Gemini API client
client = genai.Client(api_key=GOOGLE_API_KEY)
# Define the model name
MODEL_NAME = "gemini-2.0-flash"
@retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(3))
def call_gemini(video_file: str, prompt: str) -> str:
"""
Call the Gemini model with a video file and prompt.
Args:
video_file (str): Path to the video file
prompt (str): Text prompt to guide the analysis
Returns:
str: Response text from the Gemini API
"""
with open(video_file, "rb") as f:
file_bytes = f.read()
response = client.models.generate_content(
model=MODEL_NAME,
contents=[
Part(file_data=file_bytes, mime_type="video/mp4"),
Part(text=prompt)
]
)
return response.text
def safe_call_gemini(video_file: str, prompt: str) -> str:
"""
Wrapper for call_gemini that catches exceptions and returns error messages.
Args:
video_file (str): Path to the video file
prompt (str): Text prompt for the API
Returns:
str: API response or error message
"""
try:
return call_gemini(video_file, prompt)
except Exception as e:
error_msg = f"Gemini call failed: {str(e)}"
print(error_msg)
return error_msg
def hhmmss_to_seconds(time_str: str) -> float:
"""
Convert a HH:MM:SS formatted string into seconds.
Args:
time_str (str): Time string in HH:MM:SS format
Returns:
float: Time in seconds
"""
parts = time_str.strip().split(":")
parts = [float(p) for p in parts]
if len(parts) == 3:
return parts[0] * 3600 + parts[1] * 60 + parts[2]
elif len(parts) == 2:
return parts[0] * 60 + parts[1]
else:
return parts[0]
def get_key_frames(video_file: str, summary: str, user_query: str) -> list:
"""
Extract key frames from the video based on timestamps provided by Gemini.
Args:
video_file (str): Path to the video file
summary (str): Video summary to provide context
user_query (str): Optional user query to focus the analysis
Returns:
list: List of tuples (image_array, caption)
"""
prompt = (
"List the key timestamps in the video and a brief description of the event at that time. "
"Output one line per event in the format: HH:MM:SS - description. Do not include any extra text."
)
prompt += f" Video Summary: {summary}"
if user_query:
prompt += f" Focus on: {user_query}"
key_frames_response = safe_call_gemini(video_file, prompt)
if "Gemini call failed" in key_frames_response:
return []
lines = key_frames_response.strip().split("\n")
key_frames = []
for line in lines:
if " - " in line:
parts = line.split(" - ", 1)
timestamp = parts[0].strip()
description = parts[1].strip()
key_frames.append({"timestamp": timestamp, "description": description})
extracted_frames = []
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
print("Error: Could not open the uploaded video file.")
return extracted_frames
for frame_obj in key_frames:
ts = frame_obj.get("timestamp")
description = frame_obj.get("description", "")
try:
seconds = hhmmss_to_seconds(ts)
except Exception:
continue
cap.set(cv2.CAP_PROP_POS_MSEC, seconds * 1000)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
caption = f"{ts}: {description}"
extracted_frames.append((frame_rgb, caption))
cap.release()
return extracted_frames
def analyze_video(video_file: str, user_query: str) -> (str, list):
"""
Analyze the video and generate a summary and key frames.
Args:
video_file (str): Path to the video file
user_query (str): Optional query to guide the analysis
Returns:
tuple: (Markdown report, list of key frames)
"""
summary_prompt = "Summarize this video."
if user_query:
summary_prompt += f" Also focus on: {user_query}"
summary = safe_call_gemini(video_file, summary_prompt)
markdown_report = f"## Video Analysis Report\n\n**Summary:**\n\n{summary}\n"
key_frames_gallery = get_key_frames(video_file, summary, user_query)
if not key_frames_gallery:
markdown_report += "\n*No key frames were extracted.*\n"
else:
markdown_report += "\n**Key Frames Extracted:**\n"
for idx, (img, caption) in enumerate(key_frames_gallery, start=1):
markdown_report += f"- **Frame {idx}:** {caption}\n"
return markdown_report, key_frames_gallery
def gradio_interface(video_file, user_query: str) -> (str, list):
"""
Gradio interface function to process video and return results.
Args:
video_file (str): Path to the uploaded video file
user_query (str): Optional query to guide analysis
Returns:
tuple: (Markdown report, gallery of key frames)
"""
if not video_file or not os.path.exists(video_file):
return "Please upload a valid video file.", []
if not video_file.lower().endswith('.mp4'):
return "Please upload an MP4 video file.", []
return analyze_video(video_file, user_query)
# Define the Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="Upload Video File"),
gr.Textbox(label="Analysis Query (optional): guide the focus of the analysis",
placeholder="e.g., focus on unusual movements near the entrance")
],
outputs=[
gr.Markdown(label="Security & Surveillance Analysis Report"),
gr.Gallery(label="Extracted Key Frames", columns=2)
],
title="AI Video Analysis and Summariser Agent",
description=(
"This tool uses Google's Gemini 2.0 Flash model to analyze an uploaded video. "
"It returns a brief summary and extracts key frames based on that summary. "
"Provide a video file and, optionally, a query to guide the analysis."
)
)
if __name__ == "__main__":
iface.launch() |