Yago Bolivar commited on
Commit
87aa741
·
1 Parent(s): b70394c

Refactor speech_to_text.py to implement a singleton ASR pipeline, enhance error handling, and introduce SpeechToTextTool for better integration. Update spreadsheet_tool.py to support querying and improve parsing functionality, including CSV support. Enhance video_processing_tool.py with new tasks for metadata extraction and frame extraction, while improving object detection capabilities and initialization checks.

Browse files
src/image_processing_tool.py CHANGED
@@ -46,7 +46,7 @@ class ImageProcessor(Tool):
46
  # For simplicity, let's assume a general 'process' action and specify task type in params
47
  inputs = {
48
  'image_filepath': {'type': 'string', 'description': 'Path to the image file.'},
49
- 'task': {'type': 'string', 'description': 'Specific task to perform (e.g., \'caption\', \'chess_analysis\').'}
50
  }
51
  outputs = {'result': {'type': 'object', 'description': 'The result of the image processing task (e.g., text caption, chess move, error message).'}}
52
  output_type = "object"
 
46
  # For simplicity, let's assume a general 'process' action and specify task type in params
47
  inputs = {
48
  'image_filepath': {'type': 'string', 'description': 'Path to the image file.'},
49
+ 'task': {'type': 'string', 'description': 'Specific task to perform (e.g., \'caption\', \'chess_analysis\').', 'nullable': True} # Added nullable: True
50
  }
51
  outputs = {'result': {'type': 'object', 'description': 'The result of the image processing task (e.g., text caption, chess move, error message).'}}
52
  output_type = "object"
src/markdown_table_parser.py CHANGED
@@ -1,6 +1,9 @@
1
  import re
 
 
2
 
3
- def parse_markdown_table(markdown_text: str) -> dict[str, list[str]] | None:
 
4
  """
5
  Parses the first valid Markdown table found in a string.
6
  Returns a dictionary (headers as keys, lists of cell content as values)
@@ -48,15 +51,45 @@ def parse_markdown_table(markdown_text: str) -> dict[str, list[str]] | None:
48
  # First cell is row label, rest are data
49
  table[headers[0]].append(cells[0])
50
  for k, h in enumerate(headers[1:], 1):
 
 
 
51
  table[h].append(cells[k])
52
  else:
53
  for k, h in enumerate(headers):
 
 
54
  table[h].append(cells[k])
55
  j += 1
56
  return table
57
  return None
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if __name__ == '__main__':
 
60
  example_table = """
61
  |*|a|b|c|d|e|
62
  |---|---|---|---|---|---|
@@ -66,7 +99,7 @@ if __name__ == '__main__':
66
  |d|b|e|b|e|d|
67
  |e|d|b|a|d|c|
68
  """
69
- parsed = parse_markdown_table(example_table)
70
  print("Parsed GAIA example:")
71
  if parsed:
72
  for header, column_data in parsed.items():
@@ -83,7 +116,7 @@ if __name__ == '__main__':
83
  | Carol | 45 | London |
84
  Some text after
85
  """
86
- parsed_2 = parse_markdown_table(example_table_2)
87
  print("\\nParsed Table 2 (with surrounding text):")
88
  if parsed_2:
89
  for header, column_data in parsed_2.items():
@@ -95,36 +128,31 @@ if __name__ == '__main__':
95
  | Header1 | Header2 |
96
  |---------|---------|
97
  """
98
- parsed_empty = parse_markdown_table(empty_table_with_header)
99
  print("\\nParsed Empty Table with Header:")
100
  if parsed_empty:
101
  for header, column_data in parsed_empty.items():
102
  print(f"Header: {header}, Data: {column_data}")
103
  else:
104
- print("Failed to parse empty table with header.")
105
-
106
- malformed_separator = """
107
- | Header1 | Header2 |
108
- |---foo---|---------|
109
- | data1 | data2 |
110
- """
111
- parsed_mal_sep = parse_markdown_table(malformed_separator)
112
- print("\\nParsed table with malformed separator:")
113
- if parsed_mal_sep:
114
- print(parsed_mal_sep)
115
- else:
116
- print("Failed to parse (correctly).")
117
 
118
- table_with_alignment = """
119
- | Syntax | Description |
120
- | :-------- | :-----------: |
121
- | Header | Title |
122
- | Paragraph | Text |
123
  """
124
- parsed_align = parse_markdown_table(table_with_alignment)
125
- print("\\nParsed table with alignment in separator:")
126
- if parsed_align:
127
- for header, column_data in parsed_align.items():
128
  print(f"Header: {header}, Data: {column_data}")
129
  else:
130
- print("Failed to parse table with alignment.")
 
 
 
 
 
 
 
 
 
1
  import re
2
+ from smolagents.tools import Tool
3
+ from typing import Dict, List, Optional
4
 
5
+ # Original parsing function
6
+ def _parse_markdown_table_string(markdown_text: str) -> Optional[Dict[str, List[str]]]:
7
  """
8
  Parses the first valid Markdown table found in a string.
9
  Returns a dictionary (headers as keys, lists of cell content as values)
 
51
  # First cell is row label, rest are data
52
  table[headers[0]].append(cells[0])
53
  for k, h in enumerate(headers[1:], 1):
54
+ # Ensure the key exists and is a list
55
+ if h not in table or not isinstance(table[h], list):
56
+ table[h] = [] # Initialize if not present or not a list
57
  table[h].append(cells[k])
58
  else:
59
  for k, h in enumerate(headers):
60
+ if h not in table or not isinstance(table[h], list):
61
+ table[h] = []
62
  table[h].append(cells[k])
63
  j += 1
64
  return table
65
  return None
66
 
67
+ class MarkdownTableParserTool(Tool):
68
+ """
69
+ Parses a Markdown table from a given text string.
70
+ Useful for converting markdown tables into Python data structures for further analysis.
71
+ """
72
+ name = "markdown_table_parser"
73
+ description = "Parses the first valid Markdown table found in a string and returns it as a dictionary."
74
+ inputs = {'markdown_text': {'type': 'string', 'description': 'The string containing the Markdown table.'}}
75
+ outputs = {'parsed_table': {'type': 'object', 'description': 'A dictionary representing the table (headers as keys, lists of cell content as values), or null if no table is found.'}}
76
+ output_type = "object" # Or dict/None
77
+
78
+ def __init__(self, *args, **kwargs):
79
+ super().__init__(*args, **kwargs)
80
+ self.is_initialized = True
81
+
82
+ def forward(self, markdown_text: str) -> Optional[Dict[str, List[str]]]:
83
+ """
84
+ Wrapper for the _parse_markdown_table_string function.
85
+ """
86
+ return _parse_markdown_table_string(markdown_text)
87
+
88
+ # Expose the original function name if other parts of the system expect it (optional)
89
+ parse_markdown_table = _parse_markdown_table_string
90
+
91
  if __name__ == '__main__':
92
+ tool_instance = MarkdownTableParserTool()
93
  example_table = """
94
  |*|a|b|c|d|e|
95
  |---|---|---|---|---|---|
 
99
  |d|b|e|b|e|d|
100
  |e|d|b|a|d|c|
101
  """
102
+ parsed = tool_instance.forward(example_table)
103
  print("Parsed GAIA example:")
104
  if parsed:
105
  for header, column_data in parsed.items():
 
116
  | Carol | 45 | London |
117
  Some text after
118
  """
119
+ parsed_2 = tool_instance.forward(example_table_2)
120
  print("\\nParsed Table 2 (with surrounding text):")
121
  if parsed_2:
122
  for header, column_data in parsed_2.items():
 
128
  | Header1 | Header2 |
129
  |---------|---------|
130
  """
131
+ parsed_empty = tool_instance.forward(empty_table_with_header)
132
  print("\\nParsed Empty Table with Header:")
133
  if parsed_empty:
134
  for header, column_data in parsed_empty.items():
135
  print(f"Header: {header}, Data: {column_data}")
136
  else:
137
+ print("Failed to parse table (empty with header).") # Corrected message
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ malformed_table = """
140
+ | Header1 | Header2
141
+ |--- ---|
142
+ | cell1 | cell2 |
 
143
  """
144
+ parsed_malformed = tool_instance.forward(malformed_table)
145
+ print("\\nParsed Malformed Table:")
146
+ if parsed_malformed:
147
+ for header, column_data in parsed_malformed.items():
148
  print(f"Header: {header}, Data: {column_data}")
149
  else:
150
+ print("Failed to parse malformed table.")
151
+
152
+ no_table_text = "This is just some text without a table."
153
+ parsed_no_table = tool_instance.forward(no_table_text)
154
+ print("\\nParsed Text Without Table:")
155
+ if parsed_no_table:
156
+ print("Error: Should not have parsed a table.")
157
+ else:
158
+ print("Correctly found no table.")
src/python_tool.py CHANGED
@@ -6,21 +6,31 @@ import re
6
  import traceback
7
  from typing import Dict, Any, Optional, Union, List
8
  from smolagents.tools import Tool
 
9
 
10
  class CodeExecutionTool(Tool):
11
  """
12
  Executes Python code in a controlled environment for safe code interpretation.
13
  Useful for evaluating code snippets and returning their output or errors.
14
  """
 
 
 
 
 
 
 
 
15
 
16
- def __init__(self, timeout: int = 5, max_output_size: int = 10000):
17
- self.timeout = timeout # Maximum execution time in seconds
 
18
  self.max_output_size = max_output_size
19
- # Restricted imports - add more as needed
20
  self.banned_modules = [
21
- 'os', 'subprocess', 'sys', 'builtins', 'importlib', 'eval',
22
- 'pickle', 'requests', 'socket', 'shutil'
23
  ]
 
24
 
25
  def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
26
  """Perform static analysis to check for potentially harmful code."""
@@ -33,9 +43,11 @@ class CodeExecutionTool(Tool):
33
  if isinstance(node, ast.Import):
34
  imports.extend(n.name for n in node.names)
35
  elif isinstance(node, ast.ImportFrom):
36
- imports.append(node.module)
 
 
37
 
38
- dangerous_imports = [imp for imp in imports if any(
39
  banned in imp for banned in self.banned_modules)]
40
 
41
  if dangerous_imports:
@@ -83,141 +95,187 @@ class CodeExecutionTool(Tool):
83
 
84
  return None
85
 
86
- def execute_file(self, filepath: str) -> Dict[str, Any]:
87
- """Execute Python code from file and capture the output."""
88
- try:
89
- with open(filepath, 'r') as file:
90
- code = file.read()
91
-
92
- return self.execute_code(code)
93
-
94
- except FileNotFoundError:
95
- return {"success": False, "error": f"File not found: {filepath}"}
96
- except Exception as e:
97
- return {
98
- "success": False,
99
- "error": f"Error reading file: {str(e)}"
100
- }
101
-
102
- def execute_code(self, code: str) -> Dict[str, Any]:
103
- """Execute Python code string and capture the output."""
104
- # Check code safety first
 
 
 
 
 
 
 
105
  safety_check = self._analyze_code_safety(code)
106
  if not safety_check["safe"]:
107
- return {
108
- "success": False,
109
- "error": f"Security check failed: {safety_check['reason']}"
110
- }
111
-
112
- # Prepare a clean globals dictionary with minimal safe functions
113
- safe_globals = {
114
- 'abs': abs,
115
- 'all': all,
116
- 'any': any,
117
- 'bin': bin,
118
- 'bool': bool,
119
- 'chr': chr,
120
- 'complex': complex,
121
- 'dict': dict,
122
- 'divmod': divmod,
123
- 'enumerate': enumerate,
124
- 'filter': filter,
125
- 'float': float,
126
- 'format': format,
127
- 'frozenset': frozenset,
128
- 'hash': hash,
129
- 'hex': hex,
130
- 'int': int,
131
- 'isinstance': isinstance,
132
- 'issubclass': issubclass,
133
- 'len': len,
134
- 'list': list,
135
- 'map': map,
136
- 'max': max,
137
- 'min': min,
138
- 'oct': oct,
139
- 'ord': ord,
140
- 'pow': pow,
141
- 'print': print,
142
- 'range': range,
143
- 'reversed': reversed,
144
- 'round': round,
145
- 'set': set,
146
- 'sorted': sorted,
147
- 'str': str,
148
- 'sum': sum,
149
- 'tuple': tuple,
150
- 'zip': zip,
151
- '__builtins__': {}, # Empty builtins for extra security
152
- }
153
-
154
- # Add math module functions, commonly needed
155
- try:
156
- import math
157
- for name in dir(math):
158
- if not name.startswith('_'):
159
- safe_globals[name] = getattr(math, name)
160
- except ImportError:
161
- pass
162
-
163
- # Capture output using StringIO
164
- output_buffer = io.StringIO()
165
-
166
- # Set timeout handler
167
- old_handler = signal.getsignal(signal.SIGALRM)
168
  signal.signal(signal.SIGALRM, self._timeout_handler)
169
  signal.alarm(self.timeout)
170
-
 
 
 
 
 
 
 
171
  try:
172
- # Execute code with stdout/stderr capture
173
- with contextlib.redirect_stdout(output_buffer):
174
- with contextlib.redirect_stderr(output_buffer):
175
- exec(code, safe_globals)
176
 
177
- output = output_buffer.getvalue()
178
  if len(output) > self.max_output_size:
179
- truncation_message = f"\n... [output truncated to {self.max_output_size} characters]"
180
- output = output[:self.max_output_size - len(truncation_message)] + truncation_message
181
- else:
182
- output = output.strip()
183
 
184
- # Extract the numeric value
185
- numeric_result = self._extract_numeric_value(output)
 
186
 
187
  return {
188
- "success": True,
189
- "raw_output": output,
190
- "numeric_value": numeric_result,
191
- "has_numeric_result": numeric_result is not None
192
  }
193
-
194
  except TimeoutError:
195
- return {
196
- "success": False,
197
- "error": f"Code execution timed out after {self.timeout} seconds"
198
- }
199
  except Exception as e:
200
- error_info = traceback.format_exc()
201
- return {
202
- "success": False,
203
- "error": str(e),
204
- "traceback": error_info,
205
- "raw_output": output_buffer.getvalue()
206
- }
207
  finally:
208
- # Reset alarm and signal handler
209
- signal.alarm(0)
210
- signal.signal(signal.SIGALRM, old_handler)
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- # Example usage
214
- if __name__ == "__main__":
215
- executor = CodeExecutionTool()
216
- result = executor.execute_code("""
217
- # Example code that calculates a value
218
- total = 0
219
- for i in range(10):
220
- total += i * 2
221
- print(f"The result is {total}")
222
- """)
223
- print(result)
 
6
  import traceback
7
  from typing import Dict, Any, Optional, Union, List
8
  from smolagents.tools import Tool
9
+ import os
10
 
11
  class CodeExecutionTool(Tool):
12
  """
13
  Executes Python code in a controlled environment for safe code interpretation.
14
  Useful for evaluating code snippets and returning their output or errors.
15
  """
16
+ name = "python_code_executor"
17
+ description = "Executes a given Python code string or Python code from a file. Returns the output or error."
18
+ inputs = {
19
+ 'code_string': {'type': 'string', 'description': 'The Python code to execute directly.', 'nullable': True},
20
+ 'filepath': {'type': 'string', 'description': 'The path to a Python file to execute.', 'nullable': True}
21
+ }
22
+ outputs = {'result': {'type': 'object', 'description': 'A dictionary containing \'success\', \'output\', and/or \'error\'.'}}
23
+ output_type = "object"
24
 
25
+ def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.timeout = timeout
28
  self.max_output_size = max_output_size
 
29
  self.banned_modules = [
30
+ 'os', 'subprocess', 'sys', 'builtins', 'importlib',
31
+ 'pickle', 'requests', 'socket', 'shutil', 'ctypes', 'multiprocessing'
32
  ]
33
+ self.is_initialized = True
34
 
35
  def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
36
  """Perform static analysis to check for potentially harmful code."""
 
43
  if isinstance(node, ast.Import):
44
  imports.extend(n.name for n in node.names)
45
  elif isinstance(node, ast.ImportFrom):
46
+ # Ensure node.module is not None before attempting to check against banned_modules
47
+ if node.module and any(banned in node.module for banned in self.banned_modules):
48
+ imports.append(node.module)
49
 
50
+ dangerous_imports = [imp for imp in imports if imp and any(
51
  banned in imp for banned in self.banned_modules)]
52
 
53
  if dangerous_imports:
 
95
 
96
  return None
97
 
98
+ # Main entry point for the agent
99
+ def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]:
100
+ if not code_string and not filepath:
101
+ return {"success": False, "error": "No code string or filepath provided."}
102
+ if code_string and filepath:
103
+ return {"success": False, "error": "Provide either a code string or a filepath, not both."}
104
+
105
+ code_to_execute = ""
106
+ if filepath:
107
+ if not os.path.exists(filepath):
108
+ return {"success": False, "error": f"File not found: {filepath}"}
109
+ if not filepath.endswith(".py"):
110
+ return {"success": False, "error": f"File is not a Python file: {filepath}"}
111
+ try:
112
+ with open(filepath, 'r') as file:
113
+ code_to_execute = file.read()
114
+ except Exception as e:
115
+ return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
116
+ elif code_string:
117
+ code_to_execute = code_string
118
+
119
+ return self._execute_actual_code(code_to_execute)
120
+
121
+ # Renamed from execute_code to _execute_actual_code to be internal
122
+ def _execute_actual_code(self, code: str) -> Dict[str, Any]:
123
+ """Execute Python code and capture the output or error."""
124
  safety_check = self._analyze_code_safety(code)
125
  if not safety_check["safe"]:
126
+ return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"}
127
+
128
+ # Setup timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  signal.signal(signal.SIGALRM, self._timeout_handler)
130
  signal.alarm(self.timeout)
131
+
132
+ captured_output = io.StringIO()
133
+ # It's generally safer to execute in a restricted scope
134
+ # and not provide access to all globals/locals by default.
135
+ # However, for a tool that might need to define functions/classes and use them,
136
+ # a shared scope might be necessary. This needs careful consideration.
137
+ exec_globals = {}
138
+
139
  try:
140
+ with contextlib.redirect_stdout(captured_output):
141
+ with contextlib.redirect_stderr(captured_output): # Capture stderr as well
142
+ exec(code, exec_globals) # Execute in a controlled global scope
 
143
 
144
+ output = captured_output.getvalue()
145
  if len(output) > self.max_output_size:
146
+ output = output[:self.max_output_size] + "... [output truncated]"
 
 
 
147
 
148
+ # Attempt to extract a final numeric value if applicable
149
+ # This might be specific to certain tasks, consider making it optional
150
+ # numeric_result = self._extract_numeric_value(output)
151
 
152
  return {
153
+ "success": True,
154
+ "output": output,
155
+ # "numeric_value": numeric_result
 
156
  }
 
157
  except TimeoutError:
158
+ return {"success": False, "error": "Code execution timed out"}
 
 
 
159
  except Exception as e:
160
+ # Get detailed traceback
161
+ tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
162
+ error_details = "".join(tb_lines)
163
+ if len(error_details) > self.max_output_size:
164
+ error_details = error_details[:self.max_output_size] + "... [error truncated]"
165
+ return {"success": False, "error": f"Execution failed: {str(e)}\nTraceback:\n{error_details}"}
 
166
  finally:
167
+ signal.alarm(0) # Disable the alarm
168
+ captured_output.close()
169
+
170
+ # Kept execute_file and execute_code as helper methods if direct access is ever needed,
171
+ # but they now call the main _execute_actual_code method.
172
+ def execute_file(self, filepath: str) -> Dict[str, Any]:
173
+ """Helper to execute Python code from file."""
174
+ if not os.path.exists(filepath):
175
+ return {"success": False, "error": f"File not found: {filepath}"}
176
+ if not filepath.endswith(".py"):
177
+ return {"success": False, "error": f"File is not a Python file: {filepath}"}
178
+ try:
179
+ with open(filepath, 'r') as file:
180
+ code = file.read()
181
+ return self._execute_actual_code(code)
182
+ except Exception as e:
183
+ return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
184
+
185
+ def execute_code(self, code: str) -> Dict[str, Any]:
186
+ """Helper to execute Python code from a string."""
187
+ return self._execute_actual_code(code)
188
+
189
+
190
+ if __name__ == '__main__':
191
+ tool = CodeExecutionTool(timeout=5)
192
+
193
+ # Test 1: Safe code string
194
+ safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)"
195
+ print("\n--- Test 1: Safe Code String ---")
196
+ result1 = tool.forward(code_string=safe_code)
197
+ print(result1)
198
+ assert result1['success']
199
+ assert "Hello from safe code!" in result1['output']
200
+ assert "20" in result1['output']
201
+
202
+ # Test 2: Code with an error
203
+ error_code = "print(1/0)"
204
+ print("\n--- Test 2: Code with Error ---")
205
+ result2 = tool.forward(code_string=error_code)
206
+ print(result2)
207
+ assert not result2['success']
208
+ assert "ZeroDivisionError" in result2['error']
209
+
210
+ # Test 3: Code with a banned import
211
+ unsafe_import_code = "import os; print(os.getcwd())"
212
+ print("\n--- Test 3: Unsafe Import ---")
213
+ result3 = tool.forward(code_string=unsafe_import_code)
214
+ print(result3)
215
+ assert not result3['success']
216
+ assert "Safety check failed" in result3['error']
217
+ assert "os" in result3['error']
218
+
219
+ # Test 4: Timeout
220
+ timeout_code = "import time; time.sleep(10); print('Done sleeping')"
221
+ print("\n--- Test 4: Timeout ---")
222
+ # tool_timeout_short = CodeExecutionTool(timeout=2) # For testing timeout specifically
223
+ # result4 = tool_timeout_short.forward(code_string=timeout_code)
224
+ result4 = tool.forward(code_string=timeout_code) # Using the main tool instance with its timeout
225
+ print(result4)
226
+ assert not result4['success']
227
+ assert "timed out" in result4['error']
228
+
229
+ # Test 5: Execute from file
230
+ test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')"
231
+ test_filename = "temp_test_script.py"
232
+ with open(test_filename, "w") as f:
233
+ f.write(test_file_content)
234
+ print("\n--- Test 5: Execute from File ---")
235
+ result5 = tool.forward(filepath=test_filename)
236
+ print(result5)
237
+ assert result5['success']
238
+ assert "Hello from file!" in result5['output']
239
+ assert "Sum: 12" in result5['output']
240
+ os.remove(test_filename)
241
+
242
+ # Test 6: File not found
243
+ print("\n--- Test 6: File Not Found ---")
244
+ result6 = tool.forward(filepath="non_existent_script.py")
245
+ print(result6)
246
+ assert not result6['success']
247
+ assert "File not found" in result6['error']
248
+
249
+ # Test 7: Provide both code_string and filepath
250
+ print("\n--- Test 7: Both code_string and filepath ---")
251
+ result7 = tool.forward(code_string="print('hello')", filepath=test_filename)
252
+ print(result7)
253
+ assert not result7['success']
254
+ assert "Provide either a code string or a filepath, not both" in result7['error']
255
+
256
+ # Test 8: Provide neither
257
+ print("\n--- Test 8: Neither code_string nor filepath ---")
258
+ result8 = tool.forward()
259
+ print(result8)
260
+ assert not result8['success']
261
+ assert "No code string or filepath provided" in result8['error']
262
+
263
+ # Test 9: Code that defines a function and calls it
264
+ func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))"
265
+ print("\n--- Test 9: Function Definition and Call ---")
266
+ result9 = tool.forward(code_string=func_def_code)
267
+ print(result9)
268
+ assert result9['success']
269
+ assert "7" in result9['output']
270
+
271
+ # Test 10: Max output size
272
+ # tool_max_output = CodeExecutionTool(max_output_size=50)
273
+ # long_output_code = "for i in range(20): print(f'Line {i}')"
274
+ # print("\n--- Test 10: Max Output Size ---")
275
+ # result10 = tool_max_output.forward(code_string=long_output_code)
276
+ # print(result10)
277
+ # assert result10['success']
278
+ # assert "... [output truncated]" in result10['output']
279
+ # assert len(result10['output']) <= 50 + len("... [output truncated]") + 5 # a bit of leeway
280
 
281
+ print("\nAll tests seem to have passed (check output for details).")
 
 
 
 
 
 
 
 
 
 
src/speech_to_text.py CHANGED
@@ -1,35 +1,134 @@
1
  from transformers import pipeline
2
- import librosa # Or soundfile
3
  import os
 
 
4
 
5
- # Initialize the ASR pipeline with a specific model
6
- # Using a smaller Whisper model for quicker setup, but larger models offer better accuracy
7
- asr_pipeline = pipeline(
8
- "automatic-speech-recognition",
9
- model="openai/whisper-tiny.en",
10
- )
11
 
12
- def transcribe_audio(audio_filepath):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
- Converts speech in an audio file (e.g., .mp3) to text using speech recognition.
15
  Args:
16
  audio_filepath (str): Path to the audio file.
 
17
  Returns:
18
- str: Transcribed text from the audio.
19
  """
 
 
 
 
20
  try:
21
- transcription = asr_pipeline(audio_filepath, return_timestamps=True)
22
- return transcription["text"]
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
- return f"Error during transcription: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Example usage:
27
  if __name__ == "__main__":
28
- audio_file = "./downloaded_files/1f975693-876d-457b-a649-393859e79bf3.mp3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if os.path.exists(audio_file): # Check if the (placeholder or real) file exists
31
- print(f"Attempting to transcribe: {audio_file}")
32
- transcribed_text = transcribe_audio(audio_file)
33
- print(f"Transcription:\n{transcribed_text}")
34
  else:
35
- print(f"File not found: {audio_file}. Please provide a valid audio file.")
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
+ import librosa # Or soundfile
3
  import os
4
+ from smolagents.tools import Tool # Added import
5
+ from typing import Optional # Added for type hinting
6
 
7
+ # Initialize the ASR pipeline once
8
+ _asr_pipeline_instance = None
 
 
 
 
9
 
10
+
11
+ def get_asr_pipeline():
12
+ global _asr_pipeline_instance
13
+ if _asr_pipeline_instance is None:
14
+ try:
15
+ # Using a smaller Whisper model for quicker setup, but larger models offer better accuracy
16
+ _asr_pipeline_instance = pipeline(
17
+ "automatic-speech-recognition",
18
+ model="openai/whisper-tiny.en", # Consider making model configurable
19
+ )
20
+ print("ASR pipeline initialized.") # For feedback
21
+ except Exception as e:
22
+ print(f"Error initializing ASR pipeline: {e}")
23
+ # Handle error appropriately, e.g., raise or log
24
+ return _asr_pipeline_instance
25
+
26
+
27
+ # Original transcription function, renamed to be internal
28
+ def _transcribe_audio_file(audio_filepath: str, asr_pipeline_instance) -> str:
29
  """
30
+ Converts speech in an audio file to text using the provided ASR pipeline.
31
  Args:
32
  audio_filepath (str): Path to the audio file.
33
+ asr_pipeline_instance: The initialized ASR pipeline.
34
  Returns:
35
+ str: Transcribed text from the audio or an error message.
36
  """
37
+ if not asr_pipeline_instance:
38
+ return "Error: ASR pipeline is not available."
39
+ if not os.path.exists(audio_filepath):
40
+ return f"Error: Audio file not found at {audio_filepath}"
41
  try:
42
+ # Ensure the file can be loaded by librosa (or your chosen audio library)
43
+ # This step can help catch corrupted or unsupported audio formats early.
44
+ y, sr = librosa.load(audio_filepath, sr=None) # Load with original sample rate
45
+ if sr != 16000: # Whisper models expect 16kHz
46
+ y = librosa.resample(y, orig_sr=sr, target_sr=16000)
47
+
48
+ # Pass the numpy array to the pipeline
49
+ transcription_result = asr_pipeline_instance(
50
+ {"raw": y, "sampling_rate": 16000}, return_timestamps=False
51
+ ) # Changed to False for simplicity
52
+ return transcription_result["text"]
53
  except Exception as e:
54
+ return f"Error during transcription of {audio_filepath}: {e}"
55
+
56
+
57
+ class SpeechToTextTool(Tool):
58
+ """
59
+ Transcribes audio from a given audio file path to text.
60
+ """
61
+
62
+ name = "speech_to_text_transcriber"
63
+ description = "Converts speech in an audio file (e.g., .mp3, .wav) to text using speech recognition."
64
+ inputs = {
65
+ "audio_filepath": {"type": "string", "description": "Path to the audio file to transcribe."}
66
+ }
67
+ outputs = {
68
+ "transcribed_text": {
69
+ "type": "string",
70
+ "description": "The transcribed text from the audio, or an error message.",
71
+ }
72
+ }
73
+ output_type = "string"
74
+
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__(*args, **kwargs)
77
+ self.asr_pipeline = get_asr_pipeline() # Initialize or get the shared pipeline
78
+ self.is_initialized = True if self.asr_pipeline else False
79
+
80
+ def forward(self, audio_filepath: str) -> str:
81
+ """
82
+ Wrapper for the _transcribe_audio_file function.
83
+ """
84
+ if not self.is_initialized or not self.asr_pipeline:
85
+ return "Error: SpeechToTextTool was not initialized properly (ASR pipeline missing)."
86
+ return _transcribe_audio_file(audio_filepath, self.asr_pipeline)
87
+
88
+
89
+ # Expose the original function name if needed by other parts of the system (optional)
90
+ # transcribe_audio = _transcribe_audio_file # This would need adjustment if it expects the pipeline passed in
91
 
92
  # Example usage:
93
  if __name__ == "__main__":
94
+ tool_instance = SpeechToTextTool()
95
+
96
+ # Create a dummy MP3 file for testing (requires ffmpeg to be installed for pydub to work)
97
+ # This part is tricky to make universally runnable without external dependencies for audio creation.
98
+ # For a simple test, we'll assume a file exists or skip this part if it doesn't.
99
+
100
+ # Path to a test audio file (replace with an actual .mp3 or .wav file for testing)
101
+ # You might need to download a short sample audio file and place it in your project.
102
+ # e.g., create a `test_data` directory and put `sample.mp3` there.
103
+ test_audio_file = "./data/downloaded_files/1f975693-876d-457b-a649-393859e79bf3.mp3" # GAIA example
104
+ # test_audio_file_2 = "./data/downloaded_files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3" # GAIA example
105
+
106
+ if tool_instance.is_initialized:
107
+ if os.path.exists(test_audio_file):
108
+ print(f"Attempting to transcribe: {test_audio_file}")
109
+ transcribed_text = tool_instance.forward(test_audio_file)
110
+ print(f"Transcription:\n{transcribed_text}")
111
+ else:
112
+ print(
113
+ f"Test audio file not found: {test_audio_file}. Skipping transcription test."
114
+ )
115
+ print("Please place a sample .mp3 or .wav file at that location for testing.")
116
+
117
+ # if os.path.exists(test_audio_file_2):
118
+ # print(f"\nAttempting to transcribe: {test_audio_file_2}")
119
+ # transcribed_text_2 = tool_instance.forward(test_audio_file_2)
120
+ # print(f"Transcription 2:\n{transcribed_text_2}")
121
+ # else:
122
+ # print(f"Test audio file 2 not found: {test_audio_file_2}. Skipping.")
123
 
 
 
 
 
124
  else:
125
+ print(
126
+ "SpeechToTextTool could not be initialized (ASR pipeline missing). Transcription test skipped."
127
+ )
128
+
129
+ # Test with a non-existent file
130
+ non_existent_file = "./non_existent_audio.mp3"
131
+ print(f"\nAttempting to transcribe non-existent file: {non_existent_file}")
132
+ error_text = tool_instance.forward(non_existent_file)
133
+ print(f"Result for non-existent file:\n{error_text}")
134
+ assert "Error:" in error_text # Expect an error message
src/spreadsheet_tool.py CHANGED
@@ -1,59 +1,83 @@
1
  import os
2
  import pandas as pd
3
- from typing import Dict, List, Union, Tuple, Any
4
  import numpy as np
5
  from smolagents.tools import Tool
6
 
7
-
8
-
9
  class SpreadsheetTool(Tool):
10
  """
11
- Parses spreadsheet files (e.g., .xlsx) and extracts tabular data for analysis.
12
  Useful for reading, processing, and converting spreadsheet content to Python data structures.
13
  """
 
 
 
 
 
 
 
 
14
 
15
- def __init__(self):
16
  """Initialize the SpreadsheetTool."""
17
- pass
 
18
 
19
- def parse_spreadsheet(self, file_path: str) -> Dict[str, Any]:
20
- """
21
- Parse an Excel spreadsheet and extract useful information.
22
-
23
- Args:
24
- file_path: Path to the .xlsx file
25
-
26
- Returns:
27
- Dictionary containing:
28
- - sheets: Dictionary of sheet names and their DataFrames
29
- - sheet_names: List of sheet names
30
- - summary: Basic spreadsheet summary
31
- - error: Error message if any
32
- """
33
  if not os.path.exists(file_path):
34
  return {"error": f"File not found: {file_path}"}
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- # Read all sheets in the Excel file
38
  excel_file = pd.ExcelFile(file_path)
39
  sheet_names = excel_file.sheet_names
40
  sheets = {}
41
-
42
  for sheet_name in sheet_names:
43
  sheets[sheet_name] = pd.read_excel(excel_file, sheet_name=sheet_name)
44
-
45
- # Create a summary of the spreadsheet
46
  summary = self._create_summary(sheets)
47
-
48
- return {
49
- "sheets": sheets,
50
- "sheet_names": sheet_names,
51
- "summary": summary,
52
- "error": None
53
- }
54
  except Exception as e:
55
- return {"error": f"Error parsing spreadsheet: {str(e)}"}
56
-
 
 
 
 
 
 
 
 
 
 
 
 
57
  def _create_summary(self, sheets_dict: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
58
  """Create a summary of the spreadsheet contents."""
59
  summary = {}
@@ -70,179 +94,113 @@ class SpreadsheetTool(Tool):
70
 
71
  return summary
72
 
73
- def query_data(self, data: Dict[str, Any], query_instructions: str) -> Dict[str, Any]:
 
74
  """
75
  Execute a query on the spreadsheet data based on instructions.
76
-
77
- Args:
78
- data: The parsed spreadsheet data (from parse_spreadsheet)
79
- query_instructions: Instructions for querying the data (e.g., "Sum column A")
80
-
81
- Returns:
82
- Dictionary with query results and potential explanation
83
  """
84
- if data.get("error"):
85
- return {"error": data["error"]}
86
 
87
- try:
88
- # This is where you'd implement more sophisticated query logic
89
- # For now, we'll implement some basic operations
90
-
91
- sheets = data["sheets"]
92
- result = {}
93
-
94
- # Handle common operations based on query_instructions
95
- if "sum" in query_instructions.lower():
96
- # Extract column or range to sum
97
- # This is a simple implementation - a more robust one would use regex or NLP
98
- for sheet_name, df in sheets.items():
99
- numeric_cols = df.select_dtypes(include=[np.number]).columns
100
- if not numeric_cols.empty:
101
- result[f"{sheet_name}_sums"] = {
102
- col: df[col].sum() for col in numeric_cols
103
- }
104
-
105
- elif "average" in query_instructions.lower() or "mean" in query_instructions.lower():
106
- for sheet_name, df in sheets.items():
107
- numeric_cols = df.select_dtypes(include=[np.number]).columns
108
- if not numeric_cols.empty:
109
- result[f"{sheet_name}_averages"] = {
110
- col: df[col].mean() for col in numeric_cols
111
- }
112
-
113
- elif "count" in query_instructions.lower():
114
- for sheet_name, df in sheets.items():
115
- result[f"{sheet_name}_counts"] = {
116
- "rows": len(df),
117
- "non_null_counts": df.count().to_dict()
118
- }
119
-
120
- # Add the raw data structure for more custom processing by the agent
121
- result["data_structure"] = {
122
- sheet_name: {
123
- "columns": df.columns.tolist(),
124
- "dtypes": df.dtypes.astype(str).to_dict()
125
- } for sheet_name, df in sheets.items()
126
- }
127
-
128
- return result
129
-
130
- except Exception as e:
131
- return {"error": f"Error querying data: {str(e)}"}
132
-
133
- def extract_specific_data(self, data: Dict[str, Any], sheet_name: str = None,
134
- column_names: List[str] = None,
135
- row_indices: List[int] = None) -> Dict[str, Any]:
136
- """
137
- Extract specific data from the spreadsheet.
138
 
139
- Args:
140
- data: The parsed spreadsheet data
141
- sheet_name: Name of the sheet to extract from (default: first sheet)
142
- column_names: List of column names to extract (default: all columns)
143
- row_indices: List of row indices to extract (default: all rows)
144
-
145
- Returns:
146
- Dictionary with extracted data
147
- """
148
- if data.get("error"):
149
- return {"error": data["error"]}
150
 
151
- try:
152
- sheets = data["sheets"]
153
-
154
- # Default to the first sheet if not specified
155
- if sheet_name is None:
156
- sheet_name = data["sheet_names"][0]
157
-
158
- if sheet_name not in sheets:
159
- return {"error": f"Sheet '{sheet_name}' not found"}
160
-
161
- df = sheets[sheet_name]
162
-
163
- # Filter columns if specified
164
- if column_names:
165
- # Check if all requested columns exist
166
- missing_columns = [col for col in column_names if col not in df.columns]
167
- if missing_columns:
168
- return {"error": f"Columns not found: {missing_columns}"}
169
- df = df[column_names]
170
-
171
- # Filter rows if specified
172
- if row_indices:
173
- # Check if indices are in range
174
- max_index = len(df) - 1
175
- invalid_indices = [i for i in row_indices if i < 0 or i > max_index]
176
- if invalid_indices:
177
- return {"error": f"Row indices out of range: {invalid_indices}. Valid range: 0-{max_index}"}
178
- df = df.iloc[row_indices]
179
-
180
- return {
181
- "data": df.to_dict('records'),
182
- "shape": df.shape
183
- }
184
-
185
- except Exception as e:
186
- return {"error": f"Error extracting specific data: {str(e)}"}
187
 
 
188
 
189
- # Example usage (if this script is run directly)
190
- if __name__ == "__main__":
191
- # Create a simple test spreadsheet for demonstration
192
- test_dir = "spreadsheet_test"
193
- os.makedirs(test_dir, exist_ok=True)
194
-
195
- # Create a test DataFrame
196
- test_data = {
197
- 'Product': ['Apple', 'Orange', 'Banana', 'Mango'],
198
- 'Price': [1.2, 0.8, 0.5, 1.5],
199
- 'Quantity': [100, 80, 200, 50],
200
- 'Revenue': [120, 64, 100, 75]
201
- }
202
-
203
- df = pd.DataFrame(test_data)
204
- test_file_path = os.path.join(test_dir, "test_spreadsheet.xlsx")
205
-
206
- # Save to Excel
207
- with pd.ExcelWriter(test_file_path) as writer:
208
- df.to_excel(writer, sheet_name='Sales', index=False)
209
- # Create a second sheet with different data
210
- pd.DataFrame({
211
- 'Month': ['Jan', 'Feb', 'Mar', 'Apr'],
212
- 'Expenses': [50, 60, 55, 70]
213
- }).to_excel(writer, sheet_name='Expenses', index=False)
214
-
215
- print(f"Created test spreadsheet at {test_file_path}")
216
-
217
- # Test the tool
218
- spreadsheet_tool = SpreadsheetTool()
219
-
220
- # Parse the spreadsheet
221
- print("\nParsing spreadsheet...")
222
- parsed_data = spreadsheet_tool.parse_spreadsheet(test_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- if parsed_data.get("error"):
225
- print(f"Error: {parsed_data['error']}")
226
- else:
227
- print(f"Successfully parsed {len(parsed_data['sheet_names'])} sheets:")
228
- print(f"Sheet names: {parsed_data['sheet_names']}")
229
-
230
- # Show a sample of the first sheet
231
- first_sheet_name = parsed_data['sheet_names'][0]
232
- first_sheet = parsed_data['sheets'][first_sheet_name]
233
- print(f"\nFirst few rows of '{first_sheet_name}':")
234
- print(first_sheet.head())
235
-
236
- # Test query
237
- print("\nQuerying data (sum operation)...")
238
- query_result = spreadsheet_tool.query_data(parsed_data, "sum")
239
- print(f"Query result: {query_result}")
240
-
241
- # Test specific data extraction
242
- print("\nExtracting specific data...")
243
- extract_result = spreadsheet_tool.extract_specific_data(
244
- parsed_data,
245
- sheet_name='Sales',
246
- column_names=['Product', 'Revenue']
247
- )
248
- print(f"Extracted data: {extract_result}")
 
1
  import os
2
  import pandas as pd
3
+ from typing import Dict, List, Union, Tuple, Any, Optional
4
  import numpy as np
5
  from smolagents.tools import Tool
6
 
 
 
7
  class SpreadsheetTool(Tool):
8
  """
9
+ Parses spreadsheet files (e.g., .xlsx) and extracts tabular data for analysis or allows querying.
10
  Useful for reading, processing, and converting spreadsheet content to Python data structures.
11
  """
12
+ name = "spreadsheet_processor"
13
+ description = "Parses a spreadsheet file (e.g., .xlsx, .xls, .csv) and can perform queries. Returns extracted data or query results."
14
+ inputs = {
15
+ 'file_path': {'type': 'string', 'description': 'Path to the spreadsheet file.'},
16
+ 'query_instructions': {'type': 'string', 'description': 'Optional. Instructions for querying the data (e.g., "Sum column A"). If None, parses the whole sheet.', 'nullable': True}
17
+ }
18
+ outputs = {'result': {'type': 'object', 'description': 'A dictionary containing parsed sheet data, query results, or an error message.'}}
19
+ output_type = "object"
20
 
21
+ def __init__(self, *args, **kwargs):
22
  """Initialize the SpreadsheetTool."""
23
+ super().__init__(*args, **kwargs)
24
+ self.is_initialized = True
25
 
26
+ # Main entry point for the agent
27
+ def forward(self, file_path: str, query_instructions: Optional[str] = None) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
28
  if not os.path.exists(file_path):
29
  return {"error": f"File not found: {file_path}"}
30
+
31
+ # Determine file type for appropriate parsing
32
+ _, file_extension = os.path.splitext(file_path)
33
+ file_extension = file_extension.lower()
34
+
35
+ parsed_data = None
36
+ if file_extension in ['.xlsx', '.xls']:
37
+ parsed_data = self._parse_excel(file_path)
38
+ elif file_extension == '.csv':
39
+ parsed_data = self._parse_csv(file_path)
40
+ else:
41
+ return {"error": f"Unsupported file type: {file_extension}. Supported types: .xlsx, .xls, .csv"}
42
+
43
+ if parsed_data.get("error"):
44
+ return parsed_data # Return error from parsing step
45
+
46
+ if query_instructions:
47
+ return self._query_data(parsed_data, query_instructions)
48
+ else:
49
+ # If no query, return the parsed data and summary
50
+ return {
51
+ "parsed_sheets": parsed_data.get("sheets"),
52
+ "summary": parsed_data.get("summary"),
53
+ "message": "Spreadsheet parsed successfully."
54
+ }
55
+
56
+ def _parse_excel(self, file_path: str) -> Dict[str, Any]:
57
+ """Parse an Excel spreadsheet and extract useful information."""
58
  try:
 
59
  excel_file = pd.ExcelFile(file_path)
60
  sheet_names = excel_file.sheet_names
61
  sheets = {}
 
62
  for sheet_name in sheet_names:
63
  sheets[sheet_name] = pd.read_excel(excel_file, sheet_name=sheet_name)
 
 
64
  summary = self._create_summary(sheets)
65
+ return {"sheets": sheets, "sheet_names": sheet_names, "summary": summary, "error": None}
 
 
 
 
 
 
66
  except Exception as e:
67
+ return {"error": f"Error parsing Excel spreadsheet: {str(e)}"}
68
+
69
+ def _parse_csv(self, file_path: str) -> Dict[str, Any]:
70
+ """Parse a CSV file."""
71
+ try:
72
+ df = pd.read_csv(file_path)
73
+ # CSVs don't have multiple sheets, so we adapt the structure
74
+ sheet_name = os.path.splitext(os.path.basename(file_path))[0]
75
+ sheets = {sheet_name: df}
76
+ summary = self._create_summary(sheets)
77
+ return {"sheets": sheets, "sheet_names": [sheet_name], "summary": summary, "error": None}
78
+ except Exception as e:
79
+ return {"error": f"Error parsing CSV file: {str(e)}"}
80
+
81
  def _create_summary(self, sheets_dict: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
82
  """Create a summary of the spreadsheet contents."""
83
  summary = {}
 
94
 
95
  return summary
96
 
97
+ # Renamed from query_data to _query_data and adjusted arguments
98
+ def _query_data(self, parsed_data_dict: Dict[str, Any], query_instructions: str) -> Dict[str, Any]:
99
  """
100
  Execute a query on the spreadsheet data based on instructions.
101
+ This is a simplified placeholder. Real implementation would need robust query parsing.
 
 
 
 
 
 
102
  """
103
+ if parsed_data_dict.get("error"):
104
+ return {"error": parsed_data_dict["error"]}
105
 
106
+ sheets = parsed_data_dict.get("sheets")
107
+ if not sheets:
108
+ return {"error": "No sheets data available for querying."}
109
+
110
+ # Placeholder for actual query logic.
111
+ # This would involve parsing `query_instructions` (e.g., using regex, NLP, or a DSL)
112
+ # and applying pandas operations.
113
+ # For now, let's return a message indicating the query was received and basic info.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ results = {}
116
+ explanation = f"Query instruction received: '{query_instructions}'. Advanced query execution is not fully implemented. " \
117
+ f"Returning summary of available sheets: {list(sheets.keys())}."
 
 
 
 
 
 
 
 
118
 
119
+ # Example: if query asks for sum, try to sum first numeric column of first sheet
120
+ if "sum" in query_instructions.lower():
121
+ first_sheet_name = next(iter(sheets))
122
+ df = sheets[first_sheet_name]
123
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
124
+ if not numeric_cols.empty:
125
+ col_to_sum = numeric_cols[0]
126
+ try:
127
+ total_sum = df[col_to_sum].sum()
128
+ results[f'{first_sheet_name}_{col_to_sum}_sum'] = total_sum
129
+ explanation += f" Example sum of column '{col_to_sum}' in sheet '{first_sheet_name}': {total_sum}."
130
+ except Exception as e:
131
+ explanation += f" Could not perform example sum: {e}."
132
+ else:
133
+ explanation += " No numeric columns found for example sum."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ return {"query_results": results, "explanation": explanation, "original_query": query_instructions}
136
 
137
+ # Example usage (for direct testing)
138
+ if __name__ == '__main__':
139
+ tool = SpreadsheetTool()
140
+
141
+ # Create dummy files for testing
142
+ dummy_excel_file = "dummy_test.xlsx"
143
+ dummy_csv_file = "dummy_test.csv"
144
+
145
+ # Create a dummy Excel file
146
+ df_excel = pd.DataFrame({
147
+ 'colA': [1, 2, 3, 4, 5],
148
+ 'colB': ['apple', 'banana', 'cherry', 'date', 'elderberry'],
149
+ 'colC': [10.1, 20.2, 30.3, 40.4, 50.5]
150
+ })
151
+ with pd.ExcelWriter(dummy_excel_file) as writer:
152
+ df_excel.to_excel(writer, sheet_name='Sheet1', index=False)
153
+ df_excel.head(2).to_excel(writer, sheet_name='Sheet2', index=False)
154
+
155
+ # Create a dummy CSV file
156
+ df_csv = pd.DataFrame({
157
+ 'id': [101, 102, 103],
158
+ 'product': ['widget', 'gadget', 'gizmo'],
159
+ 'price': [19.99, 29.50, 15.00]
160
+ })
161
+ df_csv.to_csv(dummy_csv_file, index=False)
162
+
163
+ print("--- Test 1: Parse Excel file (no query) ---")
164
+ result1 = tool.forward(file_path=dummy_excel_file)
165
+ print(result1)
166
+ assert "error" not in result1 or result1["error"] is None
167
+ assert "Sheet1" in result1["parsed_sheets"]
168
+
169
+ print("\n--- Test 2: Parse CSV file (no query) ---")
170
+ result2 = tool.forward(file_path=dummy_csv_file)
171
+ print(result2)
172
+ assert "error" not in result2 or result2["error"] is None
173
+ assert dummy_csv_file.split('.')[0] in result2["parsed_sheets"]
174
+
175
+ print("\n--- Test 3: Query Excel file (simple sum example) ---")
176
+ result3 = tool.forward(file_path=dummy_excel_file, query_instructions="sum colA from Sheet1")
177
+ print(result3)
178
+ assert "error" not in result3 or result3["error"] is None
179
+ assert "query_results" in result3
180
+ if result3.get("query_results"):
181
+ assert "Sheet1_colA_sum" in result3["query_results"]
182
+ assert result3["query_results"]["Sheet1_colA_sum"] == 15
183
+
184
+ print("\n--- Test 4: File not found ---")
185
+ result4 = tool.forward(file_path="non_existent_file.xlsx")
186
+ print(result4)
187
+ assert result4["error"] is not None
188
+ assert "File not found" in result4["error"]
189
 
190
+ print("\n--- Test 5: Unsupported file type ---")
191
+ dummy_txt_file = "dummy_test.txt"
192
+ with open(dummy_txt_file, "w") as f:
193
+ f.write("hello")
194
+ result5 = tool.forward(file_path=dummy_txt_file)
195
+ print(result5)
196
+ assert result5["error"] is not None
197
+ assert "Unsupported file type" in result5["error"]
198
+ os.remove(dummy_txt_file)
199
+
200
+ # Clean up dummy files
201
+ if os.path.exists(dummy_excel_file):
202
+ os.remove(dummy_excel_file)
203
+ if os.path.exists(dummy_csv_file):
204
+ os.remove(dummy_csv_file)
205
+
206
+ print("\nSpreadsheetTool tests completed.")
 
 
 
 
 
 
 
 
src/video_processing_tool.py CHANGED
@@ -14,8 +14,18 @@ class VideoProcessingTool(Tool):
14
  Analyzes video content, extracting information such as frames, audio, or metadata.
15
  Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
16
  """
17
-
18
- def __init__(self, model_cfg_path=None, model_weights_path=None, class_names_path=None, temp_dir_base=None):
 
 
 
 
 
 
 
 
 
 
19
  """
20
  Initializes the VideoProcessingTool.
21
 
@@ -25,6 +35,9 @@ class VideoProcessingTool(Tool):
25
  class_names_path (str, optional): Path to the file containing class names for the model.
26
  temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
27
  """
 
 
 
28
  if temp_dir_base:
29
  self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
30
  else:
@@ -37,16 +50,67 @@ class VideoProcessingTool(Tool):
37
  if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
38
  try:
39
  self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
40
- # Set preferable backend and target
41
  self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
42
  self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
43
  with open(class_names_path, "r") as f:
44
  self.class_names = [line.strip() for line in f.readlines()]
 
45
  except Exception as e:
46
  print(f"Error loading CV model: {e}. Object detection will not be available.")
47
  self.object_detection_model = None
48
  else:
49
  print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def _extract_video_id(self, youtube_url):
52
  """Extract the YouTube video ID from a URL."""
@@ -105,7 +169,69 @@ class VideoProcessingTool(Tool):
105
  except Exception as e:
106
  return {"error": f"Failed to download video: {str(e)}"}
107
 
108
- def get_video_transcript(self, youtube_url, languages=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  """Get the transcript/captions of a YouTube video."""
110
  if languages is None:
111
  languages = ['en', 'en-US'] # Default to English
@@ -161,20 +287,18 @@ class VideoProcessingTool(Tool):
161
  # Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
162
  return {"error": f"Failed to get transcript: {str(e)}"}
163
 
164
- def count_objects_in_video(self, video_path, target_classes=None, confidence_threshold=0.5, frame_skip=5):
165
  """
166
- Counts specified objects appearing in the video using the loaded DNN model.
167
- Determines the maximum number of target objects appearing simultaneously in any single frame.
168
  Args:
169
  video_path (str): Path to the video file.
170
- target_classes (list, optional): A list of object classes (strings) to count (e.g., ["bird", "cat"]).
171
- If None, counts all detected objects.
172
  confidence_threshold (float): Minimum confidence for an object to be counted.
173
- frame_skip (int): Process every Nth frame to speed up analysis.
 
 
174
  Returns:
175
- dict: A dictionary with counts or an error message.
176
- e.g., {"success": True, "max_simultaneous_birds": 3, "max_simultaneous_cats": 1}
177
- or {"error": "Object detection model not loaded."}
178
  """
179
  if not self.object_detection_model or not self.class_names:
180
  return {"error": "Object detection model not loaded or class names missing."}
@@ -185,168 +309,45 @@ class VideoProcessingTool(Tool):
185
  if not cap.isOpened():
186
  return {"error": "Could not open video file."}
187
 
188
- max_counts_per_class = {cls: 0 for cls in target_classes} if target_classes else {}
189
- # If target_classes is None, we'd need to initialize for all detected classes,
190
- # but for simplicity, let's require target_classes for now or adjust later.
191
- if not target_classes:
192
- # Defaulting to a common class if none specified, e.g. 'person'
193
- # Or, one could count all unique classes detected. For GAIA, specific targets are better.
194
- return {"error": "target_classes must be specified for counting."}
195
-
196
-
197
  frame_count = 0
 
 
 
198
  while cap.isOpened():
199
  ret, frame = cap.read()
200
  if not ret:
201
  break
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  frame_count += 1
204
- if frame_count % frame_skip != 0:
205
- continue
206
 
207
- height, width = frame.shape[:2]
208
- blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
209
- self.object_detection_model.setInput(blob)
210
-
211
- layer_names = self.object_detection_model.getLayerNames()
212
- # Handle potential differences in getUnconnectedOutLayers() return value
213
- unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
214
- if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
215
- output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
216
- else: # For typical cases
217
- output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
218
-
219
- detections = self.object_detection_model.forward(output_layer_names)
220
-
221
- current_frame_counts = {cls: 0 for cls in target_classes}
222
-
223
- for detection_set in detections: # Detections can come from multiple output layers
224
- for detection in detection_set:
225
- scores = detection[5:]
226
- class_id = np.argmax(scores)
227
- confidence = scores[class_id]
228
-
229
- if confidence > confidence_threshold:
230
- detected_class_name = self.class_names[class_id]
231
- if detected_class_name in target_classes:
232
- current_frame_counts[detected_class_name] += 1
233
-
234
- for cls in target_classes:
235
- if current_frame_counts[cls] > max_counts_per_class[cls]:
236
- max_counts_per_class[cls] = current_frame_counts[cls]
237
-
238
  cap.release()
239
- result = {"success": True}
240
- for cls, count in max_counts_per_class.items():
241
- result[f"max_simultaneous_{cls.replace(' ', '_')}"] = count # e.g. "max_simultaneous_bird"
242
- return result
243
-
244
- def find_dialogue_response(self, transcript_entries, query_phrase, max_entries_gap=2, max_time_gap_s=5.0):
245
- """
246
- Finds what is said in response to a given query phrase in transcript entries.
247
- Looks for the query phrase and then captures the text from subsequent entries.
248
-
249
- Args:
250
- transcript_entries (list): List of transcript dictionaries (from get_video_transcript).
251
- query_phrase (str): The phrase to find (e.g., a question).
252
- max_entries_gap (int): How many transcript entries to look ahead for a response.
253
- max_time_gap_s (float): Maximum time in seconds after the query phrase to consider for a response.
254
-
255
- Returns:
256
- dict: {"success": True, "response_text": "...", "found_at_entry": {...}} or {"error": "..."}
257
- """
258
- if not transcript_entries:
259
- return {"error": "Transcript entries are empty."}
260
-
261
- query_phrase_lower = query_phrase.lower().rstrip('?.!,;') # Strip common trailing punctuation
262
-
263
- for i, entry in enumerate(transcript_entries):
264
- # Correctly access attributes: .text, .start, .duration
265
- if query_phrase_lower in entry.text.lower():
266
- # Found the query phrase, now look for the response
267
- response_parts = []
268
- start_time_of_query = entry.start + entry.duration # End time of query entry
269
-
270
- for j in range(i + 1, min(i + 1 + max_entries_gap + 1, len(transcript_entries))):
271
- next_entry = transcript_entries[j]
272
- # Check if the next entry is within the time gap
273
- if next_entry.start - start_time_of_query > max_time_gap_s:
274
- break # Too much time has passed
275
-
276
- # Add text if it's not just noise or very short (heuristic)
277
- if next_entry.text.strip() and len(next_entry.text.strip()) > 1:
278
- response_parts.append(next_entry.text)
279
-
280
- # If we have collected some response, and the next entry is significantly later, stop.
281
- if response_parts and (j + 1 < len(transcript_entries)):
282
- if transcript_entries[j+1].start - (next_entry.start + next_entry.duration) > 1.0: # If gap > 1s
283
- break
284
-
285
- if response_parts:
286
- return {
287
- "success": True,
288
- "response_text": " ".join(response_parts),
289
- "query_entry": entry,
290
- "response_start_entry_index": i + 1
291
- }
292
- # If no response found immediately after, but query was found
293
- return {"error": f"Query phrase '{query_phrase}' found, but no subsequent dialogue captured as response within gap."}
294
-
295
- return {"error": f"Query phrase '{query_phrase}' not found in transcript."}
296
-
297
-
298
- def process_video(self, youtube_url, query_type, query_params=None):
299
- """
300
- Main method to process a video based on the type of query.
301
-
302
- Args:
303
- youtube_url (str): URL of the YouTube video.
304
- query_type (str): Type of processing: "transcript", "object_count", "dialogue_response".
305
- query_params (dict, optional): Additional parameters for the specific query type.
306
- For "object_count": {"target_classes": ["bird"], "confidence_threshold": 0.5, "resolution": "360p"}
307
- For "dialogue_response": {"query_phrase": "Isn't that hot?", "languages": ['en']}
308
- """
309
- if query_params is None:
310
- query_params = {}
311
-
312
- if query_type == "transcript":
313
- return self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
314
-
315
- elif query_type == "object_count":
316
- if not self.object_detection_model:
317
- return {"error": "Object detection model not initialized. Cannot count objects."}
318
-
319
- resolution = query_params.get("resolution", "360p")
320
- download_result = self.download_video(youtube_url, resolution=resolution)
321
- if "error" in download_result:
322
- return download_result
323
-
324
- video_path = download_result["file_path"]
325
- target_classes = query_params.get("target_classes")
326
- if not target_classes or not isinstance(target_classes, list):
327
- return {"error": "query_params must include 'target_classes' as a list for object_count."}
328
-
329
- confidence = query_params.get("confidence_threshold", 0.5)
330
- frame_skip = query_params.get("frame_skip", 5)
331
- return self.count_objects_in_video(video_path, target_classes, confidence, frame_skip)
332
-
333
- elif query_type == "dialogue_response":
334
- transcript_result = self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
335
- if "error" in transcript_result:
336
- return transcript_result
337
-
338
- query_phrase = query_params.get("query_phrase")
339
- if not query_phrase:
340
- return {"error": "query_params must include 'query_phrase' for dialogue_response."}
341
-
342
- return self.find_dialogue_response(
343
- transcript_result["transcript_entries"],
344
- query_phrase,
345
- max_entries_gap=query_params.get("max_entries_gap", 2),
346
- max_time_gap_s=query_params.get("max_time_gap_s", 5.0)
347
- )
348
-
349
- return {"error": f"Unsupported query type: {query_type}"}
350
 
351
  def cleanup(self):
352
  """Remove temporary files and directory."""
 
14
  Analyzes video content, extracting information such as frames, audio, or metadata.
15
  Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
16
  """
17
+ name = "video_processor"
18
+ description = "Analyzes video content from a file path or YouTube URL. Can extract frames, detect objects, get transcripts, and provide video metadata."
19
+ inputs = {
20
+ "file_path": {"type": "string", "description": "Path to the video file or YouTube URL.", "nullable": True},
21
+ "task": {"type": "string", "description": "Specific task to perform (e.g., 'extract_frames', 'get_transcript', 'detect_objects', 'get_metadata').", "nullable": True},
22
+ "task_parameters": {"type": "object", "description": "Parameters for the specific task (e.g., frame extraction interval, object detection confidence).", "nullable": True}
23
+ }
24
+ outputs = {"result": {"type": "object", "description": "The result of the video processing task, e.g., list of frame paths, transcript text, object detection results, or metadata dictionary."}}
25
+ output_type = "object"
26
+
27
+
28
+ def __init__(self, model_cfg_path=None, model_weights_path=None, class_names_path=None, temp_dir_base=None, *args, **kwargs):
29
  """
30
  Initializes the VideoProcessingTool.
31
 
 
35
  class_names_path (str, optional): Path to the file containing class names for the model.
36
  temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
37
  """
38
+ super().__init__(*args, **kwargs)
39
+ self.is_initialized = False # Will be set to True after successful setup
40
+
41
  if temp_dir_base:
42
  self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
43
  else:
 
50
  if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
51
  try:
52
  self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
 
53
  self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
54
  self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
55
  with open(class_names_path, "r") as f:
56
  self.class_names = [line.strip() for line in f.readlines()]
57
+ print("CV Model loaded successfully.")
58
  except Exception as e:
59
  print(f"Error loading CV model: {e}. Object detection will not be available.")
60
  self.object_detection_model = None
61
  else:
62
  print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
63
+ else:
64
+ print("CV model paths not provided. Object detection will not be available.")
65
+
66
+ self.is_initialized = True
67
+
68
+ def forward(self, file_path: str = None, task: str = "get_metadata", task_parameters: dict = None):
69
+ """
70
+ Main entry point for video processing tasks.
71
+ """
72
+ if not self.is_initialized:
73
+ return {"error": "Tool not initialized properly."}
74
+
75
+ if task_parameters is None:
76
+ task_parameters = {}
77
+
78
+ is_youtube_url = file_path and ("youtube.com/" in file_path or "youtu.be/" in file_path)
79
+ video_source_path = file_path
80
+
81
+ if is_youtube_url:
82
+ download_resolution = task_parameters.get("resolution", "360p")
83
+ download_result = self.download_video(file_path, resolution=download_resolution)
84
+ if download_result.get("error"):
85
+ return download_result
86
+ video_source_path = download_result.get("file_path")
87
+ if not video_source_path or not os.path.exists(video_source_path):
88
+ return {"error": f"Failed to download or locate video from URL: {file_path}"}
89
+
90
+ elif file_path and not os.path.exists(file_path):
91
+ return {"error": f"Video file not found: {file_path}"}
92
+ elif not file_path and task not in ['get_transcript']: # transcript can work with URL directly
93
+ return {"error": "File path is required for this task."}
94
+
95
+
96
+ if task == "get_metadata":
97
+ return self.get_video_metadata(video_source_path)
98
+ elif task == "extract_frames":
99
+ interval_seconds = task_parameters.get("interval_seconds", 5)
100
+ max_frames = task_parameters.get("max_frames")
101
+ return self.extract_frames_from_video(video_source_path, interval_seconds=interval_seconds, max_frames=max_frames)
102
+ elif task == "get_transcript":
103
+ # Use original file_path which might be the URL
104
+ return self.get_youtube_transcript(file_path)
105
+ elif task == "detect_objects":
106
+ if not self.object_detection_model:
107
+ return {"error": "Object detection model not loaded."}
108
+ confidence_threshold = task_parameters.get("confidence_threshold", 0.5)
109
+ frames_to_process = task_parameters.get("frames_to_process", 5) # Process N frames
110
+ return self.detect_objects_in_video(video_source_path, confidence_threshold=confidence_threshold, num_frames_to_sample=frames_to_process)
111
+ # Add more tasks as needed, e.g., extract_audio
112
+ else:
113
+ return {"error": f"Unsupported task: {task}"}
114
 
115
  def _extract_video_id(self, youtube_url):
116
  """Extract the YouTube video ID from a URL."""
 
169
  except Exception as e:
170
  return {"error": f"Failed to download video: {str(e)}"}
171
 
172
+ def get_video_metadata(self, video_path):
173
+ """Extract metadata from the video file."""
174
+ if not os.path.exists(video_path):
175
+ return {"error": f"Video file not found: {video_path}"}
176
+
177
+ cap = cv2.VideoCapture(video_path)
178
+ if not cap.isOpened():
179
+ return {"error": "Could not open video file."}
180
+
181
+ metadata = {
182
+ "frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
183
+ "fps": cap.get(cv2.CAP_PROP_FPS),
184
+ "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
185
+ "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
186
+ "duration": cap.get(cv2.CAP_PROP_FRAME_COUNT) / cap.get(cv2.CAP_PROP_FPS)
187
+ }
188
+
189
+ cap.release()
190
+ return {"success": True, "metadata": metadata}
191
+
192
+ def extract_frames_from_video(self, video_path, interval_seconds=5, max_frames=None):
193
+ """
194
+ Extracts frames from the video at specified intervals.
195
+
196
+ Args:
197
+ video_path (str): Path to the video file.
198
+ interval_seconds (int): Interval in seconds between frames.
199
+ max_frames (int, optional): Maximum number of frames to extract.
200
+
201
+ Returns:
202
+ dict: {"success": True, "extracted_frame_paths": [...] } or {"error": "..."}
203
+ """
204
+ if not os.path.exists(video_path):
205
+ return {"error": f"Video file not found: {video_path}"}
206
+
207
+ cap = cv2.VideoCapture(video_path)
208
+ if not cap.isOpened():
209
+ return {"error": "Could not open video file."}
210
+
211
+ fps = cap.get(cv2.CAP_PROP_FPS)
212
+ frame_interval = int(fps * interval_seconds)
213
+ extracted_frame_paths = []
214
+ frame_count = 0
215
+
216
+ while cap.isOpened():
217
+ ret, frame = cap.read()
218
+ if not ret:
219
+ break
220
+
221
+ if frame_count % frame_interval == 0:
222
+ frame_id = int(frame_count / frame_interval)
223
+ frame_file_path = os.path.join(self.temp_dir, f"frame_{frame_id:04d}.jpg")
224
+ cv2.imwrite(frame_file_path, frame)
225
+ extracted_frame_paths.append(frame_file_path)
226
+ if max_frames and len(extracted_frame_paths) >= max_frames:
227
+ break
228
+
229
+ frame_count += 1
230
+
231
+ cap.release()
232
+ return {"success": True, "extracted_frame_paths": extracted_frame_paths}
233
+
234
+ def get_youtube_transcript(self, youtube_url, languages=None):
235
  """Get the transcript/captions of a YouTube video."""
236
  if languages is None:
237
  languages = ['en', 'en-US'] # Default to English
 
287
  # Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
288
  return {"error": f"Failed to get transcript: {str(e)}"}
289
 
290
+ def detect_objects_in_video(self, video_path, confidence_threshold=0.5, num_frames_to_sample=5, target_fps=1):
291
  """
292
+ Detects objects in the video and returns the count of specified objects.
293
+
294
  Args:
295
  video_path (str): Path to the video file.
 
 
296
  confidence_threshold (float): Minimum confidence for an object to be counted.
297
+ num_frames_to_sample (int): Number of frames to sample for object detection.
298
+ target_fps (int): Target frames per second for processing.
299
+
300
  Returns:
301
+ dict: {"success": True, "object_counts": {...}} or {"error": "..."}
 
 
302
  """
303
  if not self.object_detection_model or not self.class_names:
304
  return {"error": "Object detection model not loaded or class names missing."}
 
309
  if not cap.isOpened():
310
  return {"error": "Could not open video file."}
311
 
312
+ object_counts = {cls: 0 for cls in self.class_names}
 
 
 
 
 
 
 
 
313
  frame_count = 0
314
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
315
+ sample_interval = max(1, total_frames // num_frames_to_sample)
316
+
317
  while cap.isOpened():
318
  ret, frame = cap.read()
319
  if not ret:
320
  break
321
 
322
+ if frame_count % sample_interval == 0:
323
+ height, width = frame.shape[:2]
324
+ blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
325
+ self.object_detection_model.setInput(blob)
326
+
327
+ layer_names = self.object_detection_model.getLayerNames()
328
+ # Handle potential differences in getUnconnectedOutLayers() return value
329
+ unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
330
+ if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
331
+ output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
332
+ else: # For typical cases
333
+ output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
334
+
335
+ detections = self.object_detection_model.forward(output_layer_names)
336
+
337
+ for detection_set in detections: # Detections can come from multiple output layers
338
+ for detection in detection_set:
339
+ scores = detection[5:]
340
+ class_id = np.argmax(scores)
341
+ confidence = scores[class_id]
342
+
343
+ if confidence > confidence_threshold:
344
+ detected_class_name = self.class_names[class_id]
345
+ object_counts[detected_class_name] += 1
346
+
347
  frame_count += 1
 
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  cap.release()
350
+ return {"success": True, "object_counts": object_counts}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  def cleanup(self):
353
  """Remove temporary files and directory."""