import ast import contextlib import io import logging import os import re import signal import traceback from typing import Any, Dict, List, Optional, Union from smolagents.tools import Tool # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CodeExecutionTool(Tool): """ Executes Python code snippets safely with timeout protection. Useful for data processing, analysis, and transformation. Includes special utilities for web data processing and robust error handling. """ name = "python_executor" description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery." inputs = { "code_string": {"type": "string", "description": "The Python code to execute.", "nullable": True}, "filepath": {"type": "string", "description": "Path to a Python file to execute.", "nullable": True}, } outputs = { "success": {"type": "boolean", "description": "Whether the code executed successfully."}, "output": {"type": "string", "description": "The captured stdout or formatted result.", "nullable": True}, "error": {"type": "string", "description": "Error message if execution failed.", "nullable": True}, "result_value": {"type": "any", "description": "The final expression value if applicable.", "nullable": True}, } output_type = "object" def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs): super().__init__(*args, **kwargs) self.timeout = timeout self.max_output_size = max_output_size self.banned_modules = [ "os", "subprocess", "sys", "builtins", "importlib", "pickle", "requests", "socket", "shutil", "ctypes", "multiprocessing", ] self.is_initialized = True self._utility_functions = self._get_utility_functions() def _get_utility_functions(self) -> str: """Define utility functions that will be available in the executed code.""" return ''' def extract_pattern(text, pattern, group=0, all_matches=False): """ Extract data using regex pattern from text. Args: text (str): Text to search in pattern (str): Regex pattern to use group (int): Capture group to return (default 0 - entire match) all_matches (bool): If True, return all matches, otherwise just first Returns: Matched string(s) or None if no match """ import re if not text or not pattern: print("Warning: Empty text or pattern provided to extract_pattern") return None try: matches = re.finditer(pattern, text) results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches] if not results: print(f"No matches found for pattern '{pattern}'") return None return results if all_matches else results[0] except Exception as e: print(f"Error in extract_pattern: {e}") return None def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False): """ Clean text by removing extra whitespace and optionally special characters. Args: text (str): Text to clean remove_extra_whitespace (bool): If True, replace multiple spaces with single space remove_special_chars (bool): If True, remove special characters Returns: Cleaned string """ import re if not text: return "" # Replace newlines and tabs with spaces text = re.sub(r"[\\n\\t\\r]+", " ", text) if remove_special_chars: # Keep only alphanumeric, spaces, and basic punctuation text = re.sub(r"[^\w\s.,;:!?\'\"()-]", "", text) if remove_extra_whitespace: # Replace multiple spaces with single space text = re.sub(r"\\s+", " ", text) return text.strip() def parse_table_text(table_text): """ Parse table-like text into list of rows. Args: table_text (str): Text containing table-like data Returns: List of rows (each row is a list of cells) """ import re rows = [] lines = table_text.strip().split("\\n") for line in lines: # Skip empty lines if not line.strip(): continue # Split by whitespace or common separators cells = re.split(r"\\s{2,}|\\t+|\\|+", line.strip()) # Clean up cells cells = [cell.strip() for cell in cells if cell.strip()] if cells: rows.append(cells) # Print parsing result for debugging print(f"Parsed {len(rows)} rows from table text") if rows and len(rows) > 0: print(f"First row (columns: {len(rows[0])}): {rows[0]}") return rows def safe_float(text): """ Safely convert text to float, handling various formats. Args: text (str): Text to convert Returns: float or None if conversion fails """ import re if not text: return None # Remove currency symbols, commas in numbers, etc. text = re.sub(r"[^0-9.-]", "", str(text)) try: return float(text) except ValueError: print(f"Warning: Could not convert '{text}' to float") return None ''' def _analyze_code_safety(self, code: str) -> Dict[str, Any]: """Perform static analysis to check for potentially harmful code.""" try: parsed = ast.parse(code) # Check for banned imports imports = [] for node in ast.walk(parsed): if isinstance(node, ast.Import): imports.extend(n.name for n in node.names) elif isinstance(node, ast.ImportFrom): # Ensure node.module is not None before attempting to check against banned_modules if node.module and any(banned in node.module for banned in self.banned_modules): imports.append(node.module) dangerous_imports = [ imp for imp in imports if imp and any(banned in imp for banned in self.banned_modules) ] if dangerous_imports: return { "safe": False, "reason": f"Potentially harmful imports detected: {dangerous_imports}", } # Check for exec/eval usage for node in ast.walk(parsed): if isinstance(node, ast.Call) and hasattr(node, "func"): if isinstance(node.func, ast.Name) and node.func.id in ["exec", "eval"]: return {"safe": False, "reason": "Contains exec() or eval() calls"} return {"safe": True} except SyntaxError: return {"safe": False, "reason": "Invalid Python syntax"} def _timeout_handler(self, signum, frame): """Handler for timeout signal.""" raise TimeoutError(f"Code execution timed out after {self.timeout} seconds") def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]: """Extract the final numeric value from output.""" if not output: return None # Look for the last line that contains a number lines = output.strip().split("\n") for line in reversed(lines): # Try to interpret it as a pure number line = line.strip() try: return float(line) if "." in line else int(line) except ValueError: # Not a pure number, try to extract numbers with regex match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$", line) if match: num_str = match.group(0) try: return float(num_str) if "." in num_str else int(num_str) except ValueError: pass return None def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]: """Main entry point for code execution.""" if not code_string and not filepath: return {"success": False, "error": "No code string or filepath provided."} if code_string and filepath: return {"success": False, "error": "Provide either a code string or a filepath, not both."} code_to_execute = "" if filepath: if not os.path.exists(filepath): return {"success": False, "error": f"File not found: {filepath}"} if not filepath.endswith(".py"): return {"success": False, "error": f"File is not a Python file: {filepath}"} try: with open(filepath, "r") as file: code_to_execute = file.read() except Exception as e: return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} else: code_to_execute = code_string # Inject utility functions enhanced_code = self._utility_functions + "\n\n" + code_to_execute return self._execute_actual_code(enhanced_code) def _execute_actual_code(self, code: str) -> Dict[str, Any]: """Execute Python code and capture the output or error.""" safety_check = self._analyze_code_safety(code) if not safety_check["safe"]: return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"} # Capture stdout and execute the code with a timeout stdout_buffer = io.StringIO() result_value = None try: # Set timeout handler signal.signal(signal.SIGALRM, self._timeout_handler) signal.alarm(self.timeout) # Execute code and capture stdout with contextlib.redirect_stdout(stdout_buffer): # Execute the code within a new dictionary for local variables local_vars = {} exec(code, {}, local_vars) # Try to extract the result from common variable names for var_name in ["result", "answer", "output", "value", "final_result", "data"]: if var_name in local_vars: result_value = local_vars[var_name] break # Reset the alarm signal.alarm(0) # Get the captured output output = stdout_buffer.getvalue() if len(output) > self.max_output_size: output = output[: self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)" # If no result_value was found, try to extract a numeric value from the output if result_value is None: result_value = self._extract_numeric_value(output) return {"success": True, "output": output, "result_value": result_value} except TimeoutError: signal.alarm(0) return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"} except Exception as e: signal.alarm(0) trace = traceback.format_exc() error_msg = f"Error executing code: {str(e)}\n{trace}" return {"success": False, "error": error_msg} finally: # Ensure the alarm is reset signal.alarm(0) # Helper methods for backward compatibility def execute_file(self, filepath: str) -> Dict[str, Any]: """Helper to execute Python code from file.""" return self.forward(filepath=filepath) def execute_code(self, code: str) -> Dict[str, Any]: """Helper to execute Python code from a string.""" return self.forward(code_string=code) def _run_tests(): """Run comprehensive tests for the CodeExecutionTool.""" tool = CodeExecutionTool(timeout=5) test_results = [] # Test 1: Safe code string safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)" print("\n--- Test 1: Safe Code String ---") result1 = tool.forward(code_string=safe_code) print(result1) test_results.append(result1["success"] and "Hello from safe code!" in result1["output"]) # Test 2: Code with an error error_code = "print(1/0)" print("\n--- Test 2: Code with Error ---") result2 = tool.forward(code_string=error_code) print(result2) test_results.append(not result2["success"] and "ZeroDivisionError" in result2["error"]) # Test 3: Code with a banned import unsafe_import_code = "import os; print(os.getcwd())" print("\n--- Test 3: Unsafe Import ---") result3 = tool.forward(code_string=unsafe_import_code) print(result3) test_results.append(not result3["success"] and "Safety check failed" in result3["error"]) # Test 4: Timeout timeout_code = "import time; time.sleep(10); print('Done sleeping')" print("\n--- Test 4: Timeout ---") result4 = tool.forward(code_string=timeout_code) print(result4) test_results.append(not result4["success"] and "timed out" in result4["error"]) # Test 5: Execute from file test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')" test_filename = "temp_test_script.py" with open(test_filename, "w") as f: f.write(test_file_content) print("\n--- Test 5: Execute from File ---") result5 = tool.forward(filepath=test_filename) print(result5) test_results.append(result5["success"] and "Hello from file!" in result5["output"]) os.remove(test_filename) # Test 6: File not found print("\n--- Test 6: File Not Found ---") result6 = tool.forward(filepath="non_existent_script.py") print(result6) test_results.append(not result6["success"] and "File not found" in result6["error"]) # Test 7: Provide both code_string and filepath print("\n--- Test 7: Both code_string and filepath ---") result7 = tool.forward(code_string="print('hello')", filepath="dummy.py") print(result7) test_results.append( not result7["success"] and "Provide either a code string or a filepath, not both" in result7["error"] ) # Test 8: Provide neither print("\n--- Test 8: Neither code_string nor filepath ---") result8 = tool.forward() print(result8) test_results.append(not result8["success"] and "No code string or filepath provided" in result8["error"]) # Test 9: Function definition and call func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))" print("\n--- Test 9: Function Definition and Call ---") result9 = tool.forward(code_string=func_def_code) print(result9) test_results.append(result9["success"] and "7" in result9["output"]) print(f"\nTests passed: {sum(test_results)}/{len(test_results)}") if all(test_results): print("All tests passed!") else: print("Some tests failed - check output for details.") if __name__ == "__main__": _run_tests()