|
import ast |
|
import contextlib |
|
import io |
|
import logging |
|
import os |
|
import re |
|
import signal |
|
import traceback |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from smolagents.tools import Tool |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CodeExecutionTool(Tool): |
|
""" |
|
Executes Python code snippets safely with timeout protection. |
|
Useful for data processing, analysis, and transformation. |
|
Includes special utilities for web data processing and robust error handling. |
|
""" |
|
|
|
name = "python_executor" |
|
description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery." |
|
|
|
inputs = { |
|
"code_string": {"type": "string", "description": "The Python code to execute.", "nullable": True}, |
|
"filepath": {"type": "string", "description": "Path to a Python file to execute.", "nullable": True}, |
|
} |
|
|
|
outputs = { |
|
"success": {"type": "boolean", "description": "Whether the code executed successfully."}, |
|
"output": {"type": "string", "description": "The captured stdout or formatted result.", "nullable": True}, |
|
"error": {"type": "string", "description": "Error message if execution failed.", "nullable": True}, |
|
"result_value": {"type": "any", "description": "The final expression value if applicable.", "nullable": True}, |
|
} |
|
|
|
output_type = "object" |
|
|
|
def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.timeout = timeout |
|
self.max_output_size = max_output_size |
|
self.banned_modules = [ |
|
"os", |
|
"subprocess", |
|
"sys", |
|
"builtins", |
|
"importlib", |
|
"pickle", |
|
"requests", |
|
"socket", |
|
"shutil", |
|
"ctypes", |
|
"multiprocessing", |
|
] |
|
self.is_initialized = True |
|
self._utility_functions = self._get_utility_functions() |
|
|
|
def _get_utility_functions(self) -> str: |
|
"""Define utility functions that will be available in the executed code.""" |
|
return ''' |
|
def extract_pattern(text, pattern, group=0, all_matches=False): |
|
""" |
|
Extract data using regex pattern from text. |
|
|
|
Args: |
|
text (str): Text to search in |
|
pattern (str): Regex pattern to use |
|
group (int): Capture group to return (default 0 - entire match) |
|
all_matches (bool): If True, return all matches, otherwise just first |
|
|
|
Returns: |
|
Matched string(s) or None if no match |
|
""" |
|
import re |
|
if not text or not pattern: |
|
print("Warning: Empty text or pattern provided to extract_pattern") |
|
return None |
|
|
|
try: |
|
matches = re.finditer(pattern, text) |
|
results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches] |
|
|
|
if not results: |
|
print(f"No matches found for pattern '{pattern}'") |
|
return None |
|
|
|
return results if all_matches else results[0] |
|
except Exception as e: |
|
print(f"Error in extract_pattern: {e}") |
|
return None |
|
|
|
|
|
def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False): |
|
""" |
|
Clean text by removing extra whitespace and optionally special characters. |
|
|
|
Args: |
|
text (str): Text to clean |
|
remove_extra_whitespace (bool): If True, replace multiple spaces with single space |
|
remove_special_chars (bool): If True, remove special characters |
|
|
|
Returns: |
|
Cleaned string |
|
""" |
|
import re |
|
if not text: |
|
return "" |
|
|
|
# Replace newlines and tabs with spaces |
|
text = re.sub(r"[\\n\\t\\r]+", " ", text) |
|
|
|
if remove_special_chars: |
|
# Keep only alphanumeric, spaces, and basic punctuation |
|
text = re.sub(r"[^\w\s.,;:!?\'\"()-]", "", text) |
|
|
|
if remove_extra_whitespace: |
|
# Replace multiple spaces with single space |
|
text = re.sub(r"\\s+", " ", text) |
|
|
|
return text.strip() |
|
|
|
|
|
def parse_table_text(table_text): |
|
""" |
|
Parse table-like text into list of rows. |
|
|
|
Args: |
|
table_text (str): Text containing table-like data |
|
|
|
Returns: |
|
List of rows (each row is a list of cells) |
|
""" |
|
import re |
|
|
|
rows = [] |
|
lines = table_text.strip().split("\\n") |
|
|
|
for line in lines: |
|
# Skip empty lines |
|
if not line.strip(): |
|
continue |
|
|
|
# Split by whitespace or common separators |
|
cells = re.split(r"\\s{2,}|\\t+|\\|+", line.strip()) |
|
# Clean up cells |
|
cells = [cell.strip() for cell in cells if cell.strip()] |
|
|
|
if cells: |
|
rows.append(cells) |
|
|
|
# Print parsing result for debugging |
|
print(f"Parsed {len(rows)} rows from table text") |
|
if rows and len(rows) > 0: |
|
print(f"First row (columns: {len(rows[0])}): {rows[0]}") |
|
|
|
return rows |
|
|
|
|
|
def safe_float(text): |
|
""" |
|
Safely convert text to float, handling various formats. |
|
|
|
Args: |
|
text (str): Text to convert |
|
|
|
Returns: |
|
float or None if conversion fails |
|
""" |
|
import re |
|
|
|
if not text: |
|
return None |
|
|
|
# Remove currency symbols, commas in numbers, etc. |
|
text = re.sub(r"[^0-9.-]", "", str(text)) |
|
|
|
try: |
|
return float(text) |
|
except ValueError: |
|
print(f"Warning: Could not convert '{text}' to float") |
|
return None |
|
''' |
|
|
|
def _analyze_code_safety(self, code: str) -> Dict[str, Any]: |
|
"""Perform static analysis to check for potentially harmful code.""" |
|
try: |
|
parsed = ast.parse(code) |
|
|
|
|
|
imports = [] |
|
for node in ast.walk(parsed): |
|
if isinstance(node, ast.Import): |
|
imports.extend(n.name for n in node.names) |
|
elif isinstance(node, ast.ImportFrom): |
|
|
|
if node.module and any(banned in node.module for banned in self.banned_modules): |
|
imports.append(node.module) |
|
|
|
dangerous_imports = [ |
|
imp for imp in imports |
|
if imp and any(banned in imp for banned in self.banned_modules) |
|
] |
|
|
|
if dangerous_imports: |
|
return { |
|
"safe": False, |
|
"reason": f"Potentially harmful imports detected: {dangerous_imports}", |
|
} |
|
|
|
|
|
for node in ast.walk(parsed): |
|
if isinstance(node, ast.Call) and hasattr(node, "func"): |
|
if isinstance(node.func, ast.Name) and node.func.id in ["exec", "eval"]: |
|
return {"safe": False, "reason": "Contains exec() or eval() calls"} |
|
|
|
return {"safe": True} |
|
|
|
except SyntaxError: |
|
return {"safe": False, "reason": "Invalid Python syntax"} |
|
|
|
def _timeout_handler(self, signum, frame): |
|
"""Handler for timeout signal.""" |
|
raise TimeoutError(f"Code execution timed out after {self.timeout} seconds") |
|
|
|
def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]: |
|
"""Extract the final numeric value from output.""" |
|
if not output: |
|
return None |
|
|
|
|
|
lines = output.strip().split("\n") |
|
for line in reversed(lines): |
|
|
|
line = line.strip() |
|
try: |
|
return float(line) if "." in line else int(line) |
|
except ValueError: |
|
|
|
match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$", line) |
|
if match: |
|
num_str = match.group(0) |
|
try: |
|
return float(num_str) if "." in num_str else int(num_str) |
|
except ValueError: |
|
pass |
|
return None |
|
|
|
def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]: |
|
"""Main entry point for code execution.""" |
|
if not code_string and not filepath: |
|
return {"success": False, "error": "No code string or filepath provided."} |
|
if code_string and filepath: |
|
return {"success": False, "error": "Provide either a code string or a filepath, not both."} |
|
|
|
code_to_execute = "" |
|
|
|
if filepath: |
|
if not os.path.exists(filepath): |
|
return {"success": False, "error": f"File not found: {filepath}"} |
|
if not filepath.endswith(".py"): |
|
return {"success": False, "error": f"File is not a Python file: {filepath}"} |
|
try: |
|
with open(filepath, "r") as file: |
|
code_to_execute = file.read() |
|
except Exception as e: |
|
return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} |
|
else: |
|
code_to_execute = code_string |
|
|
|
|
|
enhanced_code = self._utility_functions + "\n\n" + code_to_execute |
|
return self._execute_actual_code(enhanced_code) |
|
|
|
def _execute_actual_code(self, code: str) -> Dict[str, Any]: |
|
"""Execute Python code and capture the output or error.""" |
|
safety_check = self._analyze_code_safety(code) |
|
if not safety_check["safe"]: |
|
return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"} |
|
|
|
|
|
stdout_buffer = io.StringIO() |
|
result_value = None |
|
|
|
try: |
|
|
|
signal.signal(signal.SIGALRM, self._timeout_handler) |
|
signal.alarm(self.timeout) |
|
|
|
|
|
with contextlib.redirect_stdout(stdout_buffer): |
|
|
|
local_vars = {} |
|
exec(code, {}, local_vars) |
|
|
|
|
|
for var_name in ["result", "answer", "output", "value", "final_result", "data"]: |
|
if var_name in local_vars: |
|
result_value = local_vars[var_name] |
|
break |
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
output = stdout_buffer.getvalue() |
|
if len(output) > self.max_output_size: |
|
output = output[: self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)" |
|
|
|
|
|
if result_value is None: |
|
result_value = self._extract_numeric_value(output) |
|
|
|
return {"success": True, "output": output, "result_value": result_value} |
|
|
|
except TimeoutError: |
|
signal.alarm(0) |
|
return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"} |
|
except Exception as e: |
|
signal.alarm(0) |
|
trace = traceback.format_exc() |
|
error_msg = f"Error executing code: {str(e)}\n{trace}" |
|
return {"success": False, "error": error_msg} |
|
finally: |
|
|
|
signal.alarm(0) |
|
|
|
|
|
def execute_file(self, filepath: str) -> Dict[str, Any]: |
|
"""Helper to execute Python code from file.""" |
|
return self.forward(filepath=filepath) |
|
|
|
def execute_code(self, code: str) -> Dict[str, Any]: |
|
"""Helper to execute Python code from a string.""" |
|
return self.forward(code_string=code) |
|
|
|
|
|
def _run_tests(): |
|
"""Run comprehensive tests for the CodeExecutionTool.""" |
|
tool = CodeExecutionTool(timeout=5) |
|
test_results = [] |
|
|
|
|
|
safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)" |
|
print("\n--- Test 1: Safe Code String ---") |
|
result1 = tool.forward(code_string=safe_code) |
|
print(result1) |
|
test_results.append(result1["success"] and "Hello from safe code!" in result1["output"]) |
|
|
|
|
|
error_code = "print(1/0)" |
|
print("\n--- Test 2: Code with Error ---") |
|
result2 = tool.forward(code_string=error_code) |
|
print(result2) |
|
test_results.append(not result2["success"] and "ZeroDivisionError" in result2["error"]) |
|
|
|
|
|
unsafe_import_code = "import os; print(os.getcwd())" |
|
print("\n--- Test 3: Unsafe Import ---") |
|
result3 = tool.forward(code_string=unsafe_import_code) |
|
print(result3) |
|
test_results.append(not result3["success"] and "Safety check failed" in result3["error"]) |
|
|
|
|
|
timeout_code = "import time; time.sleep(10); print('Done sleeping')" |
|
print("\n--- Test 4: Timeout ---") |
|
result4 = tool.forward(code_string=timeout_code) |
|
print(result4) |
|
test_results.append(not result4["success"] and "timed out" in result4["error"]) |
|
|
|
|
|
test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')" |
|
test_filename = "temp_test_script.py" |
|
with open(test_filename, "w") as f: |
|
f.write(test_file_content) |
|
print("\n--- Test 5: Execute from File ---") |
|
result5 = tool.forward(filepath=test_filename) |
|
print(result5) |
|
test_results.append(result5["success"] and "Hello from file!" in result5["output"]) |
|
os.remove(test_filename) |
|
|
|
|
|
print("\n--- Test 6: File Not Found ---") |
|
result6 = tool.forward(filepath="non_existent_script.py") |
|
print(result6) |
|
test_results.append(not result6["success"] and "File not found" in result6["error"]) |
|
|
|
|
|
print("\n--- Test 7: Both code_string and filepath ---") |
|
result7 = tool.forward(code_string="print('hello')", filepath="dummy.py") |
|
print(result7) |
|
test_results.append( |
|
not result7["success"] |
|
and "Provide either a code string or a filepath, not both" in result7["error"] |
|
) |
|
|
|
|
|
print("\n--- Test 8: Neither code_string nor filepath ---") |
|
result8 = tool.forward() |
|
print(result8) |
|
test_results.append(not result8["success"] and "No code string or filepath provided" in result8["error"]) |
|
|
|
|
|
func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))" |
|
print("\n--- Test 9: Function Definition and Call ---") |
|
result9 = tool.forward(code_string=func_def_code) |
|
print(result9) |
|
test_results.append(result9["success"] and "7" in result9["output"]) |
|
|
|
print(f"\nTests passed: {sum(test_results)}/{len(test_results)}") |
|
if all(test_results): |
|
print("All tests passed!") |
|
else: |
|
print("Some tests failed - check output for details.") |
|
|
|
|
|
if __name__ == "__main__": |
|
_run_tests() |