Spaces:
Runtime error
Runtime error
Yago Bolivar
commited on
Commit
·
87aa741
1
Parent(s):
b70394c
Refactor speech_to_text.py to implement a singleton ASR pipeline, enhance error handling, and introduce SpeechToTextTool for better integration. Update spreadsheet_tool.py to support querying and improve parsing functionality, including CSV support. Enhance video_processing_tool.py with new tasks for metadata extraction and frame extraction, while improving object detection capabilities and initialization checks.
Browse files- src/image_processing_tool.py +1 -1
- src/markdown_table_parser.py +55 -27
- src/python_tool.py +188 -130
- src/speech_to_text.py +118 -19
- src/spreadsheet_tool.py +159 -201
- src/video_processing_tool.py +167 -166
src/image_processing_tool.py
CHANGED
|
@@ -46,7 +46,7 @@ class ImageProcessor(Tool):
|
|
| 46 |
# For simplicity, let's assume a general 'process' action and specify task type in params
|
| 47 |
inputs = {
|
| 48 |
'image_filepath': {'type': 'string', 'description': 'Path to the image file.'},
|
| 49 |
-
'task': {'type': 'string', 'description': 'Specific task to perform (e.g., \'caption\', \'chess_analysis\').'}
|
| 50 |
}
|
| 51 |
outputs = {'result': {'type': 'object', 'description': 'The result of the image processing task (e.g., text caption, chess move, error message).'}}
|
| 52 |
output_type = "object"
|
|
|
|
| 46 |
# For simplicity, let's assume a general 'process' action and specify task type in params
|
| 47 |
inputs = {
|
| 48 |
'image_filepath': {'type': 'string', 'description': 'Path to the image file.'},
|
| 49 |
+
'task': {'type': 'string', 'description': 'Specific task to perform (e.g., \'caption\', \'chess_analysis\').', 'nullable': True} # Added nullable: True
|
| 50 |
}
|
| 51 |
outputs = {'result': {'type': 'object', 'description': 'The result of the image processing task (e.g., text caption, chess move, error message).'}}
|
| 52 |
output_type = "object"
|
src/markdown_table_parser.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
import re
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
| 4 |
"""
|
| 5 |
Parses the first valid Markdown table found in a string.
|
| 6 |
Returns a dictionary (headers as keys, lists of cell content as values)
|
|
@@ -48,15 +51,45 @@ def parse_markdown_table(markdown_text: str) -> dict[str, list[str]] | None:
|
|
| 48 |
# First cell is row label, rest are data
|
| 49 |
table[headers[0]].append(cells[0])
|
| 50 |
for k, h in enumerate(headers[1:], 1):
|
|
|
|
|
|
|
|
|
|
| 51 |
table[h].append(cells[k])
|
| 52 |
else:
|
| 53 |
for k, h in enumerate(headers):
|
|
|
|
|
|
|
| 54 |
table[h].append(cells[k])
|
| 55 |
j += 1
|
| 56 |
return table
|
| 57 |
return None
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
if __name__ == '__main__':
|
|
|
|
| 60 |
example_table = """
|
| 61 |
|*|a|b|c|d|e|
|
| 62 |
|---|---|---|---|---|---|
|
|
@@ -66,7 +99,7 @@ if __name__ == '__main__':
|
|
| 66 |
|d|b|e|b|e|d|
|
| 67 |
|e|d|b|a|d|c|
|
| 68 |
"""
|
| 69 |
-
parsed =
|
| 70 |
print("Parsed GAIA example:")
|
| 71 |
if parsed:
|
| 72 |
for header, column_data in parsed.items():
|
|
@@ -83,7 +116,7 @@ if __name__ == '__main__':
|
|
| 83 |
| Carol | 45 | London |
|
| 84 |
Some text after
|
| 85 |
"""
|
| 86 |
-
parsed_2 =
|
| 87 |
print("\\nParsed Table 2 (with surrounding text):")
|
| 88 |
if parsed_2:
|
| 89 |
for header, column_data in parsed_2.items():
|
|
@@ -95,36 +128,31 @@ if __name__ == '__main__':
|
|
| 95 |
| Header1 | Header2 |
|
| 96 |
|---------|---------|
|
| 97 |
"""
|
| 98 |
-
parsed_empty =
|
| 99 |
print("\\nParsed Empty Table with Header:")
|
| 100 |
if parsed_empty:
|
| 101 |
for header, column_data in parsed_empty.items():
|
| 102 |
print(f"Header: {header}, Data: {column_data}")
|
| 103 |
else:
|
| 104 |
-
print("Failed to parse empty
|
| 105 |
-
|
| 106 |
-
malformed_separator = """
|
| 107 |
-
| Header1 | Header2 |
|
| 108 |
-
|---foo---|---------|
|
| 109 |
-
| data1 | data2 |
|
| 110 |
-
"""
|
| 111 |
-
parsed_mal_sep = parse_markdown_table(malformed_separator)
|
| 112 |
-
print("\\nParsed table with malformed separator:")
|
| 113 |
-
if parsed_mal_sep:
|
| 114 |
-
print(parsed_mal_sep)
|
| 115 |
-
else:
|
| 116 |
-
print("Failed to parse (correctly).")
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
|
| 120 |
-
|
| 121 |
-
|
|
| 122 |
-
| Paragraph | Text |
|
| 123 |
"""
|
| 124 |
-
|
| 125 |
-
print("\\nParsed
|
| 126 |
-
if
|
| 127 |
-
for header, column_data in
|
| 128 |
print(f"Header: {header}, Data: {column_data}")
|
| 129 |
else:
|
| 130 |
-
print("Failed to parse table
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
+
from smolagents.tools import Tool
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
|
| 5 |
+
# Original parsing function
|
| 6 |
+
def _parse_markdown_table_string(markdown_text: str) -> Optional[Dict[str, List[str]]]:
|
| 7 |
"""
|
| 8 |
Parses the first valid Markdown table found in a string.
|
| 9 |
Returns a dictionary (headers as keys, lists of cell content as values)
|
|
|
|
| 51 |
# First cell is row label, rest are data
|
| 52 |
table[headers[0]].append(cells[0])
|
| 53 |
for k, h in enumerate(headers[1:], 1):
|
| 54 |
+
# Ensure the key exists and is a list
|
| 55 |
+
if h not in table or not isinstance(table[h], list):
|
| 56 |
+
table[h] = [] # Initialize if not present or not a list
|
| 57 |
table[h].append(cells[k])
|
| 58 |
else:
|
| 59 |
for k, h in enumerate(headers):
|
| 60 |
+
if h not in table or not isinstance(table[h], list):
|
| 61 |
+
table[h] = []
|
| 62 |
table[h].append(cells[k])
|
| 63 |
j += 1
|
| 64 |
return table
|
| 65 |
return None
|
| 66 |
|
| 67 |
+
class MarkdownTableParserTool(Tool):
|
| 68 |
+
"""
|
| 69 |
+
Parses a Markdown table from a given text string.
|
| 70 |
+
Useful for converting markdown tables into Python data structures for further analysis.
|
| 71 |
+
"""
|
| 72 |
+
name = "markdown_table_parser"
|
| 73 |
+
description = "Parses the first valid Markdown table found in a string and returns it as a dictionary."
|
| 74 |
+
inputs = {'markdown_text': {'type': 'string', 'description': 'The string containing the Markdown table.'}}
|
| 75 |
+
outputs = {'parsed_table': {'type': 'object', 'description': 'A dictionary representing the table (headers as keys, lists of cell content as values), or null if no table is found.'}}
|
| 76 |
+
output_type = "object" # Or dict/None
|
| 77 |
+
|
| 78 |
+
def __init__(self, *args, **kwargs):
|
| 79 |
+
super().__init__(*args, **kwargs)
|
| 80 |
+
self.is_initialized = True
|
| 81 |
+
|
| 82 |
+
def forward(self, markdown_text: str) -> Optional[Dict[str, List[str]]]:
|
| 83 |
+
"""
|
| 84 |
+
Wrapper for the _parse_markdown_table_string function.
|
| 85 |
+
"""
|
| 86 |
+
return _parse_markdown_table_string(markdown_text)
|
| 87 |
+
|
| 88 |
+
# Expose the original function name if other parts of the system expect it (optional)
|
| 89 |
+
parse_markdown_table = _parse_markdown_table_string
|
| 90 |
+
|
| 91 |
if __name__ == '__main__':
|
| 92 |
+
tool_instance = MarkdownTableParserTool()
|
| 93 |
example_table = """
|
| 94 |
|*|a|b|c|d|e|
|
| 95 |
|---|---|---|---|---|---|
|
|
|
|
| 99 |
|d|b|e|b|e|d|
|
| 100 |
|e|d|b|a|d|c|
|
| 101 |
"""
|
| 102 |
+
parsed = tool_instance.forward(example_table)
|
| 103 |
print("Parsed GAIA example:")
|
| 104 |
if parsed:
|
| 105 |
for header, column_data in parsed.items():
|
|
|
|
| 116 |
| Carol | 45 | London |
|
| 117 |
Some text after
|
| 118 |
"""
|
| 119 |
+
parsed_2 = tool_instance.forward(example_table_2)
|
| 120 |
print("\\nParsed Table 2 (with surrounding text):")
|
| 121 |
if parsed_2:
|
| 122 |
for header, column_data in parsed_2.items():
|
|
|
|
| 128 |
| Header1 | Header2 |
|
| 129 |
|---------|---------|
|
| 130 |
"""
|
| 131 |
+
parsed_empty = tool_instance.forward(empty_table_with_header)
|
| 132 |
print("\\nParsed Empty Table with Header:")
|
| 133 |
if parsed_empty:
|
| 134 |
for header, column_data in parsed_empty.items():
|
| 135 |
print(f"Header: {header}, Data: {column_data}")
|
| 136 |
else:
|
| 137 |
+
print("Failed to parse table (empty with header).") # Corrected message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
malformed_table = """
|
| 140 |
+
| Header1 | Header2
|
| 141 |
+
|--- ---|
|
| 142 |
+
| cell1 | cell2 |
|
|
|
|
| 143 |
"""
|
| 144 |
+
parsed_malformed = tool_instance.forward(malformed_table)
|
| 145 |
+
print("\\nParsed Malformed Table:")
|
| 146 |
+
if parsed_malformed:
|
| 147 |
+
for header, column_data in parsed_malformed.items():
|
| 148 |
print(f"Header: {header}, Data: {column_data}")
|
| 149 |
else:
|
| 150 |
+
print("Failed to parse malformed table.")
|
| 151 |
+
|
| 152 |
+
no_table_text = "This is just some text without a table."
|
| 153 |
+
parsed_no_table = tool_instance.forward(no_table_text)
|
| 154 |
+
print("\\nParsed Text Without Table:")
|
| 155 |
+
if parsed_no_table:
|
| 156 |
+
print("Error: Should not have parsed a table.")
|
| 157 |
+
else:
|
| 158 |
+
print("Correctly found no table.")
|
src/python_tool.py
CHANGED
|
@@ -6,21 +6,31 @@ import re
|
|
| 6 |
import traceback
|
| 7 |
from typing import Dict, Any, Optional, Union, List
|
| 8 |
from smolagents.tools import Tool
|
|
|
|
| 9 |
|
| 10 |
class CodeExecutionTool(Tool):
|
| 11 |
"""
|
| 12 |
Executes Python code in a controlled environment for safe code interpretation.
|
| 13 |
Useful for evaluating code snippets and returning their output or errors.
|
| 14 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
def __init__(self, timeout: int =
|
| 17 |
-
|
|
|
|
| 18 |
self.max_output_size = max_output_size
|
| 19 |
-
# Restricted imports - add more as needed
|
| 20 |
self.banned_modules = [
|
| 21 |
-
'os', 'subprocess', 'sys', 'builtins', 'importlib',
|
| 22 |
-
'pickle', 'requests', 'socket', 'shutil'
|
| 23 |
]
|
|
|
|
| 24 |
|
| 25 |
def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
|
| 26 |
"""Perform static analysis to check for potentially harmful code."""
|
|
@@ -33,9 +43,11 @@ class CodeExecutionTool(Tool):
|
|
| 33 |
if isinstance(node, ast.Import):
|
| 34 |
imports.extend(n.name for n in node.names)
|
| 35 |
elif isinstance(node, ast.ImportFrom):
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
dangerous_imports = [imp for imp in imports if any(
|
| 39 |
banned in imp for banned in self.banned_modules)]
|
| 40 |
|
| 41 |
if dangerous_imports:
|
|
@@ -83,141 +95,187 @@ class CodeExecutionTool(Tool):
|
|
| 83 |
|
| 84 |
return None
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
"success": False,
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
safety_check = self._analyze_code_safety(code)
|
| 106 |
if not safety_check["safe"]:
|
| 107 |
-
return {
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
# Prepare a clean globals dictionary with minimal safe functions
|
| 113 |
-
safe_globals = {
|
| 114 |
-
'abs': abs,
|
| 115 |
-
'all': all,
|
| 116 |
-
'any': any,
|
| 117 |
-
'bin': bin,
|
| 118 |
-
'bool': bool,
|
| 119 |
-
'chr': chr,
|
| 120 |
-
'complex': complex,
|
| 121 |
-
'dict': dict,
|
| 122 |
-
'divmod': divmod,
|
| 123 |
-
'enumerate': enumerate,
|
| 124 |
-
'filter': filter,
|
| 125 |
-
'float': float,
|
| 126 |
-
'format': format,
|
| 127 |
-
'frozenset': frozenset,
|
| 128 |
-
'hash': hash,
|
| 129 |
-
'hex': hex,
|
| 130 |
-
'int': int,
|
| 131 |
-
'isinstance': isinstance,
|
| 132 |
-
'issubclass': issubclass,
|
| 133 |
-
'len': len,
|
| 134 |
-
'list': list,
|
| 135 |
-
'map': map,
|
| 136 |
-
'max': max,
|
| 137 |
-
'min': min,
|
| 138 |
-
'oct': oct,
|
| 139 |
-
'ord': ord,
|
| 140 |
-
'pow': pow,
|
| 141 |
-
'print': print,
|
| 142 |
-
'range': range,
|
| 143 |
-
'reversed': reversed,
|
| 144 |
-
'round': round,
|
| 145 |
-
'set': set,
|
| 146 |
-
'sorted': sorted,
|
| 147 |
-
'str': str,
|
| 148 |
-
'sum': sum,
|
| 149 |
-
'tuple': tuple,
|
| 150 |
-
'zip': zip,
|
| 151 |
-
'__builtins__': {}, # Empty builtins for extra security
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
# Add math module functions, commonly needed
|
| 155 |
-
try:
|
| 156 |
-
import math
|
| 157 |
-
for name in dir(math):
|
| 158 |
-
if not name.startswith('_'):
|
| 159 |
-
safe_globals[name] = getattr(math, name)
|
| 160 |
-
except ImportError:
|
| 161 |
-
pass
|
| 162 |
-
|
| 163 |
-
# Capture output using StringIO
|
| 164 |
-
output_buffer = io.StringIO()
|
| 165 |
-
|
| 166 |
-
# Set timeout handler
|
| 167 |
-
old_handler = signal.getsignal(signal.SIGALRM)
|
| 168 |
signal.signal(signal.SIGALRM, self._timeout_handler)
|
| 169 |
signal.alarm(self.timeout)
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
try:
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
exec(code, safe_globals)
|
| 176 |
|
| 177 |
-
output =
|
| 178 |
if len(output) > self.max_output_size:
|
| 179 |
-
|
| 180 |
-
output = output[:self.max_output_size - len(truncation_message)] + truncation_message
|
| 181 |
-
else:
|
| 182 |
-
output = output.strip()
|
| 183 |
|
| 184 |
-
#
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
return {
|
| 188 |
-
"success": True,
|
| 189 |
-
"
|
| 190 |
-
"numeric_value": numeric_result
|
| 191 |
-
"has_numeric_result": numeric_result is not None
|
| 192 |
}
|
| 193 |
-
|
| 194 |
except TimeoutError:
|
| 195 |
-
return {
|
| 196 |
-
"success": False,
|
| 197 |
-
"error": f"Code execution timed out after {self.timeout} seconds"
|
| 198 |
-
}
|
| 199 |
except Exception as e:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
"
|
| 205 |
-
|
| 206 |
-
}
|
| 207 |
finally:
|
| 208 |
-
#
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
if __name__ == "__main__":
|
| 215 |
-
executor = CodeExecutionTool()
|
| 216 |
-
result = executor.execute_code("""
|
| 217 |
-
# Example code that calculates a value
|
| 218 |
-
total = 0
|
| 219 |
-
for i in range(10):
|
| 220 |
-
total += i * 2
|
| 221 |
-
print(f"The result is {total}")
|
| 222 |
-
""")
|
| 223 |
-
print(result)
|
|
|
|
| 6 |
import traceback
|
| 7 |
from typing import Dict, Any, Optional, Union, List
|
| 8 |
from smolagents.tools import Tool
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
class CodeExecutionTool(Tool):
|
| 12 |
"""
|
| 13 |
Executes Python code in a controlled environment for safe code interpretation.
|
| 14 |
Useful for evaluating code snippets and returning their output or errors.
|
| 15 |
"""
|
| 16 |
+
name = "python_code_executor"
|
| 17 |
+
description = "Executes a given Python code string or Python code from a file. Returns the output or error."
|
| 18 |
+
inputs = {
|
| 19 |
+
'code_string': {'type': 'string', 'description': 'The Python code to execute directly.', 'nullable': True},
|
| 20 |
+
'filepath': {'type': 'string', 'description': 'The path to a Python file to execute.', 'nullable': True}
|
| 21 |
+
}
|
| 22 |
+
outputs = {'result': {'type': 'object', 'description': 'A dictionary containing \'success\', \'output\', and/or \'error\'.'}}
|
| 23 |
+
output_type = "object"
|
| 24 |
|
| 25 |
+
def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs):
|
| 26 |
+
super().__init__(*args, **kwargs)
|
| 27 |
+
self.timeout = timeout
|
| 28 |
self.max_output_size = max_output_size
|
|
|
|
| 29 |
self.banned_modules = [
|
| 30 |
+
'os', 'subprocess', 'sys', 'builtins', 'importlib',
|
| 31 |
+
'pickle', 'requests', 'socket', 'shutil', 'ctypes', 'multiprocessing'
|
| 32 |
]
|
| 33 |
+
self.is_initialized = True
|
| 34 |
|
| 35 |
def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
|
| 36 |
"""Perform static analysis to check for potentially harmful code."""
|
|
|
|
| 43 |
if isinstance(node, ast.Import):
|
| 44 |
imports.extend(n.name for n in node.names)
|
| 45 |
elif isinstance(node, ast.ImportFrom):
|
| 46 |
+
# Ensure node.module is not None before attempting to check against banned_modules
|
| 47 |
+
if node.module and any(banned in node.module for banned in self.banned_modules):
|
| 48 |
+
imports.append(node.module)
|
| 49 |
|
| 50 |
+
dangerous_imports = [imp for imp in imports if imp and any(
|
| 51 |
banned in imp for banned in self.banned_modules)]
|
| 52 |
|
| 53 |
if dangerous_imports:
|
|
|
|
| 95 |
|
| 96 |
return None
|
| 97 |
|
| 98 |
+
# Main entry point for the agent
|
| 99 |
+
def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]:
|
| 100 |
+
if not code_string and not filepath:
|
| 101 |
+
return {"success": False, "error": "No code string or filepath provided."}
|
| 102 |
+
if code_string and filepath:
|
| 103 |
+
return {"success": False, "error": "Provide either a code string or a filepath, not both."}
|
| 104 |
+
|
| 105 |
+
code_to_execute = ""
|
| 106 |
+
if filepath:
|
| 107 |
+
if not os.path.exists(filepath):
|
| 108 |
+
return {"success": False, "error": f"File not found: {filepath}"}
|
| 109 |
+
if not filepath.endswith(".py"):
|
| 110 |
+
return {"success": False, "error": f"File is not a Python file: {filepath}"}
|
| 111 |
+
try:
|
| 112 |
+
with open(filepath, 'r') as file:
|
| 113 |
+
code_to_execute = file.read()
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
|
| 116 |
+
elif code_string:
|
| 117 |
+
code_to_execute = code_string
|
| 118 |
+
|
| 119 |
+
return self._execute_actual_code(code_to_execute)
|
| 120 |
+
|
| 121 |
+
# Renamed from execute_code to _execute_actual_code to be internal
|
| 122 |
+
def _execute_actual_code(self, code: str) -> Dict[str, Any]:
|
| 123 |
+
"""Execute Python code and capture the output or error."""
|
| 124 |
safety_check = self._analyze_code_safety(code)
|
| 125 |
if not safety_check["safe"]:
|
| 126 |
+
return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"}
|
| 127 |
+
|
| 128 |
+
# Setup timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
signal.signal(signal.SIGALRM, self._timeout_handler)
|
| 130 |
signal.alarm(self.timeout)
|
| 131 |
+
|
| 132 |
+
captured_output = io.StringIO()
|
| 133 |
+
# It's generally safer to execute in a restricted scope
|
| 134 |
+
# and not provide access to all globals/locals by default.
|
| 135 |
+
# However, for a tool that might need to define functions/classes and use them,
|
| 136 |
+
# a shared scope might be necessary. This needs careful consideration.
|
| 137 |
+
exec_globals = {}
|
| 138 |
+
|
| 139 |
try:
|
| 140 |
+
with contextlib.redirect_stdout(captured_output):
|
| 141 |
+
with contextlib.redirect_stderr(captured_output): # Capture stderr as well
|
| 142 |
+
exec(code, exec_globals) # Execute in a controlled global scope
|
|
|
|
| 143 |
|
| 144 |
+
output = captured_output.getvalue()
|
| 145 |
if len(output) > self.max_output_size:
|
| 146 |
+
output = output[:self.max_output_size] + "... [output truncated]"
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
# Attempt to extract a final numeric value if applicable
|
| 149 |
+
# This might be specific to certain tasks, consider making it optional
|
| 150 |
+
# numeric_result = self._extract_numeric_value(output)
|
| 151 |
|
| 152 |
return {
|
| 153 |
+
"success": True,
|
| 154 |
+
"output": output,
|
| 155 |
+
# "numeric_value": numeric_result
|
|
|
|
| 156 |
}
|
|
|
|
| 157 |
except TimeoutError:
|
| 158 |
+
return {"success": False, "error": "Code execution timed out"}
|
|
|
|
|
|
|
|
|
|
| 159 |
except Exception as e:
|
| 160 |
+
# Get detailed traceback
|
| 161 |
+
tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
|
| 162 |
+
error_details = "".join(tb_lines)
|
| 163 |
+
if len(error_details) > self.max_output_size:
|
| 164 |
+
error_details = error_details[:self.max_output_size] + "... [error truncated]"
|
| 165 |
+
return {"success": False, "error": f"Execution failed: {str(e)}\nTraceback:\n{error_details}"}
|
|
|
|
| 166 |
finally:
|
| 167 |
+
signal.alarm(0) # Disable the alarm
|
| 168 |
+
captured_output.close()
|
| 169 |
+
|
| 170 |
+
# Kept execute_file and execute_code as helper methods if direct access is ever needed,
|
| 171 |
+
# but they now call the main _execute_actual_code method.
|
| 172 |
+
def execute_file(self, filepath: str) -> Dict[str, Any]:
|
| 173 |
+
"""Helper to execute Python code from file."""
|
| 174 |
+
if not os.path.exists(filepath):
|
| 175 |
+
return {"success": False, "error": f"File not found: {filepath}"}
|
| 176 |
+
if not filepath.endswith(".py"):
|
| 177 |
+
return {"success": False, "error": f"File is not a Python file: {filepath}"}
|
| 178 |
+
try:
|
| 179 |
+
with open(filepath, 'r') as file:
|
| 180 |
+
code = file.read()
|
| 181 |
+
return self._execute_actual_code(code)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
|
| 184 |
+
|
| 185 |
+
def execute_code(self, code: str) -> Dict[str, Any]:
|
| 186 |
+
"""Helper to execute Python code from a string."""
|
| 187 |
+
return self._execute_actual_code(code)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == '__main__':
|
| 191 |
+
tool = CodeExecutionTool(timeout=5)
|
| 192 |
+
|
| 193 |
+
# Test 1: Safe code string
|
| 194 |
+
safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)"
|
| 195 |
+
print("\n--- Test 1: Safe Code String ---")
|
| 196 |
+
result1 = tool.forward(code_string=safe_code)
|
| 197 |
+
print(result1)
|
| 198 |
+
assert result1['success']
|
| 199 |
+
assert "Hello from safe code!" in result1['output']
|
| 200 |
+
assert "20" in result1['output']
|
| 201 |
+
|
| 202 |
+
# Test 2: Code with an error
|
| 203 |
+
error_code = "print(1/0)"
|
| 204 |
+
print("\n--- Test 2: Code with Error ---")
|
| 205 |
+
result2 = tool.forward(code_string=error_code)
|
| 206 |
+
print(result2)
|
| 207 |
+
assert not result2['success']
|
| 208 |
+
assert "ZeroDivisionError" in result2['error']
|
| 209 |
+
|
| 210 |
+
# Test 3: Code with a banned import
|
| 211 |
+
unsafe_import_code = "import os; print(os.getcwd())"
|
| 212 |
+
print("\n--- Test 3: Unsafe Import ---")
|
| 213 |
+
result3 = tool.forward(code_string=unsafe_import_code)
|
| 214 |
+
print(result3)
|
| 215 |
+
assert not result3['success']
|
| 216 |
+
assert "Safety check failed" in result3['error']
|
| 217 |
+
assert "os" in result3['error']
|
| 218 |
+
|
| 219 |
+
# Test 4: Timeout
|
| 220 |
+
timeout_code = "import time; time.sleep(10); print('Done sleeping')"
|
| 221 |
+
print("\n--- Test 4: Timeout ---")
|
| 222 |
+
# tool_timeout_short = CodeExecutionTool(timeout=2) # For testing timeout specifically
|
| 223 |
+
# result4 = tool_timeout_short.forward(code_string=timeout_code)
|
| 224 |
+
result4 = tool.forward(code_string=timeout_code) # Using the main tool instance with its timeout
|
| 225 |
+
print(result4)
|
| 226 |
+
assert not result4['success']
|
| 227 |
+
assert "timed out" in result4['error']
|
| 228 |
+
|
| 229 |
+
# Test 5: Execute from file
|
| 230 |
+
test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')"
|
| 231 |
+
test_filename = "temp_test_script.py"
|
| 232 |
+
with open(test_filename, "w") as f:
|
| 233 |
+
f.write(test_file_content)
|
| 234 |
+
print("\n--- Test 5: Execute from File ---")
|
| 235 |
+
result5 = tool.forward(filepath=test_filename)
|
| 236 |
+
print(result5)
|
| 237 |
+
assert result5['success']
|
| 238 |
+
assert "Hello from file!" in result5['output']
|
| 239 |
+
assert "Sum: 12" in result5['output']
|
| 240 |
+
os.remove(test_filename)
|
| 241 |
+
|
| 242 |
+
# Test 6: File not found
|
| 243 |
+
print("\n--- Test 6: File Not Found ---")
|
| 244 |
+
result6 = tool.forward(filepath="non_existent_script.py")
|
| 245 |
+
print(result6)
|
| 246 |
+
assert not result6['success']
|
| 247 |
+
assert "File not found" in result6['error']
|
| 248 |
+
|
| 249 |
+
# Test 7: Provide both code_string and filepath
|
| 250 |
+
print("\n--- Test 7: Both code_string and filepath ---")
|
| 251 |
+
result7 = tool.forward(code_string="print('hello')", filepath=test_filename)
|
| 252 |
+
print(result7)
|
| 253 |
+
assert not result7['success']
|
| 254 |
+
assert "Provide either a code string or a filepath, not both" in result7['error']
|
| 255 |
+
|
| 256 |
+
# Test 8: Provide neither
|
| 257 |
+
print("\n--- Test 8: Neither code_string nor filepath ---")
|
| 258 |
+
result8 = tool.forward()
|
| 259 |
+
print(result8)
|
| 260 |
+
assert not result8['success']
|
| 261 |
+
assert "No code string or filepath provided" in result8['error']
|
| 262 |
+
|
| 263 |
+
# Test 9: Code that defines a function and calls it
|
| 264 |
+
func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))"
|
| 265 |
+
print("\n--- Test 9: Function Definition and Call ---")
|
| 266 |
+
result9 = tool.forward(code_string=func_def_code)
|
| 267 |
+
print(result9)
|
| 268 |
+
assert result9['success']
|
| 269 |
+
assert "7" in result9['output']
|
| 270 |
+
|
| 271 |
+
# Test 10: Max output size
|
| 272 |
+
# tool_max_output = CodeExecutionTool(max_output_size=50)
|
| 273 |
+
# long_output_code = "for i in range(20): print(f'Line {i}')"
|
| 274 |
+
# print("\n--- Test 10: Max Output Size ---")
|
| 275 |
+
# result10 = tool_max_output.forward(code_string=long_output_code)
|
| 276 |
+
# print(result10)
|
| 277 |
+
# assert result10['success']
|
| 278 |
+
# assert "... [output truncated]" in result10['output']
|
| 279 |
+
# assert len(result10['output']) <= 50 + len("... [output truncated]") + 5 # a bit of leeway
|
| 280 |
|
| 281 |
+
print("\nAll tests seem to have passed (check output for details).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/speech_to_text.py
CHANGED
|
@@ -1,35 +1,134 @@
|
|
| 1 |
from transformers import pipeline
|
| 2 |
-
import librosa
|
| 3 |
import os
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
# Initialize the ASR pipeline
|
| 6 |
-
|
| 7 |
-
asr_pipeline = pipeline(
|
| 8 |
-
"automatic-speech-recognition",
|
| 9 |
-
model="openai/whisper-tiny.en",
|
| 10 |
-
)
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
Converts speech in an audio file
|
| 15 |
Args:
|
| 16 |
audio_filepath (str): Path to the audio file.
|
|
|
|
| 17 |
Returns:
|
| 18 |
-
str: Transcribed text from the audio.
|
| 19 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
try:
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
-
return f"Error during transcription: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Example usage:
|
| 27 |
if __name__ == "__main__":
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
if os.path.exists(audio_file): # Check if the (placeholder or real) file exists
|
| 31 |
-
print(f"Attempting to transcribe: {audio_file}")
|
| 32 |
-
transcribed_text = transcribe_audio(audio_file)
|
| 33 |
-
print(f"Transcription:\n{transcribed_text}")
|
| 34 |
else:
|
| 35 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import pipeline
|
| 2 |
+
import librosa # Or soundfile
|
| 3 |
import os
|
| 4 |
+
from smolagents.tools import Tool # Added import
|
| 5 |
+
from typing import Optional # Added for type hinting
|
| 6 |
|
| 7 |
+
# Initialize the ASR pipeline once
|
| 8 |
+
_asr_pipeline_instance = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
|
| 11 |
+
def get_asr_pipeline():
|
| 12 |
+
global _asr_pipeline_instance
|
| 13 |
+
if _asr_pipeline_instance is None:
|
| 14 |
+
try:
|
| 15 |
+
# Using a smaller Whisper model for quicker setup, but larger models offer better accuracy
|
| 16 |
+
_asr_pipeline_instance = pipeline(
|
| 17 |
+
"automatic-speech-recognition",
|
| 18 |
+
model="openai/whisper-tiny.en", # Consider making model configurable
|
| 19 |
+
)
|
| 20 |
+
print("ASR pipeline initialized.") # For feedback
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Error initializing ASR pipeline: {e}")
|
| 23 |
+
# Handle error appropriately, e.g., raise or log
|
| 24 |
+
return _asr_pipeline_instance
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Original transcription function, renamed to be internal
|
| 28 |
+
def _transcribe_audio_file(audio_filepath: str, asr_pipeline_instance) -> str:
|
| 29 |
"""
|
| 30 |
+
Converts speech in an audio file to text using the provided ASR pipeline.
|
| 31 |
Args:
|
| 32 |
audio_filepath (str): Path to the audio file.
|
| 33 |
+
asr_pipeline_instance: The initialized ASR pipeline.
|
| 34 |
Returns:
|
| 35 |
+
str: Transcribed text from the audio or an error message.
|
| 36 |
"""
|
| 37 |
+
if not asr_pipeline_instance:
|
| 38 |
+
return "Error: ASR pipeline is not available."
|
| 39 |
+
if not os.path.exists(audio_filepath):
|
| 40 |
+
return f"Error: Audio file not found at {audio_filepath}"
|
| 41 |
try:
|
| 42 |
+
# Ensure the file can be loaded by librosa (or your chosen audio library)
|
| 43 |
+
# This step can help catch corrupted or unsupported audio formats early.
|
| 44 |
+
y, sr = librosa.load(audio_filepath, sr=None) # Load with original sample rate
|
| 45 |
+
if sr != 16000: # Whisper models expect 16kHz
|
| 46 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
|
| 47 |
+
|
| 48 |
+
# Pass the numpy array to the pipeline
|
| 49 |
+
transcription_result = asr_pipeline_instance(
|
| 50 |
+
{"raw": y, "sampling_rate": 16000}, return_timestamps=False
|
| 51 |
+
) # Changed to False for simplicity
|
| 52 |
+
return transcription_result["text"]
|
| 53 |
except Exception as e:
|
| 54 |
+
return f"Error during transcription of {audio_filepath}: {e}"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SpeechToTextTool(Tool):
|
| 58 |
+
"""
|
| 59 |
+
Transcribes audio from a given audio file path to text.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
name = "speech_to_text_transcriber"
|
| 63 |
+
description = "Converts speech in an audio file (e.g., .mp3, .wav) to text using speech recognition."
|
| 64 |
+
inputs = {
|
| 65 |
+
"audio_filepath": {"type": "string", "description": "Path to the audio file to transcribe."}
|
| 66 |
+
}
|
| 67 |
+
outputs = {
|
| 68 |
+
"transcribed_text": {
|
| 69 |
+
"type": "string",
|
| 70 |
+
"description": "The transcribed text from the audio, or an error message.",
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
output_type = "string"
|
| 74 |
+
|
| 75 |
+
def __init__(self, *args, **kwargs):
|
| 76 |
+
super().__init__(*args, **kwargs)
|
| 77 |
+
self.asr_pipeline = get_asr_pipeline() # Initialize or get the shared pipeline
|
| 78 |
+
self.is_initialized = True if self.asr_pipeline else False
|
| 79 |
+
|
| 80 |
+
def forward(self, audio_filepath: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Wrapper for the _transcribe_audio_file function.
|
| 83 |
+
"""
|
| 84 |
+
if not self.is_initialized or not self.asr_pipeline:
|
| 85 |
+
return "Error: SpeechToTextTool was not initialized properly (ASR pipeline missing)."
|
| 86 |
+
return _transcribe_audio_file(audio_filepath, self.asr_pipeline)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Expose the original function name if needed by other parts of the system (optional)
|
| 90 |
+
# transcribe_audio = _transcribe_audio_file # This would need adjustment if it expects the pipeline passed in
|
| 91 |
|
| 92 |
# Example usage:
|
| 93 |
if __name__ == "__main__":
|
| 94 |
+
tool_instance = SpeechToTextTool()
|
| 95 |
+
|
| 96 |
+
# Create a dummy MP3 file for testing (requires ffmpeg to be installed for pydub to work)
|
| 97 |
+
# This part is tricky to make universally runnable without external dependencies for audio creation.
|
| 98 |
+
# For a simple test, we'll assume a file exists or skip this part if it doesn't.
|
| 99 |
+
|
| 100 |
+
# Path to a test audio file (replace with an actual .mp3 or .wav file for testing)
|
| 101 |
+
# You might need to download a short sample audio file and place it in your project.
|
| 102 |
+
# e.g., create a `test_data` directory and put `sample.mp3` there.
|
| 103 |
+
test_audio_file = "./data/downloaded_files/1f975693-876d-457b-a649-393859e79bf3.mp3" # GAIA example
|
| 104 |
+
# test_audio_file_2 = "./data/downloaded_files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3.mp3" # GAIA example
|
| 105 |
+
|
| 106 |
+
if tool_instance.is_initialized:
|
| 107 |
+
if os.path.exists(test_audio_file):
|
| 108 |
+
print(f"Attempting to transcribe: {test_audio_file}")
|
| 109 |
+
transcribed_text = tool_instance.forward(test_audio_file)
|
| 110 |
+
print(f"Transcription:\n{transcribed_text}")
|
| 111 |
+
else:
|
| 112 |
+
print(
|
| 113 |
+
f"Test audio file not found: {test_audio_file}. Skipping transcription test."
|
| 114 |
+
)
|
| 115 |
+
print("Please place a sample .mp3 or .wav file at that location for testing.")
|
| 116 |
+
|
| 117 |
+
# if os.path.exists(test_audio_file_2):
|
| 118 |
+
# print(f"\nAttempting to transcribe: {test_audio_file_2}")
|
| 119 |
+
# transcribed_text_2 = tool_instance.forward(test_audio_file_2)
|
| 120 |
+
# print(f"Transcription 2:\n{transcribed_text_2}")
|
| 121 |
+
# else:
|
| 122 |
+
# print(f"Test audio file 2 not found: {test_audio_file_2}. Skipping.")
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
else:
|
| 125 |
+
print(
|
| 126 |
+
"SpeechToTextTool could not be initialized (ASR pipeline missing). Transcription test skipped."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Test with a non-existent file
|
| 130 |
+
non_existent_file = "./non_existent_audio.mp3"
|
| 131 |
+
print(f"\nAttempting to transcribe non-existent file: {non_existent_file}")
|
| 132 |
+
error_text = tool_instance.forward(non_existent_file)
|
| 133 |
+
print(f"Result for non-existent file:\n{error_text}")
|
| 134 |
+
assert "Error:" in error_text # Expect an error message
|
src/spreadsheet_tool.py
CHANGED
|
@@ -1,59 +1,83 @@
|
|
| 1 |
import os
|
| 2 |
import pandas as pd
|
| 3 |
-
from typing import Dict, List, Union, Tuple, Any
|
| 4 |
import numpy as np
|
| 5 |
from smolagents.tools import Tool
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
class SpreadsheetTool(Tool):
|
| 10 |
"""
|
| 11 |
-
Parses spreadsheet files (e.g., .xlsx) and extracts tabular data for analysis.
|
| 12 |
Useful for reading, processing, and converting spreadsheet content to Python data structures.
|
| 13 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
def __init__(self):
|
| 16 |
"""Initialize the SpreadsheetTool."""
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
Parse an Excel spreadsheet and extract useful information.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
file_path: Path to the .xlsx file
|
| 25 |
-
|
| 26 |
-
Returns:
|
| 27 |
-
Dictionary containing:
|
| 28 |
-
- sheets: Dictionary of sheet names and their DataFrames
|
| 29 |
-
- sheet_names: List of sheet names
|
| 30 |
-
- summary: Basic spreadsheet summary
|
| 31 |
-
- error: Error message if any
|
| 32 |
-
"""
|
| 33 |
if not os.path.exists(file_path):
|
| 34 |
return {"error": f"File not found: {file_path}"}
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
-
# Read all sheets in the Excel file
|
| 38 |
excel_file = pd.ExcelFile(file_path)
|
| 39 |
sheet_names = excel_file.sheet_names
|
| 40 |
sheets = {}
|
| 41 |
-
|
| 42 |
for sheet_name in sheet_names:
|
| 43 |
sheets[sheet_name] = pd.read_excel(excel_file, sheet_name=sheet_name)
|
| 44 |
-
|
| 45 |
-
# Create a summary of the spreadsheet
|
| 46 |
summary = self._create_summary(sheets)
|
| 47 |
-
|
| 48 |
-
return {
|
| 49 |
-
"sheets": sheets,
|
| 50 |
-
"sheet_names": sheet_names,
|
| 51 |
-
"summary": summary,
|
| 52 |
-
"error": None
|
| 53 |
-
}
|
| 54 |
except Exception as e:
|
| 55 |
-
return {"error": f"Error parsing spreadsheet: {str(e)}"}
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def _create_summary(self, sheets_dict: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
|
| 58 |
"""Create a summary of the spreadsheet contents."""
|
| 59 |
summary = {}
|
|
@@ -70,179 +94,113 @@ class SpreadsheetTool(Tool):
|
|
| 70 |
|
| 71 |
return summary
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
"""
|
| 75 |
Execute a query on the spreadsheet data based on instructions.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
data: The parsed spreadsheet data (from parse_spreadsheet)
|
| 79 |
-
query_instructions: Instructions for querying the data (e.g., "Sum column A")
|
| 80 |
-
|
| 81 |
-
Returns:
|
| 82 |
-
Dictionary with query results and potential explanation
|
| 83 |
"""
|
| 84 |
-
if
|
| 85 |
-
return {"error":
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if "sum" in query_instructions.lower():
|
| 96 |
-
# Extract column or range to sum
|
| 97 |
-
# This is a simple implementation - a more robust one would use regex or NLP
|
| 98 |
-
for sheet_name, df in sheets.items():
|
| 99 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 100 |
-
if not numeric_cols.empty:
|
| 101 |
-
result[f"{sheet_name}_sums"] = {
|
| 102 |
-
col: df[col].sum() for col in numeric_cols
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
elif "average" in query_instructions.lower() or "mean" in query_instructions.lower():
|
| 106 |
-
for sheet_name, df in sheets.items():
|
| 107 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 108 |
-
if not numeric_cols.empty:
|
| 109 |
-
result[f"{sheet_name}_averages"] = {
|
| 110 |
-
col: df[col].mean() for col in numeric_cols
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
elif "count" in query_instructions.lower():
|
| 114 |
-
for sheet_name, df in sheets.items():
|
| 115 |
-
result[f"{sheet_name}_counts"] = {
|
| 116 |
-
"rows": len(df),
|
| 117 |
-
"non_null_counts": df.count().to_dict()
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
# Add the raw data structure for more custom processing by the agent
|
| 121 |
-
result["data_structure"] = {
|
| 122 |
-
sheet_name: {
|
| 123 |
-
"columns": df.columns.tolist(),
|
| 124 |
-
"dtypes": df.dtypes.astype(str).to_dict()
|
| 125 |
-
} for sheet_name, df in sheets.items()
|
| 126 |
-
}
|
| 127 |
-
|
| 128 |
-
return result
|
| 129 |
-
|
| 130 |
-
except Exception as e:
|
| 131 |
-
return {"error": f"Error querying data: {str(e)}"}
|
| 132 |
-
|
| 133 |
-
def extract_specific_data(self, data: Dict[str, Any], sheet_name: str = None,
|
| 134 |
-
column_names: List[str] = None,
|
| 135 |
-
row_indices: List[int] = None) -> Dict[str, Any]:
|
| 136 |
-
"""
|
| 137 |
-
Extract specific data from the spreadsheet.
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
column_names: List of column names to extract (default: all columns)
|
| 143 |
-
row_indices: List of row indices to extract (default: all rows)
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
Dictionary with extracted data
|
| 147 |
-
"""
|
| 148 |
-
if data.get("error"):
|
| 149 |
-
return {"error": data["error"]}
|
| 150 |
|
| 151 |
-
try
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
missing_columns = [col for col in column_names if col not in df.columns]
|
| 167 |
-
if missing_columns:
|
| 168 |
-
return {"error": f"Columns not found: {missing_columns}"}
|
| 169 |
-
df = df[column_names]
|
| 170 |
-
|
| 171 |
-
# Filter rows if specified
|
| 172 |
-
if row_indices:
|
| 173 |
-
# Check if indices are in range
|
| 174 |
-
max_index = len(df) - 1
|
| 175 |
-
invalid_indices = [i for i in row_indices if i < 0 or i > max_index]
|
| 176 |
-
if invalid_indices:
|
| 177 |
-
return {"error": f"Row indices out of range: {invalid_indices}. Valid range: 0-{max_index}"}
|
| 178 |
-
df = df.iloc[row_indices]
|
| 179 |
-
|
| 180 |
-
return {
|
| 181 |
-
"data": df.to_dict('records'),
|
| 182 |
-
"shape": df.shape
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
except Exception as e:
|
| 186 |
-
return {"error": f"Error extracting specific data: {str(e)}"}
|
| 187 |
|
|
|
|
| 188 |
|
| 189 |
-
# Example usage (
|
| 190 |
-
if __name__ ==
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
'
|
| 200 |
-
'
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
print(
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
print("\
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
# Test specific data extraction
|
| 242 |
-
print("\nExtracting specific data...")
|
| 243 |
-
extract_result = spreadsheet_tool.extract_specific_data(
|
| 244 |
-
parsed_data,
|
| 245 |
-
sheet_name='Sales',
|
| 246 |
-
column_names=['Product', 'Revenue']
|
| 247 |
-
)
|
| 248 |
-
print(f"Extracted data: {extract_result}")
|
|
|
|
| 1 |
import os
|
| 2 |
import pandas as pd
|
| 3 |
+
from typing import Dict, List, Union, Tuple, Any, Optional
|
| 4 |
import numpy as np
|
| 5 |
from smolagents.tools import Tool
|
| 6 |
|
|
|
|
|
|
|
| 7 |
class SpreadsheetTool(Tool):
|
| 8 |
"""
|
| 9 |
+
Parses spreadsheet files (e.g., .xlsx) and extracts tabular data for analysis or allows querying.
|
| 10 |
Useful for reading, processing, and converting spreadsheet content to Python data structures.
|
| 11 |
"""
|
| 12 |
+
name = "spreadsheet_processor"
|
| 13 |
+
description = "Parses a spreadsheet file (e.g., .xlsx, .xls, .csv) and can perform queries. Returns extracted data or query results."
|
| 14 |
+
inputs = {
|
| 15 |
+
'file_path': {'type': 'string', 'description': 'Path to the spreadsheet file.'},
|
| 16 |
+
'query_instructions': {'type': 'string', 'description': 'Optional. Instructions for querying the data (e.g., "Sum column A"). If None, parses the whole sheet.', 'nullable': True}
|
| 17 |
+
}
|
| 18 |
+
outputs = {'result': {'type': 'object', 'description': 'A dictionary containing parsed sheet data, query results, or an error message.'}}
|
| 19 |
+
output_type = "object"
|
| 20 |
|
| 21 |
+
def __init__(self, *args, **kwargs):
|
| 22 |
"""Initialize the SpreadsheetTool."""
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self.is_initialized = True
|
| 25 |
|
| 26 |
+
# Main entry point for the agent
|
| 27 |
+
def forward(self, file_path: str, query_instructions: Optional[str] = None) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if not os.path.exists(file_path):
|
| 29 |
return {"error": f"File not found: {file_path}"}
|
| 30 |
+
|
| 31 |
+
# Determine file type for appropriate parsing
|
| 32 |
+
_, file_extension = os.path.splitext(file_path)
|
| 33 |
+
file_extension = file_extension.lower()
|
| 34 |
+
|
| 35 |
+
parsed_data = None
|
| 36 |
+
if file_extension in ['.xlsx', '.xls']:
|
| 37 |
+
parsed_data = self._parse_excel(file_path)
|
| 38 |
+
elif file_extension == '.csv':
|
| 39 |
+
parsed_data = self._parse_csv(file_path)
|
| 40 |
+
else:
|
| 41 |
+
return {"error": f"Unsupported file type: {file_extension}. Supported types: .xlsx, .xls, .csv"}
|
| 42 |
+
|
| 43 |
+
if parsed_data.get("error"):
|
| 44 |
+
return parsed_data # Return error from parsing step
|
| 45 |
+
|
| 46 |
+
if query_instructions:
|
| 47 |
+
return self._query_data(parsed_data, query_instructions)
|
| 48 |
+
else:
|
| 49 |
+
# If no query, return the parsed data and summary
|
| 50 |
+
return {
|
| 51 |
+
"parsed_sheets": parsed_data.get("sheets"),
|
| 52 |
+
"summary": parsed_data.get("summary"),
|
| 53 |
+
"message": "Spreadsheet parsed successfully."
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def _parse_excel(self, file_path: str) -> Dict[str, Any]:
|
| 57 |
+
"""Parse an Excel spreadsheet and extract useful information."""
|
| 58 |
try:
|
|
|
|
| 59 |
excel_file = pd.ExcelFile(file_path)
|
| 60 |
sheet_names = excel_file.sheet_names
|
| 61 |
sheets = {}
|
|
|
|
| 62 |
for sheet_name in sheet_names:
|
| 63 |
sheets[sheet_name] = pd.read_excel(excel_file, sheet_name=sheet_name)
|
|
|
|
|
|
|
| 64 |
summary = self._create_summary(sheets)
|
| 65 |
+
return {"sheets": sheets, "sheet_names": sheet_names, "summary": summary, "error": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
except Exception as e:
|
| 67 |
+
return {"error": f"Error parsing Excel spreadsheet: {str(e)}"}
|
| 68 |
+
|
| 69 |
+
def _parse_csv(self, file_path: str) -> Dict[str, Any]:
|
| 70 |
+
"""Parse a CSV file."""
|
| 71 |
+
try:
|
| 72 |
+
df = pd.read_csv(file_path)
|
| 73 |
+
# CSVs don't have multiple sheets, so we adapt the structure
|
| 74 |
+
sheet_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 75 |
+
sheets = {sheet_name: df}
|
| 76 |
+
summary = self._create_summary(sheets)
|
| 77 |
+
return {"sheets": sheets, "sheet_names": [sheet_name], "summary": summary, "error": None}
|
| 78 |
+
except Exception as e:
|
| 79 |
+
return {"error": f"Error parsing CSV file: {str(e)}"}
|
| 80 |
+
|
| 81 |
def _create_summary(self, sheets_dict: Dict[str, pd.DataFrame]) -> Dict[str, Any]:
|
| 82 |
"""Create a summary of the spreadsheet contents."""
|
| 83 |
summary = {}
|
|
|
|
| 94 |
|
| 95 |
return summary
|
| 96 |
|
| 97 |
+
# Renamed from query_data to _query_data and adjusted arguments
|
| 98 |
+
def _query_data(self, parsed_data_dict: Dict[str, Any], query_instructions: str) -> Dict[str, Any]:
|
| 99 |
"""
|
| 100 |
Execute a query on the spreadsheet data based on instructions.
|
| 101 |
+
This is a simplified placeholder. Real implementation would need robust query parsing.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""
|
| 103 |
+
if parsed_data_dict.get("error"):
|
| 104 |
+
return {"error": parsed_data_dict["error"]}
|
| 105 |
|
| 106 |
+
sheets = parsed_data_dict.get("sheets")
|
| 107 |
+
if not sheets:
|
| 108 |
+
return {"error": "No sheets data available for querying."}
|
| 109 |
+
|
| 110 |
+
# Placeholder for actual query logic.
|
| 111 |
+
# This would involve parsing `query_instructions` (e.g., using regex, NLP, or a DSL)
|
| 112 |
+
# and applying pandas operations.
|
| 113 |
+
# For now, let's return a message indicating the query was received and basic info.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
results = {}
|
| 116 |
+
explanation = f"Query instruction received: '{query_instructions}'. Advanced query execution is not fully implemented. " \
|
| 117 |
+
f"Returning summary of available sheets: {list(sheets.keys())}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
# Example: if query asks for sum, try to sum first numeric column of first sheet
|
| 120 |
+
if "sum" in query_instructions.lower():
|
| 121 |
+
first_sheet_name = next(iter(sheets))
|
| 122 |
+
df = sheets[first_sheet_name]
|
| 123 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 124 |
+
if not numeric_cols.empty:
|
| 125 |
+
col_to_sum = numeric_cols[0]
|
| 126 |
+
try:
|
| 127 |
+
total_sum = df[col_to_sum].sum()
|
| 128 |
+
results[f'{first_sheet_name}_{col_to_sum}_sum'] = total_sum
|
| 129 |
+
explanation += f" Example sum of column '{col_to_sum}' in sheet '{first_sheet_name}': {total_sum}."
|
| 130 |
+
except Exception as e:
|
| 131 |
+
explanation += f" Could not perform example sum: {e}."
|
| 132 |
+
else:
|
| 133 |
+
explanation += " No numeric columns found for example sum."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
return {"query_results": results, "explanation": explanation, "original_query": query_instructions}
|
| 136 |
|
| 137 |
+
# Example usage (for direct testing)
|
| 138 |
+
if __name__ == '__main__':
|
| 139 |
+
tool = SpreadsheetTool()
|
| 140 |
+
|
| 141 |
+
# Create dummy files for testing
|
| 142 |
+
dummy_excel_file = "dummy_test.xlsx"
|
| 143 |
+
dummy_csv_file = "dummy_test.csv"
|
| 144 |
+
|
| 145 |
+
# Create a dummy Excel file
|
| 146 |
+
df_excel = pd.DataFrame({
|
| 147 |
+
'colA': [1, 2, 3, 4, 5],
|
| 148 |
+
'colB': ['apple', 'banana', 'cherry', 'date', 'elderberry'],
|
| 149 |
+
'colC': [10.1, 20.2, 30.3, 40.4, 50.5]
|
| 150 |
+
})
|
| 151 |
+
with pd.ExcelWriter(dummy_excel_file) as writer:
|
| 152 |
+
df_excel.to_excel(writer, sheet_name='Sheet1', index=False)
|
| 153 |
+
df_excel.head(2).to_excel(writer, sheet_name='Sheet2', index=False)
|
| 154 |
+
|
| 155 |
+
# Create a dummy CSV file
|
| 156 |
+
df_csv = pd.DataFrame({
|
| 157 |
+
'id': [101, 102, 103],
|
| 158 |
+
'product': ['widget', 'gadget', 'gizmo'],
|
| 159 |
+
'price': [19.99, 29.50, 15.00]
|
| 160 |
+
})
|
| 161 |
+
df_csv.to_csv(dummy_csv_file, index=False)
|
| 162 |
+
|
| 163 |
+
print("--- Test 1: Parse Excel file (no query) ---")
|
| 164 |
+
result1 = tool.forward(file_path=dummy_excel_file)
|
| 165 |
+
print(result1)
|
| 166 |
+
assert "error" not in result1 or result1["error"] is None
|
| 167 |
+
assert "Sheet1" in result1["parsed_sheets"]
|
| 168 |
+
|
| 169 |
+
print("\n--- Test 2: Parse CSV file (no query) ---")
|
| 170 |
+
result2 = tool.forward(file_path=dummy_csv_file)
|
| 171 |
+
print(result2)
|
| 172 |
+
assert "error" not in result2 or result2["error"] is None
|
| 173 |
+
assert dummy_csv_file.split('.')[0] in result2["parsed_sheets"]
|
| 174 |
+
|
| 175 |
+
print("\n--- Test 3: Query Excel file (simple sum example) ---")
|
| 176 |
+
result3 = tool.forward(file_path=dummy_excel_file, query_instructions="sum colA from Sheet1")
|
| 177 |
+
print(result3)
|
| 178 |
+
assert "error" not in result3 or result3["error"] is None
|
| 179 |
+
assert "query_results" in result3
|
| 180 |
+
if result3.get("query_results"):
|
| 181 |
+
assert "Sheet1_colA_sum" in result3["query_results"]
|
| 182 |
+
assert result3["query_results"]["Sheet1_colA_sum"] == 15
|
| 183 |
+
|
| 184 |
+
print("\n--- Test 4: File not found ---")
|
| 185 |
+
result4 = tool.forward(file_path="non_existent_file.xlsx")
|
| 186 |
+
print(result4)
|
| 187 |
+
assert result4["error"] is not None
|
| 188 |
+
assert "File not found" in result4["error"]
|
| 189 |
|
| 190 |
+
print("\n--- Test 5: Unsupported file type ---")
|
| 191 |
+
dummy_txt_file = "dummy_test.txt"
|
| 192 |
+
with open(dummy_txt_file, "w") as f:
|
| 193 |
+
f.write("hello")
|
| 194 |
+
result5 = tool.forward(file_path=dummy_txt_file)
|
| 195 |
+
print(result5)
|
| 196 |
+
assert result5["error"] is not None
|
| 197 |
+
assert "Unsupported file type" in result5["error"]
|
| 198 |
+
os.remove(dummy_txt_file)
|
| 199 |
+
|
| 200 |
+
# Clean up dummy files
|
| 201 |
+
if os.path.exists(dummy_excel_file):
|
| 202 |
+
os.remove(dummy_excel_file)
|
| 203 |
+
if os.path.exists(dummy_csv_file):
|
| 204 |
+
os.remove(dummy_csv_file)
|
| 205 |
+
|
| 206 |
+
print("\nSpreadsheetTool tests completed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/video_processing_tool.py
CHANGED
|
@@ -14,8 +14,18 @@ class VideoProcessingTool(Tool):
|
|
| 14 |
Analyzes video content, extracting information such as frames, audio, or metadata.
|
| 15 |
Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
Initializes the VideoProcessingTool.
|
| 21 |
|
|
@@ -25,6 +35,9 @@ class VideoProcessingTool(Tool):
|
|
| 25 |
class_names_path (str, optional): Path to the file containing class names for the model.
|
| 26 |
temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
|
| 27 |
"""
|
|
|
|
|
|
|
|
|
|
| 28 |
if temp_dir_base:
|
| 29 |
self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
|
| 30 |
else:
|
|
@@ -37,16 +50,67 @@ class VideoProcessingTool(Tool):
|
|
| 37 |
if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
|
| 38 |
try:
|
| 39 |
self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
|
| 40 |
-
# Set preferable backend and target
|
| 41 |
self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
|
| 42 |
self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
|
| 43 |
with open(class_names_path, "r") as f:
|
| 44 |
self.class_names = [line.strip() for line in f.readlines()]
|
|
|
|
| 45 |
except Exception as e:
|
| 46 |
print(f"Error loading CV model: {e}. Object detection will not be available.")
|
| 47 |
self.object_detection_model = None
|
| 48 |
else:
|
| 49 |
print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def _extract_video_id(self, youtube_url):
|
| 52 |
"""Extract the YouTube video ID from a URL."""
|
|
@@ -105,7 +169,69 @@ class VideoProcessingTool(Tool):
|
|
| 105 |
except Exception as e:
|
| 106 |
return {"error": f"Failed to download video: {str(e)}"}
|
| 107 |
|
| 108 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
"""Get the transcript/captions of a YouTube video."""
|
| 110 |
if languages is None:
|
| 111 |
languages = ['en', 'en-US'] # Default to English
|
|
@@ -161,20 +287,18 @@ class VideoProcessingTool(Tool):
|
|
| 161 |
# Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
|
| 162 |
return {"error": f"Failed to get transcript: {str(e)}"}
|
| 163 |
|
| 164 |
-
def
|
| 165 |
"""
|
| 166 |
-
|
| 167 |
-
|
| 168 |
Args:
|
| 169 |
video_path (str): Path to the video file.
|
| 170 |
-
target_classes (list, optional): A list of object classes (strings) to count (e.g., ["bird", "cat"]).
|
| 171 |
-
If None, counts all detected objects.
|
| 172 |
confidence_threshold (float): Minimum confidence for an object to be counted.
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
Returns:
|
| 175 |
-
dict:
|
| 176 |
-
e.g., {"success": True, "max_simultaneous_birds": 3, "max_simultaneous_cats": 1}
|
| 177 |
-
or {"error": "Object detection model not loaded."}
|
| 178 |
"""
|
| 179 |
if not self.object_detection_model or not self.class_names:
|
| 180 |
return {"error": "Object detection model not loaded or class names missing."}
|
|
@@ -185,168 +309,45 @@ class VideoProcessingTool(Tool):
|
|
| 185 |
if not cap.isOpened():
|
| 186 |
return {"error": "Could not open video file."}
|
| 187 |
|
| 188 |
-
|
| 189 |
-
# If target_classes is None, we'd need to initialize for all detected classes,
|
| 190 |
-
# but for simplicity, let's require target_classes for now or adjust later.
|
| 191 |
-
if not target_classes:
|
| 192 |
-
# Defaulting to a common class if none specified, e.g. 'person'
|
| 193 |
-
# Or, one could count all unique classes detected. For GAIA, specific targets are better.
|
| 194 |
-
return {"error": "target_classes must be specified for counting."}
|
| 195 |
-
|
| 196 |
-
|
| 197 |
frame_count = 0
|
|
|
|
|
|
|
|
|
|
| 198 |
while cap.isOpened():
|
| 199 |
ret, frame = cap.read()
|
| 200 |
if not ret:
|
| 201 |
break
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
frame_count += 1
|
| 204 |
-
if frame_count % frame_skip != 0:
|
| 205 |
-
continue
|
| 206 |
|
| 207 |
-
height, width = frame.shape[:2]
|
| 208 |
-
blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
|
| 209 |
-
self.object_detection_model.setInput(blob)
|
| 210 |
-
|
| 211 |
-
layer_names = self.object_detection_model.getLayerNames()
|
| 212 |
-
# Handle potential differences in getUnconnectedOutLayers() return value
|
| 213 |
-
unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
|
| 214 |
-
if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
|
| 215 |
-
output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
|
| 216 |
-
else: # For typical cases
|
| 217 |
-
output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
|
| 218 |
-
|
| 219 |
-
detections = self.object_detection_model.forward(output_layer_names)
|
| 220 |
-
|
| 221 |
-
current_frame_counts = {cls: 0 for cls in target_classes}
|
| 222 |
-
|
| 223 |
-
for detection_set in detections: # Detections can come from multiple output layers
|
| 224 |
-
for detection in detection_set:
|
| 225 |
-
scores = detection[5:]
|
| 226 |
-
class_id = np.argmax(scores)
|
| 227 |
-
confidence = scores[class_id]
|
| 228 |
-
|
| 229 |
-
if confidence > confidence_threshold:
|
| 230 |
-
detected_class_name = self.class_names[class_id]
|
| 231 |
-
if detected_class_name in target_classes:
|
| 232 |
-
current_frame_counts[detected_class_name] += 1
|
| 233 |
-
|
| 234 |
-
for cls in target_classes:
|
| 235 |
-
if current_frame_counts[cls] > max_counts_per_class[cls]:
|
| 236 |
-
max_counts_per_class[cls] = current_frame_counts[cls]
|
| 237 |
-
|
| 238 |
cap.release()
|
| 239 |
-
|
| 240 |
-
for cls, count in max_counts_per_class.items():
|
| 241 |
-
result[f"max_simultaneous_{cls.replace(' ', '_')}"] = count # e.g. "max_simultaneous_bird"
|
| 242 |
-
return result
|
| 243 |
-
|
| 244 |
-
def find_dialogue_response(self, transcript_entries, query_phrase, max_entries_gap=2, max_time_gap_s=5.0):
|
| 245 |
-
"""
|
| 246 |
-
Finds what is said in response to a given query phrase in transcript entries.
|
| 247 |
-
Looks for the query phrase and then captures the text from subsequent entries.
|
| 248 |
-
|
| 249 |
-
Args:
|
| 250 |
-
transcript_entries (list): List of transcript dictionaries (from get_video_transcript).
|
| 251 |
-
query_phrase (str): The phrase to find (e.g., a question).
|
| 252 |
-
max_entries_gap (int): How many transcript entries to look ahead for a response.
|
| 253 |
-
max_time_gap_s (float): Maximum time in seconds after the query phrase to consider for a response.
|
| 254 |
-
|
| 255 |
-
Returns:
|
| 256 |
-
dict: {"success": True, "response_text": "...", "found_at_entry": {...}} or {"error": "..."}
|
| 257 |
-
"""
|
| 258 |
-
if not transcript_entries:
|
| 259 |
-
return {"error": "Transcript entries are empty."}
|
| 260 |
-
|
| 261 |
-
query_phrase_lower = query_phrase.lower().rstrip('?.!,;') # Strip common trailing punctuation
|
| 262 |
-
|
| 263 |
-
for i, entry in enumerate(transcript_entries):
|
| 264 |
-
# Correctly access attributes: .text, .start, .duration
|
| 265 |
-
if query_phrase_lower in entry.text.lower():
|
| 266 |
-
# Found the query phrase, now look for the response
|
| 267 |
-
response_parts = []
|
| 268 |
-
start_time_of_query = entry.start + entry.duration # End time of query entry
|
| 269 |
-
|
| 270 |
-
for j in range(i + 1, min(i + 1 + max_entries_gap + 1, len(transcript_entries))):
|
| 271 |
-
next_entry = transcript_entries[j]
|
| 272 |
-
# Check if the next entry is within the time gap
|
| 273 |
-
if next_entry.start - start_time_of_query > max_time_gap_s:
|
| 274 |
-
break # Too much time has passed
|
| 275 |
-
|
| 276 |
-
# Add text if it's not just noise or very short (heuristic)
|
| 277 |
-
if next_entry.text.strip() and len(next_entry.text.strip()) > 1:
|
| 278 |
-
response_parts.append(next_entry.text)
|
| 279 |
-
|
| 280 |
-
# If we have collected some response, and the next entry is significantly later, stop.
|
| 281 |
-
if response_parts and (j + 1 < len(transcript_entries)):
|
| 282 |
-
if transcript_entries[j+1].start - (next_entry.start + next_entry.duration) > 1.0: # If gap > 1s
|
| 283 |
-
break
|
| 284 |
-
|
| 285 |
-
if response_parts:
|
| 286 |
-
return {
|
| 287 |
-
"success": True,
|
| 288 |
-
"response_text": " ".join(response_parts),
|
| 289 |
-
"query_entry": entry,
|
| 290 |
-
"response_start_entry_index": i + 1
|
| 291 |
-
}
|
| 292 |
-
# If no response found immediately after, but query was found
|
| 293 |
-
return {"error": f"Query phrase '{query_phrase}' found, but no subsequent dialogue captured as response within gap."}
|
| 294 |
-
|
| 295 |
-
return {"error": f"Query phrase '{query_phrase}' not found in transcript."}
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def process_video(self, youtube_url, query_type, query_params=None):
|
| 299 |
-
"""
|
| 300 |
-
Main method to process a video based on the type of query.
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
youtube_url (str): URL of the YouTube video.
|
| 304 |
-
query_type (str): Type of processing: "transcript", "object_count", "dialogue_response".
|
| 305 |
-
query_params (dict, optional): Additional parameters for the specific query type.
|
| 306 |
-
For "object_count": {"target_classes": ["bird"], "confidence_threshold": 0.5, "resolution": "360p"}
|
| 307 |
-
For "dialogue_response": {"query_phrase": "Isn't that hot?", "languages": ['en']}
|
| 308 |
-
"""
|
| 309 |
-
if query_params is None:
|
| 310 |
-
query_params = {}
|
| 311 |
-
|
| 312 |
-
if query_type == "transcript":
|
| 313 |
-
return self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
|
| 314 |
-
|
| 315 |
-
elif query_type == "object_count":
|
| 316 |
-
if not self.object_detection_model:
|
| 317 |
-
return {"error": "Object detection model not initialized. Cannot count objects."}
|
| 318 |
-
|
| 319 |
-
resolution = query_params.get("resolution", "360p")
|
| 320 |
-
download_result = self.download_video(youtube_url, resolution=resolution)
|
| 321 |
-
if "error" in download_result:
|
| 322 |
-
return download_result
|
| 323 |
-
|
| 324 |
-
video_path = download_result["file_path"]
|
| 325 |
-
target_classes = query_params.get("target_classes")
|
| 326 |
-
if not target_classes or not isinstance(target_classes, list):
|
| 327 |
-
return {"error": "query_params must include 'target_classes' as a list for object_count."}
|
| 328 |
-
|
| 329 |
-
confidence = query_params.get("confidence_threshold", 0.5)
|
| 330 |
-
frame_skip = query_params.get("frame_skip", 5)
|
| 331 |
-
return self.count_objects_in_video(video_path, target_classes, confidence, frame_skip)
|
| 332 |
-
|
| 333 |
-
elif query_type == "dialogue_response":
|
| 334 |
-
transcript_result = self.get_video_transcript(youtube_url, languages=query_params.get("languages"))
|
| 335 |
-
if "error" in transcript_result:
|
| 336 |
-
return transcript_result
|
| 337 |
-
|
| 338 |
-
query_phrase = query_params.get("query_phrase")
|
| 339 |
-
if not query_phrase:
|
| 340 |
-
return {"error": "query_params must include 'query_phrase' for dialogue_response."}
|
| 341 |
-
|
| 342 |
-
return self.find_dialogue_response(
|
| 343 |
-
transcript_result["transcript_entries"],
|
| 344 |
-
query_phrase,
|
| 345 |
-
max_entries_gap=query_params.get("max_entries_gap", 2),
|
| 346 |
-
max_time_gap_s=query_params.get("max_time_gap_s", 5.0)
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
return {"error": f"Unsupported query type: {query_type}"}
|
| 350 |
|
| 351 |
def cleanup(self):
|
| 352 |
"""Remove temporary files and directory."""
|
|
|
|
| 14 |
Analyzes video content, extracting information such as frames, audio, or metadata.
|
| 15 |
Useful for tasks like video summarization, frame extraction, transcript analysis, or content analysis.
|
| 16 |
"""
|
| 17 |
+
name = "video_processor"
|
| 18 |
+
description = "Analyzes video content from a file path or YouTube URL. Can extract frames, detect objects, get transcripts, and provide video metadata."
|
| 19 |
+
inputs = {
|
| 20 |
+
"file_path": {"type": "string", "description": "Path to the video file or YouTube URL.", "nullable": True},
|
| 21 |
+
"task": {"type": "string", "description": "Specific task to perform (e.g., 'extract_frames', 'get_transcript', 'detect_objects', 'get_metadata').", "nullable": True},
|
| 22 |
+
"task_parameters": {"type": "object", "description": "Parameters for the specific task (e.g., frame extraction interval, object detection confidence).", "nullable": True}
|
| 23 |
+
}
|
| 24 |
+
outputs = {"result": {"type": "object", "description": "The result of the video processing task, e.g., list of frame paths, transcript text, object detection results, or metadata dictionary."}}
|
| 25 |
+
output_type = "object"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def __init__(self, model_cfg_path=None, model_weights_path=None, class_names_path=None, temp_dir_base=None, *args, **kwargs):
|
| 29 |
"""
|
| 30 |
Initializes the VideoProcessingTool.
|
| 31 |
|
|
|
|
| 35 |
class_names_path (str, optional): Path to the file containing class names for the model.
|
| 36 |
temp_dir_base (str, optional): Base directory for temporary files. Defaults to system temp.
|
| 37 |
"""
|
| 38 |
+
super().__init__(*args, **kwargs)
|
| 39 |
+
self.is_initialized = False # Will be set to True after successful setup
|
| 40 |
+
|
| 41 |
if temp_dir_base:
|
| 42 |
self.temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
|
| 43 |
else:
|
|
|
|
| 50 |
if os.path.exists(model_cfg_path) and os.path.exists(model_weights_path) and os.path.exists(class_names_path):
|
| 51 |
try:
|
| 52 |
self.object_detection_model = cv2.dnn.readNetFromDarknet(model_cfg_path, model_weights_path)
|
|
|
|
| 53 |
self.object_detection_model.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
|
| 54 |
self.object_detection_model.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
|
| 55 |
with open(class_names_path, "r") as f:
|
| 56 |
self.class_names = [line.strip() for line in f.readlines()]
|
| 57 |
+
print("CV Model loaded successfully.")
|
| 58 |
except Exception as e:
|
| 59 |
print(f"Error loading CV model: {e}. Object detection will not be available.")
|
| 60 |
self.object_detection_model = None
|
| 61 |
else:
|
| 62 |
print("Warning: One or more CV model paths are invalid. Object detection will not be available.")
|
| 63 |
+
else:
|
| 64 |
+
print("CV model paths not provided. Object detection will not be available.")
|
| 65 |
+
|
| 66 |
+
self.is_initialized = True
|
| 67 |
+
|
| 68 |
+
def forward(self, file_path: str = None, task: str = "get_metadata", task_parameters: dict = None):
|
| 69 |
+
"""
|
| 70 |
+
Main entry point for video processing tasks.
|
| 71 |
+
"""
|
| 72 |
+
if not self.is_initialized:
|
| 73 |
+
return {"error": "Tool not initialized properly."}
|
| 74 |
+
|
| 75 |
+
if task_parameters is None:
|
| 76 |
+
task_parameters = {}
|
| 77 |
+
|
| 78 |
+
is_youtube_url = file_path and ("youtube.com/" in file_path or "youtu.be/" in file_path)
|
| 79 |
+
video_source_path = file_path
|
| 80 |
+
|
| 81 |
+
if is_youtube_url:
|
| 82 |
+
download_resolution = task_parameters.get("resolution", "360p")
|
| 83 |
+
download_result = self.download_video(file_path, resolution=download_resolution)
|
| 84 |
+
if download_result.get("error"):
|
| 85 |
+
return download_result
|
| 86 |
+
video_source_path = download_result.get("file_path")
|
| 87 |
+
if not video_source_path or not os.path.exists(video_source_path):
|
| 88 |
+
return {"error": f"Failed to download or locate video from URL: {file_path}"}
|
| 89 |
+
|
| 90 |
+
elif file_path and not os.path.exists(file_path):
|
| 91 |
+
return {"error": f"Video file not found: {file_path}"}
|
| 92 |
+
elif not file_path and task not in ['get_transcript']: # transcript can work with URL directly
|
| 93 |
+
return {"error": "File path is required for this task."}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if task == "get_metadata":
|
| 97 |
+
return self.get_video_metadata(video_source_path)
|
| 98 |
+
elif task == "extract_frames":
|
| 99 |
+
interval_seconds = task_parameters.get("interval_seconds", 5)
|
| 100 |
+
max_frames = task_parameters.get("max_frames")
|
| 101 |
+
return self.extract_frames_from_video(video_source_path, interval_seconds=interval_seconds, max_frames=max_frames)
|
| 102 |
+
elif task == "get_transcript":
|
| 103 |
+
# Use original file_path which might be the URL
|
| 104 |
+
return self.get_youtube_transcript(file_path)
|
| 105 |
+
elif task == "detect_objects":
|
| 106 |
+
if not self.object_detection_model:
|
| 107 |
+
return {"error": "Object detection model not loaded."}
|
| 108 |
+
confidence_threshold = task_parameters.get("confidence_threshold", 0.5)
|
| 109 |
+
frames_to_process = task_parameters.get("frames_to_process", 5) # Process N frames
|
| 110 |
+
return self.detect_objects_in_video(video_source_path, confidence_threshold=confidence_threshold, num_frames_to_sample=frames_to_process)
|
| 111 |
+
# Add more tasks as needed, e.g., extract_audio
|
| 112 |
+
else:
|
| 113 |
+
return {"error": f"Unsupported task: {task}"}
|
| 114 |
|
| 115 |
def _extract_video_id(self, youtube_url):
|
| 116 |
"""Extract the YouTube video ID from a URL."""
|
|
|
|
| 169 |
except Exception as e:
|
| 170 |
return {"error": f"Failed to download video: {str(e)}"}
|
| 171 |
|
| 172 |
+
def get_video_metadata(self, video_path):
|
| 173 |
+
"""Extract metadata from the video file."""
|
| 174 |
+
if not os.path.exists(video_path):
|
| 175 |
+
return {"error": f"Video file not found: {video_path}"}
|
| 176 |
+
|
| 177 |
+
cap = cv2.VideoCapture(video_path)
|
| 178 |
+
if not cap.isOpened():
|
| 179 |
+
return {"error": "Could not open video file."}
|
| 180 |
+
|
| 181 |
+
metadata = {
|
| 182 |
+
"frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
|
| 183 |
+
"fps": cap.get(cv2.CAP_PROP_FPS),
|
| 184 |
+
"width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
| 185 |
+
"height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
| 186 |
+
"duration": cap.get(cv2.CAP_PROP_FRAME_COUNT) / cap.get(cv2.CAP_PROP_FPS)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
cap.release()
|
| 190 |
+
return {"success": True, "metadata": metadata}
|
| 191 |
+
|
| 192 |
+
def extract_frames_from_video(self, video_path, interval_seconds=5, max_frames=None):
|
| 193 |
+
"""
|
| 194 |
+
Extracts frames from the video at specified intervals.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
video_path (str): Path to the video file.
|
| 198 |
+
interval_seconds (int): Interval in seconds between frames.
|
| 199 |
+
max_frames (int, optional): Maximum number of frames to extract.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
dict: {"success": True, "extracted_frame_paths": [...] } or {"error": "..."}
|
| 203 |
+
"""
|
| 204 |
+
if not os.path.exists(video_path):
|
| 205 |
+
return {"error": f"Video file not found: {video_path}"}
|
| 206 |
+
|
| 207 |
+
cap = cv2.VideoCapture(video_path)
|
| 208 |
+
if not cap.isOpened():
|
| 209 |
+
return {"error": "Could not open video file."}
|
| 210 |
+
|
| 211 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 212 |
+
frame_interval = int(fps * interval_seconds)
|
| 213 |
+
extracted_frame_paths = []
|
| 214 |
+
frame_count = 0
|
| 215 |
+
|
| 216 |
+
while cap.isOpened():
|
| 217 |
+
ret, frame = cap.read()
|
| 218 |
+
if not ret:
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
if frame_count % frame_interval == 0:
|
| 222 |
+
frame_id = int(frame_count / frame_interval)
|
| 223 |
+
frame_file_path = os.path.join(self.temp_dir, f"frame_{frame_id:04d}.jpg")
|
| 224 |
+
cv2.imwrite(frame_file_path, frame)
|
| 225 |
+
extracted_frame_paths.append(frame_file_path)
|
| 226 |
+
if max_frames and len(extracted_frame_paths) >= max_frames:
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
frame_count += 1
|
| 230 |
+
|
| 231 |
+
cap.release()
|
| 232 |
+
return {"success": True, "extracted_frame_paths": extracted_frame_paths}
|
| 233 |
+
|
| 234 |
+
def get_youtube_transcript(self, youtube_url, languages=None):
|
| 235 |
"""Get the transcript/captions of a YouTube video."""
|
| 236 |
if languages is None:
|
| 237 |
languages = ['en', 'en-US'] # Default to English
|
|
|
|
| 287 |
# Catches other exceptions from YouTubeTranscriptApi calls or re-raised from fetch
|
| 288 |
return {"error": f"Failed to get transcript: {str(e)}"}
|
| 289 |
|
| 290 |
+
def detect_objects_in_video(self, video_path, confidence_threshold=0.5, num_frames_to_sample=5, target_fps=1):
|
| 291 |
"""
|
| 292 |
+
Detects objects in the video and returns the count of specified objects.
|
| 293 |
+
|
| 294 |
Args:
|
| 295 |
video_path (str): Path to the video file.
|
|
|
|
|
|
|
| 296 |
confidence_threshold (float): Minimum confidence for an object to be counted.
|
| 297 |
+
num_frames_to_sample (int): Number of frames to sample for object detection.
|
| 298 |
+
target_fps (int): Target frames per second for processing.
|
| 299 |
+
|
| 300 |
Returns:
|
| 301 |
+
dict: {"success": True, "object_counts": {...}} or {"error": "..."}
|
|
|
|
|
|
|
| 302 |
"""
|
| 303 |
if not self.object_detection_model or not self.class_names:
|
| 304 |
return {"error": "Object detection model not loaded or class names missing."}
|
|
|
|
| 309 |
if not cap.isOpened():
|
| 310 |
return {"error": "Could not open video file."}
|
| 311 |
|
| 312 |
+
object_counts = {cls: 0 for cls in self.class_names}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
frame_count = 0
|
| 314 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 315 |
+
sample_interval = max(1, total_frames // num_frames_to_sample)
|
| 316 |
+
|
| 317 |
while cap.isOpened():
|
| 318 |
ret, frame = cap.read()
|
| 319 |
if not ret:
|
| 320 |
break
|
| 321 |
|
| 322 |
+
if frame_count % sample_interval == 0:
|
| 323 |
+
height, width = frame.shape[:2]
|
| 324 |
+
blob = cv2.dnn.blobFromImage(frame, 1/255.0, (416, 416), swapRB=True, crop=False)
|
| 325 |
+
self.object_detection_model.setInput(blob)
|
| 326 |
+
|
| 327 |
+
layer_names = self.object_detection_model.getLayerNames()
|
| 328 |
+
# Handle potential differences in getUnconnectedOutLayers() return value
|
| 329 |
+
unconnected_out_layers_indices = self.object_detection_model.getUnconnectedOutLayers()
|
| 330 |
+
if isinstance(unconnected_out_layers_indices, np.ndarray) and unconnected_out_layers_indices.ndim > 1 : # For some OpenCV versions
|
| 331 |
+
output_layer_names = [layer_names[i[0] - 1] for i in unconnected_out_layers_indices]
|
| 332 |
+
else: # For typical cases
|
| 333 |
+
output_layer_names = [layer_names[i - 1] for i in unconnected_out_layers_indices]
|
| 334 |
+
|
| 335 |
+
detections = self.object_detection_model.forward(output_layer_names)
|
| 336 |
+
|
| 337 |
+
for detection_set in detections: # Detections can come from multiple output layers
|
| 338 |
+
for detection in detection_set:
|
| 339 |
+
scores = detection[5:]
|
| 340 |
+
class_id = np.argmax(scores)
|
| 341 |
+
confidence = scores[class_id]
|
| 342 |
+
|
| 343 |
+
if confidence > confidence_threshold:
|
| 344 |
+
detected_class_name = self.class_names[class_id]
|
| 345 |
+
object_counts[detected_class_name] += 1
|
| 346 |
+
|
| 347 |
frame_count += 1
|
|
|
|
|
|
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
cap.release()
|
| 350 |
+
return {"success": True, "object_counts": object_counts}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
def cleanup(self):
|
| 353 |
"""Remove temporary files and directory."""
|