HF_Agents_Final_Project / src /video_processing_tool.py
Yago Bolivar
Refactor speech_to_text.py to implement a singleton ASR pipeline, enhance error handling, and introduce SpeechToTextTool for better integration. Update spreadsheet_tool.py to support querying and improve parsing functionality, including CSV support. Enhance video_processing_tool.py with new tasks for metadata extraction and frame extraction, while improving object detection capabilities and initialization checks.
87aa741
raw
history blame
21.9 kB
import os
import yt_dlp
import cv2
import numpy as np
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
import tempfile
import re
import shutil
import time
from smolagents.tools import Tool
class VideoProcessingTool(Tool):
"""
Analyzes video content, extracting information such as frames, audio, or metadata.
Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
"""
name = "video_processor"
description = "Analyzes video content from a file path or YouTube URL. Can extract frames, detect objects, get transcripts, and provide video metadata."
inputs = {
"file_path": {"type": "string", "description": "Path to the video file or YouTube URL.", "nullable": True},
"task": {"type": "string", "description": "Specific task to perform (e.g., 'extract_frames', 'get_transcript', 'detect_objects', 'get_metadata').", "nullable": True},
"task_parameters": {"type": "object", "description": "Parameters for the specific task (e.g., frame extraction interval, object detection confidence).", "nullable": True}
}
outputs = {"result": {"type": "object", "description": "The result of the video processing task, e.g., list of frame paths, transcript text, object detection results, or metadata dictionary."}}
output_type = "object"
def __init__(self, model_cfg_path=None, model_weights_path=None, class_names_path=None, temp_dir_base=None, *args, **kwargs):
"""
Initializes the VideoProcessingTool.
Args:
model_cfg_path (str, optional): Path to the object detection model's configuration file.
model_weights_path (str, optional): Path to the object detection model's weights file.
class_names_path (str, optional): Path to the file containing class names for the model.
temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
"""
super().__init__(*args, **kwargs)
self.is_initialized = False # Will be set to True after successful setup
if temp_dir_base:
self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
else:
self.temp_dir = tempfile.mkdtemp()
self.object_detection_model = None
self.class_names = []
if model_cfg_path and model_weights_path and class_names_path:
if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
try:
self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
with open(class_names_path, "r") as f:
self.class_names = [line.strip() for line in f.readlines()]
print("CV Model loaded successfully.")
except Exception as e:
print(f"Error loading CV model: {e}. Object detection will not be available.")
self.object_detection_model = None
else:
print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
else:
print("CV model paths not provided. Object detection will not be available.")
self.is_initialized = True
def forward(self, file_path: str = None, task: str = "get_metadata", task_parameters: dict = None):
"""
Main entry point for video processing tasks.
"""
if not self.is_initialized:
return {"error": "Tool not initialized properly."}
if task_parameters is None:
task_parameters = {}
is_youtube_url = file_path and ("youtube.com/" in file_path or "youtu.be/" in file_path)
video_source_path = file_path
if is_youtube_url:
download_resolution = task_parameters.get("resolution", "360p")
download_result = self.download_video(file_path, resolution=download_resolution)
if download_result.get("error"):
return download_result
video_source_path = download_result.get("file_path")
if not video_source_path or not os.path.exists(video_source_path):
return {"error": f"Failed to download or locate video from URL: {file_path}"}
elif file_path and not os.path.exists(file_path):
return {"error": f"Video file not found: {file_path}"}
elif not file_path and task not in ['get_transcript']: # transcript can work with URL directly
return {"error": "File path is required for this task."}
if task == "get_metadata":
return self.get_video_metadata(video_source_path)
elif task == "extract_frames":
interval_seconds = task_parameters.get("interval_seconds", 5)
max_frames = task_parameters.get("max_frames")
return self.extract_frames_from_video(video_source_path, interval_seconds=interval_seconds, max_frames=max_frames)
elif task == "get_transcript":
# Use original file_path which might be the URL
return self.get_youtube_transcript(file_path)
elif task == "detect_objects":
if not self.object_detection_model:
return {"error": "Object detection model not loaded."}
confidence_threshold = task_parameters.get("confidence_threshold", 0.5)
frames_to_process = task_parameters.get("frames_to_process", 5) # Process N frames
return self.detect_objects_in_video(video_source_path, confidence_threshold=confidence_threshold, num_frames_to_sample=frames_to_process)
# Add more tasks as needed, e.g., extract_audio
else:
return {"error": f"Unsupported task: {task}"}
def _extract_video_id(self, youtube_url):
"""Extract the YouTube video ID from a URL."""
match = re.search(r"(?:v=|\/|embed\/|watch\?v=|youtu\.be\/)([0-9A-Za-z_-]{11})", youtube_url)
if match:
return match.group(1)
return None
def download_video(self, youtube_url, resolution="360p"):
"""Download YouTube video for processing."""
video_id = self._extract_video_id(youtube_url)
if not video_id:
return {"error": "Invalid YouTube URL or could not extract video ID."}
output_file_name = f"{video_id}.mp4"
output_file_path = os.path.join(self.temp_dir, output_file_name)
if os.path.exists(output_file_path): # Avoid re-downloading
return {"success": True, "file_path": output_file_path, "message": "Video already downloaded."}
try:
ydl_opts = {
'format': f'bestvideo[height<={resolution[:-1]}][ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
'outtmpl': output_file_path,
'noplaylist': True,
'quiet': True,
'no_warnings': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([youtube_url])
if not os.path.exists(output_file_path): # Check if download actually created the file
# Fallback for some formats if mp4 direct is not available
ydl_opts['format'] = f'best[height<={resolution[:-1]}]' # more generic
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(youtube_url, download=True)
# yt-dlp might save with a different extension, find the downloaded file
downloaded_files = [f for f in os.listdir(self.temp_dir) if f.startswith(video_id)]
if downloaded_files:
actual_file_path = os.path.join(self.temp_dir, downloaded_files[0])
if actual_file_path != output_file_path and actual_file_path.endswith(('.mkv', '.webm', '.flv')):
# Minimal conversion to mp4 if needed, or just use the downloaded format if cv2 supports it
# For simplicity, we'll assume cv2 can handle common formats or user ensures mp4 compatible download
output_file_path = actual_file_path # Use the actual downloaded file
elif not actual_file_path.endswith('.mp4'): # if it's not mp4 and not handled above
return {"error": f"Downloaded video is not in a directly usable format: {downloaded_files[0]}"}
if os.path.exists(output_file_path):
return {"success": True, "file_path": output_file_path}
else:
return {"error": "Video download failed, file not found after attempt."}
except yt_dlp.utils.DownloadError as e:
return {"error": f"yt-dlp download error: {str(e)}"}
except Exception as e:
return {"error": f"Failed to download video: {str(e)}"}
def get_video_metadata(self, video_path):
"""Extract metadata from the video file."""
if not os.path.exists(video_path):
return {"error": f"Video file not found: {video_path}"}
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return {"error": "Could not open video file."}
metadata = {
"frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
"fps": cap.get(cv2.CAP_PROP_FPS),
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
"duration": cap.get(cv2.CAP_PROP_FRAME_COUNT) / cap.get(cv2.CAP_PROP_FPS)
}
cap.release()
return {"success": True, "metadata": metadata}
def extract_frames_from_video(self, video_path, interval_seconds=5, max_frames=None):
"""
Extracts frames from the video at specified intervals.
Args:
video_path (str): Path to the video file.
interval_seconds (int): Interval in seconds between frames.
max_frames (int, optional): Maximum number of frames to extract.
Returns:
dict: {"success": True, "extracted_frame_paths": [...] } or {"error": "..."}
"""
if not os.path.exists(video_path):
return {"error": f"Video file not found: {video_path}"}
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return {"error": "Could not open video file."}
fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * interval_seconds)
extracted_frame_paths = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frame_id = int(frame_count / frame_interval)
frame_file_path = os.path.join(self.temp_dir, f"frame_{frame_id:04d}.jpg")
cv2.imwrite(frame_file_path, frame)
extracted_frame_paths.append(frame_file_path)
if max_frames and len(extracted_frame_paths) >= max_frames:
break
frame_count += 1
cap.release()
return {"success": True, "extracted_frame_paths": extracted_frame_paths}
def get_youtube_transcript(self, youtube_url, languages=None):
"""Get the transcript/captions of a YouTube video."""
if languages is None:
languages = ['en', 'en-US'] # Default to English
video_id = self._extract_video_id(youtube_url)
if not video_id:
return {"error": "Invalid YouTube URL or could not extract video ID."}
try:
# Reverting to list_transcripts due to issues with list() in the current env
transcript_list_obj = YouTubeTranscriptApi.list_transcripts(video_id)
transcript = None
# Try to find a manual transcript first in the specified languages
try:
transcript = transcript_list_obj.find_manually_created_transcript(languages)
except NoTranscriptFound:
# If no manual transcript, try to find a generated one
# This will raise NoTranscriptFound if it also fails, which is caught below.
transcript = transcript_list_obj.find_generated_transcript(languages)
# Retry logic for transcript.fetch()
fetched_transcript_entries = None
max_attempts = 3 # Total attempts
last_fetch_exception = None
for attempt in range(max_attempts):
try:
fetched_transcript_entries = transcript.fetch()
last_fetch_exception = None # Clear exception on success
break # Successful fetch
except Exception as e_fetch:
last_fetch_exception = e_fetch
if attempt < max_attempts - 1:
time.sleep(1) # Wait 1 second before retrying
# If it's the last attempt, the loop will end, and last_fetch_exception will be set.
if last_fetch_exception: # If all attempts failed
raise last_fetch_exception # Re-raise the last exception from fetch()
# Correctly access the 'text' attribute
full_transcript_text = " ".join([entry.text for entry in fetched_transcript_entries])
return {
"success": True,
"transcript": full_transcript_text,
"transcript_entries": fetched_transcript_entries
}
except TranscriptsDisabled:
return {"error": "Transcripts are disabled for this video."}
except NoTranscriptFound: # This will catch if neither manual nor generated is found for the languages
return {"error": f"No transcript found for the video in languages: {languages}."}
except Exception as e:
# Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
return {"error": f"Failed to get transcript: {str(e)}"}
def detect_objects_in_video(self, video_path, confidence_threshold=0.5, num_frames_to_sample=5, target_fps=1):
"""
Detects objects in the video and returns the count of specified objects.
Args:
video_path (str): Path to the video file.
confidence_threshold (float): Minimum confidence for an object to be counted.
num_frames_to_sample (int): Number of frames to sample for object detection.
target_fps (int): Target frames per second for processing.
Returns:
dict: {"success": True, "object_counts": {...}} or {"error": "..."}
"""
if not self.object_detection_model or not self.class_names:
return {"error": "Object detection model not loaded or class names missing."}
if not os.path.exists(video_path):
return {"error": f"Video file not found: {video_path}"}
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return {"error": "Could not open video file."}
object_counts = {cls: 0 for cls in self.class_names}
frame_count = 0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
sample_interval = max(1, total_frames // num_frames_to_sample)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % sample_interval == 0:
height, width = frame.shape[:2]
blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
self.object_detection_model.setInput(blob)
layer_names = self.object_detection_model.getLayerNames()
# Handle potential differences in getUnconnectedOutLayers() return value
unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
else: # For typical cases
output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
detections = self.object_detection_model.forward(output_layer_names)
for detection_set in detections: # Detections can come from multiple output layers
for detection in detection_set:
scores = detection[5:]
class_id = np.argmax(scores)
confidence = scores[class_id]
if confidence > confidence_threshold:
detected_class_name = self.class_names[class_id]
object_counts[detected_class_name] += 1
frame_count += 1
cap.release()
return {"success": True, "object_counts": object_counts}
def cleanup(self):
"""Remove temporary files and directory."""
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
# print(f"Cleaned up temp directory: {self.temp_dir}")
# Example Usage (for testing purposes, assuming model files are in ./models/cv/):
if __name__ == '__main__':
# Create dummy model files for local testing if they don't exist
os.makedirs("./models/cv", exist_ok=True)
dummy_cfg = "./models/cv/dummy-yolov3-tiny.cfg"
dummy_weights = "./models/cv/dummy-yolov3-tiny.weights"
dummy_names = "./models/cv/dummy-coco.names"
if not os.path.exists(dummy_cfg): open(dummy_cfg, 'w').write("# Dummy YOLOv3 tiny config")
if not os.path.exists(dummy_weights): open(dummy_weights, 'w').write("dummy weights") # Actual weights file is binary
if not os.path.exists(dummy_names): open(dummy_names, 'w').write("bird\\ncat\\ndog\\nperson")
# Initialize tool
# Note: For real object detection, provide paths to actual .cfg, .weights, and .names files.
# For example, from: https://pjreddie.com/darknet/yolo/
video_tool = VideoProcessingTool(
model_cfg_path=dummy_cfg, # Replace with actual path to YOLOv3-tiny.cfg or similar
model_weights_path=dummy_weights, # Replace with actual path to YOLOv3-tiny.weights
class_names_path=dummy_names # Replace with actual path to coco.names
)
# Test 1: Get Transcript
# Replace with a video that has transcripts
transcript_test_url = "https://www.youtube.com/watch?v=1htKBjuUWec" # Stargate SG-1 clip
print(f"--- Testing Transcript for: {transcript_test_url} ---")
transcript_info = video_tool.process_video(transcript_test_url, "transcript")
if transcript_info.get("success"):
print("Transcript (first 100 chars):", transcript_info.get("transcript", "")[:100])
else:
print("Transcript Error:", transcript_info.get("error"))
print("\\n")
# Test 2: Find Dialogue Response
dialogue_test_url = "https://www.youtube.com/watch?v=1htKBjuUWec" # Stargate SG-1 clip
print(f"--- Testing Dialogue Response for: {dialogue_test_url} ---")
dialogue_info = video_tool.process_video(
dialogue_test_url,
"dialogue_response",
query_params={"query_phrase": "Isn't that hot?"}
)
if dialogue_info.get("success"):
print(f"Query: 'Isn't that hot?', Response: '{dialogue_info.get('response_text')}'")
else:
print("Dialogue Error:", dialogue_info.get("error"))
print("\\n")
# Test 3: Object Counting (will likely use dummy model and might not detect much without real video/model)
# Replace with a video URL that you want to test object counting on.
# This example will download a short video.
object_count_test_url = "https://www.youtube.com/watch?v=L1vXCYZAYYM" # Birds video
print(f"--- Testing Object Counting for: {object_count_test_url} ---")
# Ensure you have actual model files for this to work meaningfully.
# The dummy model files will likely result in zero counts or errors if OpenCV can't parse them.
# For this example, we expect it to run through, but actual detection depends on valid models.
if video_tool.object_detection_model:
count_info = video_tool.process_video(
object_count_test_url,
"object_count",
query_params={"target_classes": ["bird"], "resolution": "360p"}
)
if count_info.get("success"):
print("Object Counts:", count_info)
else:
print("Object Counting Error:", count_info.get("error"))
else:
print("Object detection model not loaded, skipping object count test.")
# Cleanup
video_tool.cleanup()
# Clean up dummy model files if they were created by this script
# (Be careful if you have real files with these names)
# if os.path.exists(dummy_cfg) and "dummy-yolov3-tiny.cfg" in dummy_cfg : os.remove(dummy_cfg)
# if os.path.exists(dummy_weights) and "dummy-yolov3-tiny.weights" in dummy_weights: os.remove(dummy_weights)
# if os.path.exists(dummy_names) and "dummy-coco.names" in dummy_names: os.remove(dummy_names)
# if os.path.exists("./models/cv") and not os.listdir("./models/cv"): os.rmdir("./models/cv")
# if os.path.exists("./models") and not os.listdir("./models"): os.rmdir("./models")
print("\\nAll tests finished.")