import ast import contextlib import io import signal import re import traceback from typing import Dict, Any, Optional, Union, List from smolagents.tools import Tool import os import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CodeExecutionTool(Tool): """ Executes Python code snippets safely with timeout protection. Useful for data processing, analysis, and transformation. Includes special utilities for web data processing and robust error handling. """ name = "python_executor" description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery." inputs = { 'code_string': {'type': 'string', 'description': 'The Python code to execute.', 'nullable': True}, 'filepath': {'type': 'string', 'description': 'Path to a Python file to execute.', 'nullable': True} } outputs = { 'success': {'type': 'boolean', 'description': 'Whether the code executed successfully.'}, 'output': {'type': 'string', 'description': 'The captured stdout or formatted result.', 'nullable': True}, 'error': {'type': 'string', 'description': 'Error message if execution failed.', 'nullable': True}, 'result_value': {'type': 'any', 'description': 'The final expression value if applicable.', 'nullable': True} } output_type = "object" def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs): super().__init__(*args, **kwargs) self.timeout = timeout self.max_output_size = max_output_size self.banned_modules = [ 'os', 'subprocess', 'sys', 'builtins', 'importlib', 'pickle', 'requests', 'socket', 'shutil', 'ctypes', 'multiprocessing' ] self.is_initialized = True # Add utility functions that will be available to executed code self._utility_functions = self._get_utility_functions() def _get_utility_functions(self): """Define utility functions that will be available in the executed code""" utility_code = """ # Utility functions for web data processing def extract_pattern(text, pattern, group=0, all_matches=False): """ "Extract data using regex pattern from text. Args: text (str): Text to search in pattern (str): Regex pattern to use group (int): Capture group to return (default 0 - entire match) all_matches (bool): If True, return all matches, otherwise just first Returns: Matched string(s) or None if no match """ import re if not text or not pattern: print("Warning: Empty text or pattern provided to extract_pattern") return None try: matches = re.finditer(pattern, text) results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches] if not results: print(f"No matches found for pattern '{pattern}'") return None if all_matches: return results else: return results[0] except Exception as e: print(f"Error in extract_pattern: {e}") return None def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False): """ Clean text by removing extra whitespace and optionally special characters. Args: text (str): Text to clean remove_extra_whitespace (bool): If True, replace multiple spaces with single space remove_special_chars (bool): If True, remove special characters Returns: Cleaned string """ import re if not text: return "" # Replace newlines and tabs with spaces text = re.sub(r'[\\n\\t\\r]+', ' ', text) if remove_special_chars: # Keep only alphanumeric, spaces, and basic punctuation text = re.sub(r'[^\\w\\s.,;:!?\'"()-]', '', text) if remove_extra_whitespace: # Replace multiple spaces with single space text = re.sub(r'\\s+', ' ', text) return text.strip() def parse_table_text(table_text): """ Parse table-like text into list of rows Args: table_text (str): Text containing table-like data Returns: List of rows (each row is a list of cells) """ rows = [] lines = table_text.strip().split('\\n') for line in lines: # Skip empty lines if not line.strip(): continue # Split by whitespace or common separators cells = re.split(r'\\s{2,}|\\t+|\\|+', line.strip()) # Clean up cells cells = [cell.strip() for cell in cells if cell.strip()] if cells: rows.append(cells) # Print parsing result for debugging print(f"Parsed {len(rows)} rows from table text") if rows and len(rows) > 0: print(f"First row (columns: {len(rows[0])}): {rows[0]}") return rows def safe_float(text): """ Safely convert text to float, handling various formats. Args: text (str): Text to convert Returns: float or None if conversion fails """ if not text: return None # Remove currency symbols, commas in numbers, etc. text = re.sub(r'[^0-9.-]', '', str(text)) try: return float(text) except ValueError: print(f"Warning: Could not convert '{text}' to float") return None """ return utility_code def _analyze_code_safety(self, code: str) -> Dict[str, Any]: """Perform static analysis to check for potentially harmful code.""" try: parsed = ast.parse(code) # Check for banned imports imports = [] for node in ast.walk(parsed): if isinstance(node, ast.Import): imports.extend(n.name for n in node.names) elif isinstance(node, ast.ImportFrom): # Ensure node.module is not None before attempting to check against banned_modules if node.module and any(banned in node.module for banned in self.banned_modules): imports.append(node.module) dangerous_imports = [imp for imp in imports if imp and any( banned in imp for banned in self.banned_modules)] if dangerous_imports: return { "safe": False, "reason": f"Potentially harmful imports detected: {dangerous_imports}" } # Check for exec/eval usage for node in ast.walk(parsed): if isinstance(node, ast.Call) and hasattr(node, 'func'): if isinstance(node.func, ast.Name) and node.func.id in ['exec', 'eval']: return { "safe": False, "reason": "Contains exec() or eval() calls" } return {"safe": True} except SyntaxError: return {"safe": False, "reason": "Invalid Python syntax"} def _timeout_handler(self, signum, frame): """Handler for timeout signal.""" raise TimeoutError(f"Code execution timed out after {self.timeout} seconds") def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]: """Extract the final numeric value from output.""" if not output: return None # Look for the last line that contains a number lines = output.strip().split('\n') for line in reversed(lines): # Try to interpret it as a pure number line = line.strip() try: if '.' in line: return float(line) else: return int(line) except ValueError: # Not a pure number, try to extract numbers with regex match = re.search(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$', line) if match: num_str = match.group(0) try: if '.' in num_str: return float(num_str) else: return int(num_str) except ValueError: pass return None def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]: if not code_string and not filepath: return {"success": False, "error": "No code string or filepath provided."} if code_string and filepath: return {"success": False, "error": "Provide either a code string or a filepath, not both."} code_to_execute = "" if filepath: if not os.path.exists(filepath): return {"success": False, "error": f"File not found: {filepath}"} if not filepath.endswith(".py"): return {"success": False, "error": f"File is not a Python file: {filepath}"} try: with open(filepath, 'r') as file: code_to_execute = file.read() except Exception as e: return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} elif code_string: code_to_execute = code_string # Inject utility functions enhanced_code = self._utility_functions + "\n\n" + code_to_execute return self._execute_actual_code(enhanced_code) def _execute_actual_code(self, code: str) -> Dict[str, Any]: """Execute Python code and capture the output or error.""" safety_check = self._analyze_code_safety(code) if not safety_check["safe"]: return { "success": False, "error": f"Safety check failed: {safety_check['reason']}" } # Capture stdout and execute the code with a timeout stdout_buffer = io.StringIO() result_value = None try: # Set timeout handler signal.signal(signal.SIGALRM, self._timeout_handler) signal.alarm(self.timeout) # Execute code and capture stdout with contextlib.redirect_stdout(stdout_buffer): # Execute the code within a new dictionary for local variables local_vars = {} exec(code, {}, local_vars) # Try to extract the result from common variable names for var_name in ['result', 'answer', 'output', 'value', 'final_result', 'data']: if var_name in local_vars: result_value = local_vars[var_name] break # Reset the alarm signal.alarm(0) # Get the captured output output = stdout_buffer.getvalue() if len(output) > self.max_output_size: output = output[:self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)" # If no result_value was found, try to extract a numeric value from the output if result_value is None: result_value = self._extract_numeric_value(output) return { "success": True, "output": output, "result_value": result_value } except TimeoutError as e: signal.alarm(0) # Reset the alarm return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"} except Exception as e: signal.alarm(0) # Reset the alarm trace = traceback.format_exc() error_msg = f"Error executing code: {str(e)}\n{trace}" return {"success": False, "error": error_msg} finally: # Ensure the alarm is reset signal.alarm(0) # Kept execute_file and execute_code as helper methods if direct access is ever needed, # but they now call the main _execute_actual_code method. def execute_file(self, filepath: str) -> Dict[str, Any]: """Helper to execute Python code from file.""" if not os.path.exists(filepath): return {"success": False, "error": f"File not found: {filepath}"} if not filepath.endswith(".py"): return {"success": False, "error": f"File is not a Python file: {filepath}"} try: with open(filepath, 'r') as file: code = file.read() return self._execute_actual_code(code) except Exception as e: return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} def execute_code(self, code: str) -> Dict[str, Any]: """Helper to execute Python code from a string.""" return self._execute_actual_code(code) if __name__ == '__main__': tool = CodeExecutionTool(timeout=5) # Test 1: Safe code string safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)" print("\n--- Test 1: Safe Code String ---") result1 = tool.forward(code_string=safe_code) print(result1) assert result1['success'] assert "Hello from safe code!" in result1['output'] assert "20" in result1['output'] # Test 2: Code with an error error_code = "print(1/0)" print("\n--- Test 2: Code with Error ---") result2 = tool.forward(code_string=error_code) print(result2) assert not result2['success'] assert "ZeroDivisionError" in result2['error'] # Test 3: Code with a banned import unsafe_import_code = "import os; print(os.getcwd())" print("\n--- Test 3: Unsafe Import ---") result3 = tool.forward(code_string=unsafe_import_code) print(result3) assert not result3['success'] assert "Safety check failed" in result3['error'] assert "os" in result3['error'] # Test 4: Timeout timeout_code = "import time; time.sleep(10); print('Done sleeping')" print("\n--- Test 4: Timeout ---") # tool_timeout_short = CodeExecutionTool(timeout=2) # For testing timeout specifically # result4 = tool_timeout_short.forward(code_string=timeout_code) result4 = tool.forward(code_string=timeout_code) # Using the main tool instance with its timeout print(result4) assert not result4['success'] assert "timed out" in result4['error'] # Test 5: Execute from file test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')" test_filename = "temp_test_script.py" with open(test_filename, "w") as f: f.write(test_file_content) print("\n--- Test 5: Execute from File ---") result5 = tool.forward(filepath=test_filename) print(result5) assert result5['success'] assert "Hello from file!" in result5['output'] assert "Sum: 12" in result5['output'] os.remove(test_filename) # Test 6: File not found print("\n--- Test 6: File Not Found ---") result6 = tool.forward(filepath="non_existent_script.py") print(result6) assert not result6['success'] assert "File not found" in result6['error'] # Test 7: Provide both code_string and filepath print("\n--- Test 7: Both code_string and filepath ---") result7 = tool.forward(code_string="print('hello')", filepath=test_filename) print(result7) assert not result7['success'] assert "Provide either a code string or a filepath, not both" in result7['error'] # Test 8: Provide neither print("\n--- Test 8: Neither code_string nor filepath ---") result8 = tool.forward() print(result8) assert not result8['success'] assert "No code string or filepath provided" in result8['error'] # Test 9: Code that defines a function and calls it func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))" print("\n--- Test 9: Function Definition and Call ---") result9 = tool.forward(code_string=func_def_code) print(result9) assert result9['success'] assert "7" in result9['output'] # Test 10: Max output size # tool_max_output = CodeExecutionTool(max_output_size=50) # long_output_code = "for i in range(20): print(f'Line {i}')" # print("\n--- Test 10: Max Output Size ---") # result10 = tool_max_output.forward(code_string=long_output_code) # print(result10) # assert result10['success'] # assert "... [output truncated]" in result10['output'] # assert len(result10['output']) <= 50 + len("... [output truncated]") + 5 # a bit of leeway print("\nAll tests seem to have passed (check output for details).")