|
import ast |
|
import contextlib |
|
import io |
|
import signal |
|
import re |
|
import traceback |
|
from typing import Dict, Any, Optional, Union, List |
|
from smolagents.tools import Tool |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class CodeExecutionTool(Tool): |
|
""" |
|
Executes Python code snippets safely with timeout protection. |
|
Useful for data processing, analysis, and transformation. |
|
Includes special utilities for web data processing and robust error handling. |
|
""" |
|
name = "python_executor" |
|
description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery." |
|
inputs = { |
|
'code_string': {'type': 'string', 'description': 'The Python code to execute.', 'nullable': True}, |
|
'filepath': {'type': 'string', 'description': 'Path to a Python file to execute.', 'nullable': True} |
|
} |
|
outputs = { |
|
'success': {'type': 'boolean', 'description': 'Whether the code executed successfully.'}, |
|
'output': {'type': 'string', 'description': 'The captured stdout or formatted result.', 'nullable': True}, |
|
'error': {'type': 'string', 'description': 'Error message if execution failed.', 'nullable': True}, |
|
'result_value': {'type': 'any', 'description': 'The final expression value if applicable.', 'nullable': True} |
|
} |
|
output_type = "object" |
|
|
|
def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.timeout = timeout |
|
self.max_output_size = max_output_size |
|
self.banned_modules = [ |
|
'os', 'subprocess', 'sys', 'builtins', 'importlib', |
|
'pickle', 'requests', 'socket', 'shutil', 'ctypes', 'multiprocessing' |
|
] |
|
self.is_initialized = True |
|
|
|
self._utility_functions = self._get_utility_functions() |
|
|
|
def _get_utility_functions(self): |
|
"""Define utility functions that will be available in the executed code""" |
|
utility_code = """ |
|
# Utility functions for web data processing |
|
def extract_pattern(text, pattern, group=0, all_matches=False): |
|
""" |
|
"Extract data using regex pattern from text. |
|
Args: |
|
text (str): Text to search in |
|
pattern (str): Regex pattern to use |
|
group (int): Capture group to return (default 0 - entire match) |
|
all_matches (bool): If True, return all matches, otherwise just first |
|
Returns: |
|
Matched string(s) or None if no match |
|
""" |
|
import re |
|
if not text or not pattern: |
|
print("Warning: Empty text or pattern provided to extract_pattern") |
|
return None |
|
|
|
try: |
|
matches = re.finditer(pattern, text) |
|
results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches] |
|
|
|
if not results: |
|
print(f"No matches found for pattern '{pattern}'") |
|
return None |
|
|
|
if all_matches: |
|
return results |
|
else: |
|
return results[0] |
|
except Exception as e: |
|
print(f"Error in extract_pattern: {e}") |
|
return None |
|
|
|
def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False): |
|
""" |
|
Clean text by removing extra whitespace and optionally special characters. |
|
Args: |
|
text (str): Text to clean |
|
remove_extra_whitespace (bool): If True, replace multiple spaces with single space |
|
remove_special_chars (bool): If True, remove special characters |
|
Returns: |
|
Cleaned string |
|
""" |
|
import re |
|
if not text: |
|
return "" |
|
|
|
|
|
text = re.sub(r'[\\n\\t\\r]+', ' ', text) |
|
|
|
if remove_special_chars: |
|
|
|
text = re.sub(r'[^\\w\\s.,;:!?\'"()-]', '', text) |
|
|
|
if remove_extra_whitespace: |
|
|
|
text = re.sub(r'\\s+', ' ', text) |
|
|
|
return text.strip() |
|
|
|
def parse_table_text(table_text): |
|
""" |
|
Parse table-like text into list of rows |
|
Args: |
|
table_text (str): Text containing table-like data |
|
Returns: |
|
List of rows (each row is a list of cells) |
|
""" |
|
rows = [] |
|
lines = table_text.strip().split('\\n') |
|
|
|
for line in lines: |
|
|
|
if not line.strip(): |
|
continue |
|
|
|
|
|
cells = re.split(r'\\s{2,}|\\t+|\\|+', line.strip()) |
|
|
|
cells = [cell.strip() for cell in cells if cell.strip()] |
|
|
|
if cells: |
|
rows.append(cells) |
|
|
|
|
|
print(f"Parsed {len(rows)} rows from table text") |
|
if rows and len(rows) > 0: |
|
print(f"First row (columns: {len(rows[0])}): {rows[0]}") |
|
|
|
return rows |
|
|
|
def safe_float(text): |
|
""" |
|
Safely convert text to float, handling various formats. |
|
Args: |
|
text (str): Text to convert |
|
Returns: |
|
float or None if conversion fails |
|
""" |
|
if not text: |
|
return None |
|
|
|
|
|
text = re.sub(r'[^0-9.-]', '', str(text)) |
|
|
|
try: |
|
return float(text) |
|
except ValueError: |
|
print(f"Warning: Could not convert '{text}' to float") |
|
return None |
|
""" |
|
return utility_code |
|
|
|
def _analyze_code_safety(self, code: str) -> Dict[str, Any]: |
|
"""Perform static analysis to check for potentially harmful code.""" |
|
try: |
|
parsed = ast.parse(code) |
|
|
|
# Check for banned imports |
|
imports = [] |
|
for node in ast.walk(parsed): |
|
if isinstance(node, ast.Import): |
|
imports.extend(n.name for n in node.names) |
|
elif isinstance(node, ast.ImportFrom): |
|
# Ensure node.module is not None before attempting to check against banned_modules |
|
if node.module and any(banned in node.module for banned in self.banned_modules): |
|
imports.append(node.module) |
|
|
|
dangerous_imports = [imp for imp in imports if imp and any( |
|
banned in imp for banned in self.banned_modules)] |
|
|
|
if dangerous_imports: |
|
return { |
|
"safe": False, |
|
"reason": f"Potentially harmful imports detected: {dangerous_imports}" |
|
} |
|
|
|
# Check for exec/eval usage |
|
for node in ast.walk(parsed): |
|
if isinstance(node, ast.Call) and hasattr(node, 'func'): |
|
if isinstance(node.func, ast.Name) and node.func.id in ['exec', 'eval']: |
|
return { |
|
"safe": False, |
|
"reason": "Contains exec() or eval() calls" |
|
} |
|
|
|
return {"safe": True} |
|
except SyntaxError: |
|
return {"safe": False, "reason": "Invalid Python syntax"} |
|
|
|
def _timeout_handler(self, signum, frame): |
|
"""Handler for timeout signal.""" |
|
raise TimeoutError(f"Code execution timed out after {self.timeout} seconds") |
|
|
|
def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]: |
|
"""Extract the final numeric value from output.""" |
|
if not output: |
|
return None |
|
|
|
# Look for the last line that contains a number |
|
lines = output.strip().split('\n') |
|
for line in reversed(lines): |
|
# Try to interpret it as a pure number |
|
line = line.strip() |
|
try: |
|
if '.' in line: |
|
return float(line) |
|
else: |
|
return int(line) |
|
except ValueError: |
|
# Not a pure number, try to extract numbers with regex |
|
match = re.search(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$', line) |
|
if match: |
|
num_str = match.group(0) |
|
try: |
|
if '.' in num_str: |
|
return float(num_str) |
|
else: |
|
return int(num_str) |
|
except ValueError: |
|
pass |
|
return None |
|
|
|
def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]: |
|
if not code_string and not filepath: |
|
return {"success": False, "error": "No code string or filepath provided."} |
|
if code_string and filepath: |
|
return {"success": False, "error": "Provide either a code string or a filepath, not both."} |
|
|
|
code_to_execute = "" |
|
if filepath: |
|
if not os.path.exists(filepath): |
|
return {"success": False, "error": f"File not found: {filepath}"} |
|
if not filepath.endswith(".py"): |
|
return {"success": False, "error": f"File is not a Python file: {filepath}"} |
|
try: |
|
with open(filepath, 'r') as file: |
|
code_to_execute = file.read() |
|
except Exception as e: |
|
return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} |
|
elif code_string: |
|
code_to_execute = code_string |
|
|
|
# Inject utility functions |
|
enhanced_code = self._utility_functions + "\n\n" + code_to_execute |
|
|
|
return self._execute_actual_code(enhanced_code) |
|
|
|
def _execute_actual_code(self, code: str) -> Dict[str, Any]: |
|
"""Execute Python code and capture the output or error.""" |
|
safety_check = self._analyze_code_safety(code) |
|
if not safety_check["safe"]: |
|
return { |
|
"success": False, |
|
"error": f"Safety check failed: {safety_check['reason']}" |
|
} |
|
|
|
# Capture stdout and execute the code with a timeout |
|
stdout_buffer = io.StringIO() |
|
result_value = None |
|
|
|
try: |
|
# Set timeout handler |
|
signal.signal(signal.SIGALRM, self._timeout_handler) |
|
signal.alarm(self.timeout) |
|
|
|
# Execute code and capture stdout |
|
with contextlib.redirect_stdout(stdout_buffer): |
|
# Execute the code within a new dictionary for local variables |
|
local_vars = {} |
|
exec(code, {}, local_vars) |
|
|
|
# Try to extract the result from common variable names |
|
for var_name in ['result', 'answer', 'output', 'value', 'final_result', 'data']: |
|
if var_name in local_vars: |
|
result_value = local_vars[var_name] |
|
break |
|
|
|
# Reset the alarm |
|
signal.alarm(0) |
|
|
|
# Get the captured output |
|
output = stdout_buffer.getvalue() |
|
if len(output) > self.max_output_size: |
|
output = output[:self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)" |
|
|
|
# If no result_value was found, try to extract a numeric value from the output |
|
if result_value is None: |
|
result_value = self._extract_numeric_value(output) |
|
|
|
return { |
|
"success": True, |
|
"output": output, |
|
"result_value": result_value |
|
} |
|
|
|
except TimeoutError as e: |
|
signal.alarm(0) # Reset the alarm |
|
return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"} |
|
except Exception as e: |
|
signal.alarm(0) # Reset the alarm |
|
trace = traceback.format_exc() |
|
error_msg = f"Error executing code: {str(e)}\n{trace}" |
|
return {"success": False, "error": error_msg} |
|
finally: |
|
# Ensure the alarm is reset |
|
signal.alarm(0) |
|
|
|
# Kept execute_file and execute_code as helper methods if direct access is ever needed, |
|
# but they now call the main _execute_actual_code method. |
|
def execute_file(self, filepath: str) -> Dict[str, Any]: |
|
"""Helper to execute Python code from file.""" |
|
if not os.path.exists(filepath): |
|
return {"success": False, "error": f"File not found: {filepath}"} |
|
if not filepath.endswith(".py"): |
|
return {"success": False, "error": f"File is not a Python file: {filepath}"} |
|
try: |
|
with open(filepath, 'r') as file: |
|
code = file.read() |
|
return self._execute_actual_code(code) |
|
except Exception as e: |
|
return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
"""Helper to execute Python code from a string.""" |
|
return self._execute_actual_code(code) |
|
|
|
|
|
if __name__ == '__main__': |
|
tool = CodeExecutionTool(timeout=5) |
|
|
|
# Test 1: Safe code string |
|
safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)" |
|
print("\n--- Test 1: Safe Code String ---") |
|
result1 = tool.forward(code_string=safe_code) |
|
print(result1) |
|
assert result1['success'] |
|
assert "Hello from safe code!" in result1['output'] |
|
assert "20" in result1['output'] |
|
|
|
# Test 2: Code with an error |
|
error_code = "print(1/0)" |
|
print("\n--- Test 2: Code with Error ---") |
|
result2 = tool.forward(code_string=error_code) |
|
print(result2) |
|
assert not result2['success'] |
|
assert "ZeroDivisionError" in result2['error'] |
|
|
|
# Test 3: Code with a banned import |
|
unsafe_import_code = "import os; print(os.getcwd())" |
|
print("\n--- Test 3: Unsafe Import ---") |
|
result3 = tool.forward(code_string=unsafe_import_code) |
|
print(result3) |
|
assert not result3['success'] |
|
assert "Safety check failed" in result3['error'] |
|
assert "os" in result3['error'] |
|
|
|
# Test 4: Timeout |
|
timeout_code = "import time; time.sleep(10); print('Done sleeping')" |
|
print("\n--- Test 4: Timeout ---") |
|
# tool_timeout_short = CodeExecutionTool(timeout=2) # For testing timeout specifically |
|
# result4 = tool_timeout_short.forward(code_string=timeout_code) |
|
result4 = tool.forward(code_string=timeout_code) # Using the main tool instance with its timeout |
|
print(result4) |
|
assert not result4['success'] |
|
assert "timed out" in result4['error'] |
|
|
|
# Test 5: Execute from file |
|
test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')" |
|
test_filename = "temp_test_script.py" |
|
with open(test_filename, "w") as f: |
|
f.write(test_file_content) |
|
print("\n--- Test 5: Execute from File ---") |
|
result5 = tool.forward(filepath=test_filename) |
|
print(result5) |
|
assert result5['success'] |
|
assert "Hello from file!" in result5['output'] |
|
assert "Sum: 12" in result5['output'] |
|
os.remove(test_filename) |
|
|
|
# Test 6: File not found |
|
print("\n--- Test 6: File Not Found ---") |
|
result6 = tool.forward(filepath="non_existent_script.py") |
|
print(result6) |
|
assert not result6['success'] |
|
assert "File not found" in result6['error'] |
|
|
|
# Test 7: Provide both code_string and filepath |
|
print("\n--- Test 7: Both code_string and filepath ---") |
|
result7 = tool.forward(code_string="print('hello')", filepath=test_filename) |
|
print(result7) |
|
assert not result7['success'] |
|
assert "Provide either a code string or a filepath, not both" in result7['error'] |
|
|
|
# Test 8: Provide neither |
|
print("\n--- Test 8: Neither code_string nor filepath ---") |
|
result8 = tool.forward() |
|
print(result8) |
|
assert not result8['success'] |
|
assert "No code string or filepath provided" in result8['error'] |
|
|
|
# Test 9: Code that defines a function and calls it |
|
func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))" |
|
print("\n--- Test 9: Function Definition and Call ---") |
|
result9 = tool.forward(code_string=func_def_code) |
|
print(result9) |
|
assert result9['success'] |
|
assert "7" in result9['output'] |
|
|
|
# Test 10: Max output size |
|
# tool_max_output = CodeExecutionTool(max_output_size=50) |
|
# long_output_code = "for i in range(20): print(f'Line {i}')" |
|
# print("\n--- Test 10: Max Output Size ---") |
|
# result10 = tool_max_output.forward(code_string=long_output_code) |
|
# print(result10) |
|
# assert result10['success'] |
|
# assert "... [output truncated]" in result10['output'] |
|
# assert len(result10['output']) <= 50 + len("... [output truncated]") + 5 # a bit of leeway |
|
|
|
print("\nAll tests seem to have passed (check output for details).") |