Yago Bolivar commited on
Commit
b1939df
·
1 Parent(s): 8818829

add smolagents docstrings

Browse files
src/file_processing_tool.py CHANGED
@@ -5,6 +5,10 @@ import os
5
  import mimetypes
6
 
7
  class FileIdentifier:
 
 
 
 
8
  def __init__(self):
9
  mimetypes.init()
10
  # Mapping from simple type to action and common extensions
 
5
  import mimetypes
6
 
7
  class FileIdentifier:
8
+ """
9
+ Identifies file types and maps them to the appropriate processing tool based on file extension.
10
+ Useful for routing files to specialized tools such as speech-to-text, spreadsheet parser, image processor, etc.
11
+ """
12
  def __init__(self):
13
  mimetypes.init()
14
  # Mapping from simple type to action and common extensions
src/image_processing_tool.py CHANGED
@@ -20,6 +20,11 @@ vision_pipeline = pipeline(
20
  )
21
 
22
  class ImageProcessor:
 
 
 
 
 
23
  def __init__(self):
24
  self.vision_pipeline = vision_pipeline
25
 
 
20
  )
21
 
22
  class ImageProcessor:
23
+ """
24
+ Processes image files, including OCR, vision reasoning, and chessboard analysis.
25
+ Integrates computer vision and chess engines for advanced image-based tasks.
26
+ Useful for extracting text, analyzing chess positions, and general image understanding.
27
+ """
28
  def __init__(self):
29
  self.vision_pipeline = vision_pipeline
30
 
src/markdown_table_parser.py CHANGED
@@ -6,6 +6,7 @@ def parse_markdown_table(markdown_text: str) -> dict[str, list[str]] | None:
6
  Parses the first valid Markdown table found in a string.
7
  Returns a dictionary (headers as keys, lists of cell content as values)
8
  or None if no valid table is found.
 
9
  """
10
  lines = [line.rstrip() for line in markdown_text.split('\n') if line.strip()]
11
  n = len(lines)
 
6
  Parses the first valid Markdown table found in a string.
7
  Returns a dictionary (headers as keys, lists of cell content as values)
8
  or None if no valid table is found.
9
+ Useful for converting markdown tables into Python data structures for further analysis.
10
  """
11
  lines = [line.rstrip() for line in markdown_text.split('\n') if line.strip()]
12
  n = len(lines)
src/python_tool.py CHANGED
@@ -7,7 +7,10 @@ import traceback
7
  from typing import Dict, Any, Optional, Union, List
8
 
9
  class CodeExecutionTool:
10
- """Tool to safely execute Python code files and extract numeric outputs."""
 
 
 
11
 
12
  def __init__(self, timeout: int = 5, max_output_size: int = 10000):
13
  self.timeout = timeout # Maximum execution time in seconds
 
7
  from typing import Dict, Any, Optional, Union, List
8
 
9
  class CodeExecutionTool:
10
+ """
11
+ Executes Python code in a controlled environment for safe code interpretation.
12
+ Useful for evaluating code snippets and returning their output or errors.
13
+ """
14
 
15
  def __init__(self, timeout: int = 5, max_output_size: int = 10000):
16
  self.timeout = timeout # Maximum execution time in seconds
src/speech_to_text.py CHANGED
@@ -11,7 +11,11 @@ asr_pipeline = pipeline(
11
 
12
  def transcribe_audio(audio_filepath):
13
  """
14
- Transcribes an audio file using the Hugging Face ASR pipeline.
 
 
 
 
15
  """
16
  try:
17
  transcription = asr_pipeline(audio_filepath, return_timestamps=True)
 
11
 
12
  def transcribe_audio(audio_filepath):
13
  """
14
+ Converts speech in an audio file (e.g., .mp3) to text using speech recognition.
15
+ Args:
16
+ audio_filepath (str): Path to the audio file.
17
+ Returns:
18
+ str: Transcribed text from the audio.
19
  """
20
  try:
21
  transcription = asr_pipeline(audio_filepath, return_timestamps=True)
src/spreadsheet_tool.py CHANGED
@@ -5,7 +5,10 @@ import numpy as np
5
 
6
 
7
  class SpreadsheetTool:
8
- """Tool for parsing and extracting data from Excel (.xlsx) files."""
 
 
 
9
 
10
  def __init__(self):
11
  """Initialize the SpreadsheetTool."""
 
5
 
6
 
7
  class SpreadsheetTool:
8
+ """
9
+ Parses spreadsheet files (e.g., .xlsx) and extracts tabular data for analysis.
10
+ Useful for reading, processing, and converting spreadsheet content to Python data structures.
11
+ """
12
 
13
  def __init__(self):
14
  """Initialize the SpreadsheetTool."""
src/text_reversal_tool.py CHANGED
@@ -1,5 +1,7 @@
1
  def reverse_text(text: str) -> str:
2
- """Reverses the input string."""
 
 
3
  return text[::-1]
4
 
5
  if __name__ == '__main__':
 
1
  def reverse_text(text: str) -> str:
2
+ """
3
+ Reverses or processes reversed text, useful for decoding or analyzing reversed strings.
4
+ """
5
  return text[::-1]
6
 
7
  if __name__ == '__main__':
src/video_processing_tool.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yt_dlp
3
+ import cv2
4
+ import numpy as np
5
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
6
+ import tempfile
7
+ import re
8
+ import shutil
9
+ import time # Added for retry logic
10
+
11
+ class VideoProcessingTool:
12
+ """
13
+ Analyzes video content, extracting information such as frames, audio, or metadata.
14
+ Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
15
+ """
16
+
17
+ def __init__(self, model_cfg_path=None, model_weights_path=None, class_names_path=None, temp_dir_base=None):
18
+ """
19
+ Initializes the VideoProcessingTool.
20
+
21
+ Args:
22
+ model_cfg_path (str, optional): Path to the object detection model's configuration file.
23
+ model_weights_path (str, optional): Path to the object detection model's weights file.
24
+ class_names_path (str, optional): Path to the file containing class names for the model.
25
+ temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
26
+ """
27
+ if temp_dir_base:
28
+ self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
29
+ else:
30
+ self.temp_dir = tempfile.mkdtemp()
31
+
32
+ self.object_detection_model = None
33
+ self.class_names = []
34
+
35
+ if model_cfg_path and model_weights_path and class_names_path:
36
+ if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
37
+ try:
38
+ self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
39
+ # Set preferable backend and target
40
+ self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
41
+ self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
42
+ with open(class_names_path, "r") as f:
43
+ self.class_names = [line.strip() for line in f.readlines()]
44
+ except Exception as e:
45
+ print(f"Error loading CV model: {e}. Object detection will not be available.")
46
+ self.object_detection_model = None
47
+ else:
48
+ print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
49
+
50
+ def _extract_video_id(self, youtube_url):
51
+ """Extract the YouTube video ID from a URL."""
52
+ match = re.search(r"(?:v=|\/|embed\/|watch\?v=|youtu\.be\/)([0-9A-Za-z_-]{11})", youtube_url)
53
+ if match:
54
+ return match.group(1)
55
+ return None
56
+
57
+ def download_video(self, youtube_url, resolution="360p"):
58
+ """Download YouTube video for processing."""
59
+ video_id = self._extract_video_id(youtube_url)
60
+ if not video_id:
61
+ return {"error": "Invalid YouTube URL or could not extract video ID."}
62
+
63
+ output_file_name = f"{video_id}.mp4"
64
+ output_file_path = os.path.join(self.temp_dir, output_file_name)
65
+
66
+ if os.path.exists(output_file_path): # Avoid re-downloading
67
+ return {"success": True, "file_path": output_file_path, "message": "Video already downloaded."}
68
+
69
+ try:
70
+ ydl_opts = {
71
+ 'format': f'bestvideo[height<={resolution[:-1]}][ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
72
+ 'outtmpl': output_file_path,
73
+ 'noplaylist': True,
74
+ 'quiet': True,
75
+ 'no_warnings': True,
76
+ }
77
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
78
+ ydl.download([youtube_url])
79
+
80
+ if not os.path.exists(output_file_path): # Check if download actually created the file
81
+ # Fallback for some formats if mp4 direct is not available
82
+ ydl_opts['format'] = f'best[height<={resolution[:-1]}]' # more generic
83
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
84
+ info_dict = ydl.extract_info(youtube_url, download=True)
85
+ # yt-dlp might save with a different extension, find the downloaded file
86
+ downloaded_files = [f for f in os.listdir(self.temp_dir) if f.startswith(video_id)]
87
+ if downloaded_files:
88
+ actual_file_path = os.path.join(self.temp_dir, downloaded_files[0])
89
+ if actual_file_path != output_file_path and actual_file_path.endswith(('.mkv', '.webm', '.flv')):
90
+ # Minimal conversion to mp4 if needed, or just use the downloaded format if cv2 supports it
91
+ # For simplicity, we'll assume cv2 can handle common formats or user ensures mp4 compatible download
92
+ output_file_path = actual_file_path # Use the actual downloaded file
93
+ elif not actual_file_path.endswith('.mp4'): # if it's not mp4 and not handled above
94
+ return {"error": f"Downloaded video is not in a directly usable format: {downloaded_files[0]}"}
95
+
96
+
97
+ if os.path.exists(output_file_path):
98
+ return {"success": True, "file_path": output_file_path}
99
+ else:
100
+ return {"error": "Video download failed, file not found after attempt."}
101
+
102
+ except yt_dlp.utils.DownloadError as e:
103
+ return {"error": f"yt-dlp download error: {str(e)}"}
104
+ except Exception as e:
105
+ return {"error": f"Failed to download video: {str(e)}"}
106
+
107
+ def get_video_transcript(self, youtube_url, languages=None):
108
+ """Get the transcript/captions of a YouTube video."""
109
+ if languages is None:
110
+ languages = ['en', 'en-US'] # Default to English
111
+ video_id = self._extract_video_id(youtube_url)
112
+ if not video_id:
113
+ return {"error": "Invalid YouTube URL or could not extract video ID."}
114
+
115
+ try:
116
+ # Reverting to list_transcripts due to issues with list() in the current env
117
+ transcript_list_obj = YouTubeTranscriptApi.list_transcripts(video_id)
118
+
119
+ transcript = None
120
+ # Try to find a manual transcript first in the specified languages
121
+ try:
122
+ transcript = transcript_list_obj.find_manually_created_transcript(languages)
123
+ except NoTranscriptFound:
124
+ # If no manual transcript, try to find a generated one
125
+ # This will raise NoTranscriptFound if it also fails, which is caught below.
126
+ transcript = transcript_list_obj.find_generated_transcript(languages)
127
+
128
+ # Retry logic for transcript.fetch()
129
+ fetched_transcript_entries = None
130
+ max_attempts = 3 # Total attempts
131
+ last_fetch_exception = None
132
+
133
+ for attempt in range(max_attempts):
134
+ try:
135
+ fetched_transcript_entries = transcript.fetch()
136
+ last_fetch_exception = None # Clear exception on success
137
+ break # Successful fetch
138
+ except Exception as e_fetch:
139
+ last_fetch_exception = e_fetch
140
+ if attempt < max_attempts - 1:
141
+ time.sleep(1) # Wait 1 second before retrying
142
+ # If it's the last attempt, the loop will end, and last_fetch_exception will be set.
143
+
144
+ if last_fetch_exception: # If all attempts failed
145
+ raise last_fetch_exception # Re-raise the last exception from fetch()
146
+
147
+ # Correctly access the 'text' attribute
148
+ full_transcript_text = " ".join([entry.text for entry in fetched_transcript_entries])
149
+
150
+ return {
151
+ "success": True,
152
+ "transcript": full_transcript_text,
153
+ "transcript_entries": fetched_transcript_entries
154
+ }
155
+ except TranscriptsDisabled:
156
+ return {"error": "Transcripts are disabled for this video."}
157
+ except NoTranscriptFound: # This will catch if neither manual nor generated is found for the languages
158
+ return {"error": f"No transcript found for the video in languages: {languages}."}
159
+ except Exception as e:
160
+ # Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
161
+ return {"error": f"Failed to get transcript: {str(e)}"}
162
+
163
+ def count_objects_in_video(self, video_path, target_classes=None, confidence_threshold=0.5, frame_skip=5):
164
+ """
165
+ Counts specified objects appearing in the video using the loaded DNN model.
166
+ Determines the maximum number of target objects appearing simultaneously in any single frame.
167
+ Args:
168
+ video_path (str): Path to the video file.
169
+ target_classes (list, optional): A list of object classes (strings) to count (e.g., ["bird", "cat"]).
170
+ If None, counts all detected objects.
171
+ confidence_threshold (float): Minimum confidence for an object to be counted.
172
+ frame_skip (int): Process every Nth frame to speed up analysis.
173
+ Returns:
174
+ dict: A dictionary with counts or an error message.
175
+ e.g., {"success": True, "max_simultaneous_birds": 3, "max_simultaneous_cats": 1}
176
+ or {"error": "Object detection model not loaded."}
177
+ """
178
+ if not self.object_detection_model or not self.class_names:
179
+ return {"error": "Object detection model not loaded or class names missing."}
180
+ if not os.path.exists(video_path):
181
+ return {"error": f"Video file not found: {video_path}"}
182
+
183
+ cap = cv2.VideoCapture(video_path)
184
+ if not cap.isOpened():
185
+ return {"error": "Could not open video file."}
186
+
187
+ max_counts_per_class = {cls: 0 for cls in target_classes} if target_classes else {}
188
+ # If target_classes is None, we'd need to initialize for all detected classes,
189
+ # but for simplicity, let's require target_classes for now or adjust later.
190
+ if not target_classes:
191
+ # Defaulting to a common class if none specified, e.g. 'person'
192
+ # Or, one could count all unique classes detected. For GAIA, specific targets are better.
193
+ return {"error": "target_classes must be specified for counting."}
194
+
195
+
196
+ frame_count = 0
197
+ while cap.isOpened():
198
+ ret, frame = cap.read()
199
+ if not ret:
200
+ break
201
+
202
+ frame_count += 1
203
+ if frame_count % frame_skip != 0:
204
+ continue
205
+
206
+ height, width = frame.shape[:2]
207
+ blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
208
+ self.object_detection_model.setInput(blob)
209
+
210
+ layer_names = self.object_detection_model.getLayerNames()
211
+ # Handle potential differences in getUnconnectedOutLayers() return value
212
+ unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
213
+ if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
214
+ output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
215
+ else: # For typical cases
216
+ output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
217
+
218
+ detections = self.object_detection_model.forward(output_layer_names)
219
+
220
+ current_frame_counts = {cls: 0 for cls in target_classes}
221
+
222
+ for detection_set in detections: # Detections can come from multiple output layers
223
+ for detection in detection_set:
224
+ scores = detection[5:]
225
+ class_id = np.argmax(scores)
226
+ confidence = scores[class_id]
227
+
228
+ if confidence > confidence_threshold:
229
+ detected_class_name = self.class_names[class_id]
230
+ if detected_class_name in target_classes:
231
+ current_frame_counts[detected_class_name] += 1
232
+
233
+ for cls in target_classes:
234
+ if current_frame_counts[cls] > max_counts_per_class[cls]:
235
+ max_counts_per_class[cls] = current_frame_counts[cls]
236
+
237
+ cap.release()
238
+ result = {"success": True}
239
+ for cls, count in max_counts_per_class.items():
240
+ result[f"max_simultaneous_{cls.replace(' ', '_')}"] = count # e.g. "max_simultaneous_bird"
241
+ return result
242
+
243
+ def find_dialogue_response(self, transcript_entries, query_phrase, max_entries_gap=2, max_time_gap_s=5.0):
244
+ """
245
+ Finds what is said in response to a given query phrase in transcript entries.
246
+ Looks for the query phrase and then captures the text from subsequent entries.
247
+
248
+ Args:
249
+ transcript_entries (list): List of transcript dictionaries (from get_video_transcript).
250
+ query_phrase (str): The phrase to find (e.g., a question).
251
+ max_entries_gap (int): How many transcript entries to look ahead for a response.
252
+ max_time_gap_s (float): Maximum time in seconds after the query phrase to consider for a response.
253
+
254
+ Returns:
255
+ dict: {"success": True, "response_text": "...", "found_at_entry": {...}} or {"error": "..."}
256
+ """
257
+ if not transcript_entries:
258
+ return {"error": "Transcript entries are empty."}
259
+
260
+ query_phrase_lower = query_phrase.lower().rstrip('?.!,;') # Strip common trailing punctuation
261
+
262
+ for i, entry in enumerate(transcript_entries):
263
+ # Correctly access attributes: .text, .start, .duration
264
+ if query_phrase_lower in entry.text.lower():
265
+ # Found the query phrase, now look for the response
266
+ response_parts = []
267
+ start_time_of_query = entry.start + entry.duration # End time of query entry
268
+
269
+ for j in range(i + 1, min(i + 1 + max_entries_gap + 1, len(transcript_entries))):
270
+ next_entry = transcript_entries[j]
271
+ # Check if the next entry is within the time gap
272
+ if next_entry.start - start_time_of_query > max_time_gap_s:
273
+ break # Too much time has passed
274
+
275
+ # Add text if it's not just noise or very short (heuristic)
276
+ if next_entry.text.strip() and len(next_entry.text.strip()) > 1:
277
+ response_parts.append(next_entry.text)
278
+
279
+ # If we have collected some response, and the next entry is significantly later, stop.
280
+ if response_parts and (j + 1 < len(transcript_entries)):
281
+ if transcript_entries[j+1].start - (next_entry.start + next_entry.duration) > 1.0: # If gap > 1s
282
+ break
283
+
284
+ if response_parts:
285
+ return {
286
+ "success": True,
287
+ "response_text": " ".join(response_parts),
288
+ "query_entry": entry,
289
+ "response_start_entry_index": i + 1
290
+ }
291
+ # If no response found immediately after, but query was found
292
+ return {"error": f"Query phrase '{query_phrase}' found, but no subsequent dialogue captured as response within gap."}
293
+
294
+ return {"error": f"Query phrase '{query_phrase}' not found in transcript."}
295
+
296
+
297
+ def process_video(self, youtube_url, query_type, query_params=None):
298
+ """
299
+ Main method to process a video based on the type of query.
300
+
301
+ Args:
302
+ youtube_url (str): URL of the YouTube video.
303
+ query_type (str): Type of processing: "transcript", "object_count", "dialogue_response".
304
+ query_params (dict, optional): Additional parameters for the specific query type.
305
+ For "object_count": {"target_classes": ["bird"], "confidence_threshold": 0.5, "resolution": "360p"}
306
+ For "dialogue_response": {"query_phrase": "Isn't that hot?", "languages": ['en']}
307
+ """
308
+ if query_params is None:
309
+ query_params = {}
310
+
311
+ if query_type == "transcript":
312
+ return self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
313
+
314
+ elif query_type == "object_count":
315
+ if not self.object_detection_model:
316
+ return {"error": "Object detection model not initialized. Cannot count objects."}
317
+
318
+ resolution = query_params.get("resolution", "360p")
319
+ download_result = self.download_video(youtube_url, resolution=resolution)
320
+ if "error" in download_result:
321
+ return download_result
322
+
323
+ video_path = download_result["file_path"]
324
+ target_classes = query_params.get("target_classes")
325
+ if not target_classes or not isinstance(target_classes, list):
326
+ return {"error": "query_params must include 'target_classes' as a list for object_count."}
327
+
328
+ confidence = query_params.get("confidence_threshold", 0.5)
329
+ frame_skip = query_params.get("frame_skip", 5)
330
+ return self.count_objects_in_video(video_path, target_classes, confidence, frame_skip)
331
+
332
+ elif query_type == "dialogue_response":
333
+ transcript_result = self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
334
+ if "error" in transcript_result:
335
+ return transcript_result
336
+
337
+ query_phrase = query_params.get("query_phrase")
338
+ if not query_phrase:
339
+ return {"error": "query_params must include 'query_phrase' for dialogue_response."}
340
+
341
+ return self.find_dialogue_response(
342
+ transcript_result["transcript_entries"],
343
+ query_phrase,
344
+ max_entries_gap=query_params.get("max_entries_gap", 2),
345
+ max_time_gap_s=query_params.get("max_time_gap_s", 5.0)
346
+ )
347
+
348
+ return {"error": f"Unsupported query type: {query_type}"}
349
+
350
+ def cleanup(self):
351
+ """Remove temporary files and directory."""
352
+ if os.path.exists(self.temp_dir):
353
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
354
+ # print(f"Cleaned up temp directory: {self.temp_dir}")
355
+
356
+ # Example Usage (for testing purposes, assuming model files are in ./models/cv/):
357
+ if __name__ == '__main__':
358
+ # Create dummy model files for local testing if they don't exist
359
+ os.makedirs("./models/cv", exist_ok=True)
360
+ dummy_cfg = "./models/cv/dummy-yolov3-tiny.cfg"
361
+ dummy_weights = "./models/cv/dummy-yolov3-tiny.weights"
362
+ dummy_names = "./models/cv/dummy-coco.names"
363
+
364
+ if not os.path.exists(dummy_cfg): open(dummy_cfg, 'w').write("# Dummy YOLOv3 tiny config")
365
+ if not os.path.exists(dummy_weights): open(dummy_weights, 'w').write("dummy weights") # Actual weights file is binary
366
+ if not os.path.exists(dummy_names): open(dummy_names, 'w').write("bird\\ncat\\ndog\\nperson")
367
+
368
+ # Initialize tool
369
+ # Note: For real object detection, provide paths to actual .cfg, .weights, and .names files.
370
+ # For example, from: https://pjreddie.com/darknet/yolo/
371
+ video_tool = VideoProcessingTool(
372
+ model_cfg_path=dummy_cfg, # Replace with actual path to YOLOv3-tiny.cfg or similar
373
+ model_weights_path=dummy_weights, # Replace with actual path to YOLOv3-tiny.weights
374
+ class_names_path=dummy_names # Replace with actual path to coco.names
375
+ )
376
+
377
+ # Test 1: Get Transcript
378
+ # Replace with a video that has transcripts
379
+ transcript_test_url = "https://www.youtube.com/watch?v=1htKBjuUWec" # Stargate SG-1 clip
380
+ print(f"--- Testing Transcript for: {transcript_test_url} ---")
381
+ transcript_info = video_tool.process_video(transcript_test_url, "transcript")
382
+ if transcript_info.get("success"):
383
+ print("Transcript (first 100 chars):", transcript_info.get("transcript", "")[:100])
384
+ else:
385
+ print("Transcript Error:", transcript_info.get("error"))
386
+ print("\\n")
387
+
388
+ # Test 2: Find Dialogue Response
389
+ dialogue_test_url = "https://www.youtube.com/watch?v=1htKBjuUWec" # Stargate SG-1 clip
390
+ print(f"--- Testing Dialogue Response for: {dialogue_test_url} ---")
391
+ dialogue_info = video_tool.process_video(
392
+ dialogue_test_url,
393
+ "dialogue_response",
394
+ query_params={"query_phrase": "Isn't that hot?"}
395
+ )
396
+ if dialogue_info.get("success"):
397
+ print(f"Query: 'Isn't that hot?', Response: '{dialogue_info.get('response_text')}'")
398
+ else:
399
+ print("Dialogue Error:", dialogue_info.get("error"))
400
+ print("\\n")
401
+
402
+ # Test 3: Object Counting (will likely use dummy model and might not detect much without real video/model)
403
+ # Replace with a video URL that you want to test object counting on.
404
+ # This example will download a short video.
405
+ object_count_test_url = "https://www.youtube.com/watch?v=L1vXCYZAYYM" # Birds video
406
+ print(f"--- Testing Object Counting for: {object_count_test_url} ---")
407
+ # Ensure you have actual model files for this to work meaningfully.
408
+ # The dummy model files will likely result in zero counts or errors if OpenCV can't parse them.
409
+ # For this example, we expect it to run through, but actual detection depends on valid models.
410
+ if video_tool.object_detection_model:
411
+ count_info = video_tool.process_video(
412
+ object_count_test_url,
413
+ "object_count",
414
+ query_params={"target_classes": ["bird"], "resolution": "360p"}
415
+ )
416
+ if count_info.get("success"):
417
+ print("Object Counts:", count_info)
418
+ else:
419
+ print("Object Counting Error:", count_info.get("error"))
420
+ else:
421
+ print("Object detection model not loaded, skipping object count test.")
422
+
423
+ # Cleanup
424
+ video_tool.cleanup()
425
+ # Clean up dummy model files if they were created by this script
426
+ # (Be careful if you have real files with these names)
427
+ # if os.path.exists(dummy_cfg) and "dummy-yolov3-tiny.cfg" in dummy_cfg : os.remove(dummy_cfg)
428
+ # if os.path.exists(dummy_weights) and "dummy-yolov3-tiny.weights" in dummy_weights: os.remove(dummy_weights)
429
+ # if os.path.exists(dummy_names) and "dummy-coco.names" in dummy_names: os.remove(dummy_names)
430
+ # if os.path.exists("./models/cv") and not os.listdir("./models/cv"): os.rmdir("./models/cv")
431
+ # if os.path.exists("./models") and not os.listdir("./models"): os.rmdir("./models")
432
+
433
+ print("\\nAll tests finished.")
src/web_browsing_tool.py CHANGED
@@ -3,7 +3,8 @@ from bs4 import BeautifulSoup
3
 
4
  class WebBrowser:
5
  """
6
- A simple web browser tool to fetch and parse content from URLs.
 
7
  """
8
 
9
  def __init__(self, user_agent="GAIA-Agent/1.0"):
 
3
 
4
  class WebBrowser:
5
  """
6
+ Retrieves information from online sources by browsing web pages or performing web searches.
7
+ Useful for extracting or summarizing web content.
8
  """
9
 
10
  def __init__(self, user_agent="GAIA-Agent/1.0"):
tests/test_video_processing_tool.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ import shutil
4
+
5
+ import os
6
+ import sys
7
+
8
+ # Add the parent directory to sys.path to find the src module
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+ from src.video_processing_tool import VideoProcessingTool
11
+
12
+
13
+
14
+ # --- Test Configuration ---
15
+ # URLs for testing different functionalities
16
+ # Ensure these videos are publicly accessible and have expected features (transcripts, specific objects)
17
+ # Using videos from the common_questions.json for relevance
18
+ VIDEO_URL_TRANSCRIPT_DIALOGUE = "https://www.youtube.com/watch?v=1htKBjuUWec" # Stargate SG-1 "Isn't that hot?"
19
+ VIDEO_URL_OBJECT_COUNT = "https://www.youtube.com/watch?v=L1vXCYZAYYM" # Birds video
20
+ VIDEO_URL_NO_TRANSCRIPT = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Rick Astley (often no official transcript)
21
+ VIDEO_URL_SHORT_GENERAL = "https://www.youtube.com/watch?v=jNQXAC9IVRw" # Short Creative Commons video (Big Buck Bunny)
22
+
23
+ # --- Fixtures ---
24
+ @pytest.fixture(scope="session")
25
+ def model_files():
26
+ """Creates dummy model files for testing CV functionality if real ones aren't provided."""
27
+ # Using a sub-directory within the test directory for these files
28
+ test_model_dir = os.path.join(os.path.dirname(__file__), "test_cv_models")
29
+ os.makedirs(test_model_dir, exist_ok=True)
30
+
31
+ cfg_path = os.path.join(test_model_dir, "dummy-yolov3-tiny.cfg")
32
+ weights_path = os.path.join(test_model_dir, "dummy-yolov3-tiny.weights")
33
+ names_path = os.path.join(test_model_dir, "dummy-coco.names")
34
+
35
+ # Create minimal dummy files if they don't exist
36
+ # These won't make OpenCV's DNN module load a functional model but allow testing the file handling logic.
37
+ # For actual DNN model loading, valid model files are required.
38
+ if not os.path.exists(cfg_path): open(cfg_path, 'w').write("[net]\nwidth=416\nheight=416")
39
+ if not os.path.exists(weights_path): open(weights_path, 'wb').write(b'dummyweights')
40
+ if not os.path.exists(names_path): open(names_path, 'w').write("bird\ncat\ndog\nperson\n")
41
+
42
+ yield {
43
+ "cfg": cfg_path,
44
+ "weights": weights_path,
45
+ "names": names_path,
46
+ "dir": test_model_dir
47
+ }
48
+ # Cleanup: remove the dummy model directory after tests
49
+ # shutil.rmtree(test_model_dir, ignore_errors=True) # Keep for inspection if tests fail
50
+
51
+ @pytest.fixture
52
+ def video_tool(model_files):
53
+ """Initializes VideoProcessingTool with dummy model paths for testing."""
54
+ # Using a specific temp directory for test artifacts
55
+ test_temp_dir_base = os.path.join(os.path.dirname(__file__), "test_temp_videos")
56
+ os.makedirs(test_temp_dir_base, exist_ok=True)
57
+
58
+ tool = VideoProcessingTool(
59
+ model_cfg_path=model_files["cfg"],
60
+ model_weights_path=model_files["weights"],
61
+ class_names_path=model_files["names"],
62
+ temp_dir_base=test_temp_dir_base
63
+ )
64
+ yield tool
65
+ tool.cleanup() # Ensure temp files for this tool instance are removed
66
+ # Optional: Clean up the base test_temp_dir if it's empty or after all tests
67
+ # if os.path.exists(test_temp_dir_base) and not os.listdir(test_temp_dir_base):
68
+ # shutil.rmtree(test_temp_dir_base)
69
+
70
+ # --- Test Cases ---
71
+
72
+ def test_extract_video_id(video_tool):
73
+ assert video_tool._extract_video_id("https://www.youtube.com/watch?v=1htKBjuUWec") == "1htKBjuUWec"
74
+ assert video_tool._extract_video_id("https://youtu.be/1htKBjuUWec") == "1htKBjuUWec"
75
+ assert video_tool._extract_video_id("https://www.youtube.com/embed/1htKBjuUWec") == "1htKBjuUWec"
76
+ assert video_tool._extract_video_id("https://www.youtube.com/watch?v=1htKBjuUWec&t=10s") == "1htKBjuUWec"
77
+ assert video_tool._extract_video_id("invalid_url") is None
78
+
79
+ @pytest.mark.integration # Marks as integration test (requires network)
80
+ def test_download_video(video_tool):
81
+ result = video_tool.download_video(VIDEO_URL_SHORT_GENERAL, resolution="240p")
82
+ assert result.get("success"), f"Download failed: {result.get('error')}"
83
+ assert "file_path" in result
84
+ assert os.path.exists(result["file_path"])
85
+ assert result["file_path"].endswith(".mp4") or result["file_path"].startswith(video_tool._extract_video_id(VIDEO_URL_SHORT_GENERAL))
86
+
87
+ @pytest.mark.integration
88
+ def test_get_video_transcript_success(video_tool):
89
+ result = video_tool.get_video_transcript(VIDEO_URL_TRANSCRIPT_DIALOGUE)
90
+ assert result.get("success"), f"Transcript fetch failed: {result.get('error')}"
91
+ assert "transcript" in result and len(result["transcript"]) > 0
92
+ assert "transcript_entries" in result and len(result["transcript_entries"]) > 0
93
+ # Making the check case-insensitive and more flexible
94
+ assert "isn't that hot" in result["transcript"].lower() # Check for expected content (removed ?)
95
+
96
+ @pytest.mark.integration
97
+ def test_get_video_transcript_no_transcript(video_tool):
98
+ # This video is unlikely to have official transcripts in many languages
99
+ # However, YouTube might auto-generate. The API should handle it gracefully.
100
+ result = video_tool.get_video_transcript(VIDEO_URL_NO_TRANSCRIPT, languages=['xx-YY']) # Non-existent language
101
+ assert not result.get("success")
102
+ assert "error" in result
103
+ assert "No transcript found" in result["error"] or "Transcripts are disabled" in result["error"]
104
+
105
+ @pytest.mark.integration
106
+ def test_find_dialogue_response_success(video_tool):
107
+ transcript_data = video_tool.get_video_transcript(VIDEO_URL_TRANSCRIPT_DIALOGUE)
108
+ assert transcript_data.get("success"), f"Transcript fetch failed for dialogue test: {transcript_data.get('error')}"
109
+
110
+ result = video_tool.find_dialogue_response(transcript_data["transcript_entries"], "Isn't that hot?")
111
+ assert result.get("success"), f"Dialogue search failed: {result.get('error')}"
112
+ assert "response_text" in result
113
+ # The expected response is "Extremely" but can vary slightly with transcript generation
114
+ assert "Extremely".lower() in result["response_text"].lower()
115
+
116
+ @pytest.mark.integration
117
+ def test_find_dialogue_response_not_found(video_tool):
118
+ transcript_data = video_tool.get_video_transcript(VIDEO_URL_TRANSCRIPT_DIALOGUE)
119
+ assert transcript_data.get("success")
120
+
121
+ result = video_tool.find_dialogue_response(transcript_data["transcript_entries"], "This phrase is not in the video")
122
+ assert not result.get("success")
123
+ assert "not found in transcript" in result.get("error", "")
124
+
125
+ @pytest.mark.integration
126
+ @pytest.mark.cv_dependent # Marks tests that rely on (even dummy) CV model setup
127
+ def test_object_counting_interface(video_tool):
128
+ """Tests the object counting call, expecting it to run with dummy models even if counts are zero."""
129
+ if not video_tool.object_detection_model:
130
+ pytest.skip("CV model not loaded, skipping object count test.")
131
+
132
+ download_result = video_tool.download_video(VIDEO_URL_OBJECT_COUNT, resolution="240p") # Use a short video
133
+ assert download_result.get("success"), f"Video download failed for object counting: {download_result.get('error')}"
134
+ video_path = download_result["file_path"]
135
+
136
+ result = video_tool.count_objects_in_video(video_path, target_classes=["bird"], confidence_threshold=0.1, frame_skip=30)
137
+
138
+ # With dummy models, we don't expect actual detections, but the function should complete.
139
+ assert result.get("success"), f"Object counting failed: {result.get('error')}"
140
+ assert "max_simultaneous_bird" in result # Even if it's 0
141
+ # If using real models and a video with birds, you would assert result["max_simultaneous_bird"] > 0
142
+
143
+ @pytest.mark.integration
144
+ @pytest.mark.cv_dependent
145
+ def test_process_video_object_count_flow(video_tool):
146
+ if not video_tool.object_detection_model:
147
+ pytest.skip("CV model not loaded, skipping process_video object count test.")
148
+
149
+ query_params = {
150
+ "target_classes": ["bird"],
151
+ "resolution": "240p",
152
+ "confidence_threshold": 0.1,
153
+ "frame_skip": 30 # Process fewer frames for faster test
154
+ }
155
+ result = video_tool.process_video(VIDEO_URL_OBJECT_COUNT, "object_count", query_params=query_params)
156
+ assert result.get("success"), f"process_video for object_count failed: {result.get('error')}"
157
+ assert "max_simultaneous_bird" in result
158
+
159
+ @pytest.mark.integration
160
+ def test_process_video_dialogue_flow(video_tool):
161
+ query_params = {"query_phrase": "Isn't that hot?"}
162
+ result = video_tool.process_video(VIDEO_URL_TRANSCRIPT_DIALOGUE, "dialogue_response", query_params=query_params)
163
+ assert result.get("success"), f"process_video for dialogue_response failed: {result.get('error')}"
164
+ assert "extremely" in result.get("response_text", "").lower()
165
+
166
+ @pytest.mark.integration
167
+ def test_process_video_transcript_flow(video_tool):
168
+ result = video_tool.process_video(VIDEO_URL_TRANSCRIPT_DIALOGUE, "transcript")
169
+ assert result.get("success"), f"process_video for transcript failed: {result.get('error')}"
170
+ assert "transcript" in result and len(result["transcript"]) > 0
171
+
172
+ def test_cleanup_removes_temp_dir(model_files): # Test cleanup more directly
173
+ test_temp_dir_base = os.path.join(os.path.dirname(__file__), "test_temp_cleanup")
174
+ os.makedirs(test_temp_dir_base, exist_ok=True)
175
+ tool = VideoProcessingTool(
176
+ model_cfg_path=model_files["cfg"],
177
+ model_weights_path=model_files["weights"],
178
+ class_names_path=model_files["names"],
179
+ temp_dir_base=test_temp_dir_base
180
+ )
181
+ # Create a dummy file in its temp dir
182
+ temp_file_in_tool_dir = os.path.join(tool.temp_dir, "dummy.txt")
183
+ open(temp_file_in_tool_dir, 'w').write("test")
184
+ assert os.path.exists(tool.temp_dir)
185
+ assert os.path.exists(temp_file_in_tool_dir)
186
+
187
+ tool_temp_dir_path = tool.temp_dir # Store path before cleanup
188
+ tool.cleanup()
189
+ assert not os.path.exists(tool_temp_dir_path), f"Temp directory {tool_temp_dir_path} was not removed."
190
+ # shutil.rmtree(test_temp_dir_base, ignore_errors=True) # Clean up the base for this specific test
191
+
192
+ # To run these tests:
193
+ # 1. Ensure you have pytest installed (`pip install pytest`).
194
+ # 2. Ensure required libraries for VideoProcessingTool are installed (yt_dlp, youtube_transcript_api, opencv-python).
195
+ # 3. Navigate to the directory containing this test file and `src` directory.
196
+ # 4. Run `pytest` or `python -m pytest` in your terminal.
197
+ # 5. For tests requiring network (integration), ensure you have an internet connection.
198
+ # 6. For CV dependent tests to be meaningful beyond interface checks, replace dummy model files with actual ones.