Yago Bolivar commited on
Commit
224d111
·
1 Parent(s): b09a8ba

feat: Refactor CodeExecutionTool for improved readability and maintainability

Browse files
Files changed (1) hide show
  1. src/python_tool.py +143 -152
src/python_tool.py CHANGED
@@ -1,36 +1,42 @@
1
  import ast
2
  import contextlib
3
  import io
4
- import signal
 
5
  import re
 
6
  import traceback
7
- from typing import Dict, Any, Optional, Union, List
 
8
  from smolagents.tools import Tool
9
- import os
10
- import logging
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
16
  class CodeExecutionTool(Tool):
17
  """
18
  Executes Python code snippets safely with timeout protection.
19
  Useful for data processing, analysis, and transformation.
20
  Includes special utilities for web data processing and robust error handling.
21
  """
 
22
  name = "python_executor"
23
  description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery."
 
24
  inputs = {
25
- 'code_string': {'type': 'string', 'description': 'The Python code to execute.', 'nullable': True},
26
- 'filepath': {'type': 'string', 'description': 'Path to a Python file to execute.', 'nullable': True}
27
  }
 
28
  outputs = {
29
- 'success': {'type': 'boolean', 'description': 'Whether the code executed successfully.'},
30
- 'output': {'type': 'string', 'description': 'The captured stdout or formatted result.', 'nullable': True},
31
- 'error': {'type': 'string', 'description': 'Error message if execution failed.', 'nullable': True},
32
- 'result_value': {'type': 'any', 'description': 'The final expression value if applicable.', 'nullable': True}
33
  }
 
34
  output_type = "object"
35
 
36
  def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs):
@@ -38,25 +44,34 @@ class CodeExecutionTool(Tool):
38
  self.timeout = timeout
39
  self.max_output_size = max_output_size
40
  self.banned_modules = [
41
- 'os', 'subprocess', 'sys', 'builtins', 'importlib',
42
- 'pickle', 'requests', 'socket', 'shutil', 'ctypes', 'multiprocessing'
 
 
 
 
 
 
 
 
 
43
  ]
44
  self.is_initialized = True
45
- # Add utility functions that will be available to executed code
46
  self._utility_functions = self._get_utility_functions()
47
 
48
- def _get_utility_functions(self):
49
- """Define utility functions that will be available in the executed code"""
50
- utility_code = """
51
- # Utility functions for web data processing
52
  def extract_pattern(text, pattern, group=0, all_matches=False):
53
  """
54
- "Extract data using regex pattern from text.
 
55
  Args:
56
  text (str): Text to search in
57
  pattern (str): Regex pattern to use
58
  group (int): Capture group to return (default 0 - entire match)
59
  all_matches (bool): If True, return all matches, otherwise just first
 
60
  Returns:
61
  Matched string(s) or None if no match
62
  """
@@ -64,108 +79,117 @@ def extract_pattern(text, pattern, group=0, all_matches=False):
64
  if not text or not pattern:
65
  print("Warning: Empty text or pattern provided to extract_pattern")
66
  return None
67
-
68
  try:
69
  matches = re.finditer(pattern, text)
70
  results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches]
71
-
72
  if not results:
73
  print(f"No matches found for pattern '{pattern}'")
74
  return None
75
-
76
- if all_matches:
77
- return results
78
- else:
79
- return results[0]
80
  except Exception as e:
81
  print(f"Error in extract_pattern: {e}")
82
  return None
83
 
 
84
  def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False):
85
  """
86
  Clean text by removing extra whitespace and optionally special characters.
 
87
  Args:
88
  text (str): Text to clean
89
  remove_extra_whitespace (bool): If True, replace multiple spaces with single space
90
  remove_special_chars (bool): If True, remove special characters
 
91
  Returns:
92
  Cleaned string
93
  """
94
  import re
95
  if not text:
96
  return ""
97
-
98
  # Replace newlines and tabs with spaces
99
- text = re.sub(r'[\\n\\t\\r]+', ' ', text)
100
-
101
  if remove_special_chars:
102
  # Keep only alphanumeric, spaces, and basic punctuation
103
- text = re.sub(r'[^\\w\\s.,;:!?\'"()-]', '', text)
104
-
105
  if remove_extra_whitespace:
106
  # Replace multiple spaces with single space
107
- text = re.sub(r'\\s+', ' ', text)
108
-
109
  return text.strip()
110
-
 
111
  def parse_table_text(table_text):
112
  """
113
- Parse table-like text into list of rows
 
114
  Args:
115
  table_text (str): Text containing table-like data
 
116
  Returns:
117
  List of rows (each row is a list of cells)
118
  """
 
 
119
  rows = []
120
- lines = table_text.strip().split('\\n')
121
-
122
  for line in lines:
123
  # Skip empty lines
124
  if not line.strip():
125
  continue
126
-
127
  # Split by whitespace or common separators
128
- cells = re.split(r'\\s{2,}|\\t+|\\|+', line.strip())
129
  # Clean up cells
130
  cells = [cell.strip() for cell in cells if cell.strip()]
131
-
132
  if cells:
133
  rows.append(cells)
134
-
135
  # Print parsing result for debugging
136
  print(f"Parsed {len(rows)} rows from table text")
137
  if rows and len(rows) > 0:
138
  print(f"First row (columns: {len(rows[0])}): {rows[0]}")
139
-
140
  return rows
141
 
 
142
  def safe_float(text):
143
  """
144
  Safely convert text to float, handling various formats.
 
145
  Args:
146
  text (str): Text to convert
 
147
  Returns:
148
  float or None if conversion fails
149
  """
 
 
150
  if not text:
151
  return None
152
-
153
  # Remove currency symbols, commas in numbers, etc.
154
- text = re.sub(r'[^0-9.-]', '', str(text))
155
-
156
  try:
157
  return float(text)
158
  except ValueError:
159
  print(f"Warning: Could not convert '{text}' to float")
160
  return None
161
- """
162
- return utility_code
163
 
164
  def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
165
  """Perform static analysis to check for potentially harmful code."""
166
  try:
167
  parsed = ast.parse(code)
168
-
169
  # Check for banned imports
170
  imports = []
171
  for node in ast.walk(parsed):
@@ -175,26 +199,26 @@ def safe_float(text):
175
  # Ensure node.module is not None before attempting to check against banned_modules
176
  if node.module and any(banned in node.module for banned in self.banned_modules):
177
  imports.append(node.module)
178
-
179
- dangerous_imports = [imp for imp in imports if imp and any(
180
- banned in imp for banned in self.banned_modules)]
181
-
 
 
182
  if dangerous_imports:
183
  return {
184
- "safe": False,
185
- "reason": f"Potentially harmful imports detected: {dangerous_imports}"
186
  }
187
-
188
  # Check for exec/eval usage
189
  for node in ast.walk(parsed):
190
- if isinstance(node, ast.Call) and hasattr(node, 'func'):
191
- if isinstance(node.func, ast.Name) and node.func.id in ['exec', 'eval']:
192
- return {
193
- "safe": False,
194
- "reason": "Contains exec() or eval() calls"
195
- }
196
-
197
  return {"safe": True}
 
198
  except SyntaxError:
199
  return {"safe": False, "reason": "Invalid Python syntax"}
200
 
@@ -206,109 +230,97 @@ def safe_float(text):
206
  """Extract the final numeric value from output."""
207
  if not output:
208
  return None
209
-
210
  # Look for the last line that contains a number
211
- lines = output.strip().split('\n')
212
  for line in reversed(lines):
213
  # Try to interpret it as a pure number
214
  line = line.strip()
215
  try:
216
- if '.' in line:
217
- return float(line)
218
- else:
219
- return int(line)
220
  except ValueError:
221
  # Not a pure number, try to extract numbers with regex
222
- match = re.search(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$', line)
223
  if match:
224
  num_str = match.group(0)
225
  try:
226
- if '.' in num_str:
227
- return float(num_str)
228
- else:
229
- return int(num_str)
230
  except ValueError:
231
  pass
232
  return None
233
 
234
  def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]:
 
235
  if not code_string and not filepath:
236
  return {"success": False, "error": "No code string or filepath provided."}
237
  if code_string and filepath:
238
  return {"success": False, "error": "Provide either a code string or a filepath, not both."}
239
 
240
  code_to_execute = ""
 
241
  if filepath:
242
  if not os.path.exists(filepath):
243
- return {"success": False, "error": f"File not found: {filepath}"}
244
  if not filepath.endswith(".py"):
245
  return {"success": False, "error": f"File is not a Python file: {filepath}"}
246
  try:
247
- with open(filepath, 'r') as file:
248
  code_to_execute = file.read()
249
  except Exception as e:
250
  return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
251
- elif code_string:
252
  code_to_execute = code_string
253
-
254
  # Inject utility functions
255
  enhanced_code = self._utility_functions + "\n\n" + code_to_execute
256
-
257
  return self._execute_actual_code(enhanced_code)
258
 
259
  def _execute_actual_code(self, code: str) -> Dict[str, Any]:
260
  """Execute Python code and capture the output or error."""
261
  safety_check = self._analyze_code_safety(code)
262
  if not safety_check["safe"]:
263
- return {
264
- "success": False,
265
- "error": f"Safety check failed: {safety_check['reason']}"
266
- }
267
-
268
  # Capture stdout and execute the code with a timeout
269
  stdout_buffer = io.StringIO()
270
  result_value = None
271
-
272
  try:
273
  # Set timeout handler
274
  signal.signal(signal.SIGALRM, self._timeout_handler)
275
  signal.alarm(self.timeout)
276
-
277
  # Execute code and capture stdout
278
  with contextlib.redirect_stdout(stdout_buffer):
279
  # Execute the code within a new dictionary for local variables
280
  local_vars = {}
281
  exec(code, {}, local_vars)
282
-
283
  # Try to extract the result from common variable names
284
- for var_name in ['result', 'answer', 'output', 'value', 'final_result', 'data']:
285
  if var_name in local_vars:
286
  result_value = local_vars[var_name]
287
  break
288
-
289
  # Reset the alarm
290
  signal.alarm(0)
291
-
292
  # Get the captured output
293
  output = stdout_buffer.getvalue()
294
  if len(output) > self.max_output_size:
295
- output = output[:self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)"
296
-
297
  # If no result_value was found, try to extract a numeric value from the output
298
  if result_value is None:
299
  result_value = self._extract_numeric_value(output)
300
-
301
- return {
302
- "success": True,
303
- "output": output,
304
- "result_value": result_value
305
- }
306
-
307
- except TimeoutError as e:
308
- signal.alarm(0) # Reset the alarm
309
  return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"}
310
  except Exception as e:
311
- signal.alarm(0) # Reset the alarm
312
  trace = traceback.format_exc()
313
  error_msg = f"Error executing code: {str(e)}\n{trace}"
314
  return {"success": False, "error": error_msg}
@@ -316,64 +328,48 @@ def safe_float(text):
316
  # Ensure the alarm is reset
317
  signal.alarm(0)
318
 
319
- # Kept execute_file and execute_code as helper methods if direct access is ever needed,
320
- # but they now call the main _execute_actual_code method.
321
  def execute_file(self, filepath: str) -> Dict[str, Any]:
322
  """Helper to execute Python code from file."""
323
- if not os.path.exists(filepath):
324
- return {"success": False, "error": f"File not found: {filepath}"}
325
- if not filepath.endswith(".py"):
326
- return {"success": False, "error": f"File is not a Python file: {filepath}"}
327
- try:
328
- with open(filepath, 'r') as file:
329
- code = file.read()
330
- return self._execute_actual_code(code)
331
- except Exception as e:
332
- return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
333
 
334
  def execute_code(self, code: str) -> Dict[str, Any]:
335
  """Helper to execute Python code from a string."""
336
- return self._execute_actual_code(code)
337
 
338
 
339
- if __name__ == '__main__':
 
340
  tool = CodeExecutionTool(timeout=5)
 
341
 
342
  # Test 1: Safe code string
343
  safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)"
344
  print("\n--- Test 1: Safe Code String ---")
345
  result1 = tool.forward(code_string=safe_code)
346
  print(result1)
347
- assert result1['success']
348
- assert "Hello from safe code!" in result1['output']
349
- assert "20" in result1['output']
350
 
351
  # Test 2: Code with an error
352
  error_code = "print(1/0)"
353
  print("\n--- Test 2: Code with Error ---")
354
  result2 = tool.forward(code_string=error_code)
355
  print(result2)
356
- assert not result2['success']
357
- assert "ZeroDivisionError" in result2['error']
358
 
359
  # Test 3: Code with a banned import
360
  unsafe_import_code = "import os; print(os.getcwd())"
361
  print("\n--- Test 3: Unsafe Import ---")
362
  result3 = tool.forward(code_string=unsafe_import_code)
363
  print(result3)
364
- assert not result3['success']
365
- assert "Safety check failed" in result3['error']
366
- assert "os" in result3['error']
367
 
368
  # Test 4: Timeout
369
  timeout_code = "import time; time.sleep(10); print('Done sleeping')"
370
  print("\n--- Test 4: Timeout ---")
371
- # tool_timeout_short = CodeExecutionTool(timeout=2) # For testing timeout specifically
372
- # result4 = tool_timeout_short.forward(code_string=timeout_code)
373
- result4 = tool.forward(code_string=timeout_code) # Using the main tool instance with its timeout
374
  print(result4)
375
- assert not result4['success']
376
- assert "timed out" in result4['error']
377
 
378
  # Test 5: Execute from file
379
  test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')"
@@ -383,48 +379,43 @@ if __name__ == '__main__':
383
  print("\n--- Test 5: Execute from File ---")
384
  result5 = tool.forward(filepath=test_filename)
385
  print(result5)
386
- assert result5['success']
387
- assert "Hello from file!" in result5['output']
388
- assert "Sum: 12" in result5['output']
389
  os.remove(test_filename)
390
 
391
  # Test 6: File not found
392
  print("\n--- Test 6: File Not Found ---")
393
  result6 = tool.forward(filepath="non_existent_script.py")
394
  print(result6)
395
- assert not result6['success']
396
- assert "File not found" in result6['error']
397
 
398
  # Test 7: Provide both code_string and filepath
399
  print("\n--- Test 7: Both code_string and filepath ---")
400
- result7 = tool.forward(code_string="print('hello')", filepath=test_filename)
401
  print(result7)
402
- assert not result7['success']
403
- assert "Provide either a code string or a filepath, not both" in result7['error']
 
 
404
 
405
  # Test 8: Provide neither
406
  print("\n--- Test 8: Neither code_string nor filepath ---")
407
  result8 = tool.forward()
408
  print(result8)
409
- assert not result8['success']
410
- assert "No code string or filepath provided" in result8['error']
411
 
412
- # Test 9: Code that defines a function and calls it
413
  func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))"
414
  print("\n--- Test 9: Function Definition and Call ---")
415
  result9 = tool.forward(code_string=func_def_code)
416
  print(result9)
417
- assert result9['success']
418
- assert "7" in result9['output']
419
-
420
- # Test 10: Max output size
421
- # tool_max_output = CodeExecutionTool(max_output_size=50)
422
- # long_output_code = "for i in range(20): print(f'Line {i}')"
423
- # print("\n--- Test 10: Max Output Size ---")
424
- # result10 = tool_max_output.forward(code_string=long_output_code)
425
- # print(result10)
426
- # assert result10['success']
427
- # assert "... [output truncated]" in result10['output']
428
- # assert len(result10['output']) <= 50 + len("... [output truncated]") + 5 # a bit of leeway
429
-
430
- print("\nAll tests seem to have passed (check output for details).")
 
1
  import ast
2
  import contextlib
3
  import io
4
+ import logging
5
+ import os
6
  import re
7
+ import signal
8
  import traceback
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
  from smolagents.tools import Tool
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+
18
  class CodeExecutionTool(Tool):
19
  """
20
  Executes Python code snippets safely with timeout protection.
21
  Useful for data processing, analysis, and transformation.
22
  Includes special utilities for web data processing and robust error handling.
23
  """
24
+
25
  name = "python_executor"
26
  description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery."
27
+
28
  inputs = {
29
+ "code_string": {"type": "string", "description": "The Python code to execute.", "nullable": True},
30
+ "filepath": {"type": "string", "description": "Path to a Python file to execute.", "nullable": True},
31
  }
32
+
33
  outputs = {
34
+ "success": {"type": "boolean", "description": "Whether the code executed successfully."},
35
+ "output": {"type": "string", "description": "The captured stdout or formatted result.", "nullable": True},
36
+ "error": {"type": "string", "description": "Error message if execution failed.", "nullable": True},
37
+ "result_value": {"type": "any", "description": "The final expression value if applicable.", "nullable": True},
38
  }
39
+
40
  output_type = "object"
41
 
42
  def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs):
 
44
  self.timeout = timeout
45
  self.max_output_size = max_output_size
46
  self.banned_modules = [
47
+ "os",
48
+ "subprocess",
49
+ "sys",
50
+ "builtins",
51
+ "importlib",
52
+ "pickle",
53
+ "requests",
54
+ "socket",
55
+ "shutil",
56
+ "ctypes",
57
+ "multiprocessing",
58
  ]
59
  self.is_initialized = True
 
60
  self._utility_functions = self._get_utility_functions()
61
 
62
+ def _get_utility_functions(self) -> str:
63
+ """Define utility functions that will be available in the executed code."""
64
+ return '''
 
65
  def extract_pattern(text, pattern, group=0, all_matches=False):
66
  """
67
+ Extract data using regex pattern from text.
68
+
69
  Args:
70
  text (str): Text to search in
71
  pattern (str): Regex pattern to use
72
  group (int): Capture group to return (default 0 - entire match)
73
  all_matches (bool): If True, return all matches, otherwise just first
74
+
75
  Returns:
76
  Matched string(s) or None if no match
77
  """
 
79
  if not text or not pattern:
80
  print("Warning: Empty text or pattern provided to extract_pattern")
81
  return None
82
+
83
  try:
84
  matches = re.finditer(pattern, text)
85
  results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches]
86
+
87
  if not results:
88
  print(f"No matches found for pattern '{pattern}'")
89
  return None
90
+
91
+ return results if all_matches else results[0]
 
 
 
92
  except Exception as e:
93
  print(f"Error in extract_pattern: {e}")
94
  return None
95
 
96
+
97
  def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False):
98
  """
99
  Clean text by removing extra whitespace and optionally special characters.
100
+
101
  Args:
102
  text (str): Text to clean
103
  remove_extra_whitespace (bool): If True, replace multiple spaces with single space
104
  remove_special_chars (bool): If True, remove special characters
105
+
106
  Returns:
107
  Cleaned string
108
  """
109
  import re
110
  if not text:
111
  return ""
112
+
113
  # Replace newlines and tabs with spaces
114
+ text = re.sub(r"[\\n\\t\\r]+", " ", text)
115
+
116
  if remove_special_chars:
117
  # Keep only alphanumeric, spaces, and basic punctuation
118
+ text = re.sub(r"[^\w\s.,;:!?\'\"()-]", "", text)
119
+
120
  if remove_extra_whitespace:
121
  # Replace multiple spaces with single space
122
+ text = re.sub(r"\\s+", " ", text)
123
+
124
  return text.strip()
125
+
126
+
127
  def parse_table_text(table_text):
128
  """
129
+ Parse table-like text into list of rows.
130
+
131
  Args:
132
  table_text (str): Text containing table-like data
133
+
134
  Returns:
135
  List of rows (each row is a list of cells)
136
  """
137
+ import re
138
+
139
  rows = []
140
+ lines = table_text.strip().split("\\n")
141
+
142
  for line in lines:
143
  # Skip empty lines
144
  if not line.strip():
145
  continue
146
+
147
  # Split by whitespace or common separators
148
+ cells = re.split(r"\\s{2,}|\\t+|\\|+", line.strip())
149
  # Clean up cells
150
  cells = [cell.strip() for cell in cells if cell.strip()]
151
+
152
  if cells:
153
  rows.append(cells)
154
+
155
  # Print parsing result for debugging
156
  print(f"Parsed {len(rows)} rows from table text")
157
  if rows and len(rows) > 0:
158
  print(f"First row (columns: {len(rows[0])}): {rows[0]}")
159
+
160
  return rows
161
 
162
+
163
  def safe_float(text):
164
  """
165
  Safely convert text to float, handling various formats.
166
+
167
  Args:
168
  text (str): Text to convert
169
+
170
  Returns:
171
  float or None if conversion fails
172
  """
173
+ import re
174
+
175
  if not text:
176
  return None
177
+
178
  # Remove currency symbols, commas in numbers, etc.
179
+ text = re.sub(r"[^0-9.-]", "", str(text))
180
+
181
  try:
182
  return float(text)
183
  except ValueError:
184
  print(f"Warning: Could not convert '{text}' to float")
185
  return None
186
+ '''
 
187
 
188
  def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
189
  """Perform static analysis to check for potentially harmful code."""
190
  try:
191
  parsed = ast.parse(code)
192
+
193
  # Check for banned imports
194
  imports = []
195
  for node in ast.walk(parsed):
 
199
  # Ensure node.module is not None before attempting to check against banned_modules
200
  if node.module and any(banned in node.module for banned in self.banned_modules):
201
  imports.append(node.module)
202
+
203
+ dangerous_imports = [
204
+ imp for imp in imports
205
+ if imp and any(banned in imp for banned in self.banned_modules)
206
+ ]
207
+
208
  if dangerous_imports:
209
  return {
210
+ "safe": False,
211
+ "reason": f"Potentially harmful imports detected: {dangerous_imports}",
212
  }
213
+
214
  # Check for exec/eval usage
215
  for node in ast.walk(parsed):
216
+ if isinstance(node, ast.Call) and hasattr(node, "func"):
217
+ if isinstance(node.func, ast.Name) and node.func.id in ["exec", "eval"]:
218
+ return {"safe": False, "reason": "Contains exec() or eval() calls"}
219
+
 
 
 
220
  return {"safe": True}
221
+
222
  except SyntaxError:
223
  return {"safe": False, "reason": "Invalid Python syntax"}
224
 
 
230
  """Extract the final numeric value from output."""
231
  if not output:
232
  return None
233
+
234
  # Look for the last line that contains a number
235
+ lines = output.strip().split("\n")
236
  for line in reversed(lines):
237
  # Try to interpret it as a pure number
238
  line = line.strip()
239
  try:
240
+ return float(line) if "." in line else int(line)
 
 
 
241
  except ValueError:
242
  # Not a pure number, try to extract numbers with regex
243
+ match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$", line)
244
  if match:
245
  num_str = match.group(0)
246
  try:
247
+ return float(num_str) if "." in num_str else int(num_str)
 
 
 
248
  except ValueError:
249
  pass
250
  return None
251
 
252
  def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]:
253
+ """Main entry point for code execution."""
254
  if not code_string and not filepath:
255
  return {"success": False, "error": "No code string or filepath provided."}
256
  if code_string and filepath:
257
  return {"success": False, "error": "Provide either a code string or a filepath, not both."}
258
 
259
  code_to_execute = ""
260
+
261
  if filepath:
262
  if not os.path.exists(filepath):
263
+ return {"success": False, "error": f"File not found: {filepath}"}
264
  if not filepath.endswith(".py"):
265
  return {"success": False, "error": f"File is not a Python file: {filepath}"}
266
  try:
267
+ with open(filepath, "r") as file:
268
  code_to_execute = file.read()
269
  except Exception as e:
270
  return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"}
271
+ else:
272
  code_to_execute = code_string
273
+
274
  # Inject utility functions
275
  enhanced_code = self._utility_functions + "\n\n" + code_to_execute
 
276
  return self._execute_actual_code(enhanced_code)
277
 
278
  def _execute_actual_code(self, code: str) -> Dict[str, Any]:
279
  """Execute Python code and capture the output or error."""
280
  safety_check = self._analyze_code_safety(code)
281
  if not safety_check["safe"]:
282
+ return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"}
283
+
 
 
 
284
  # Capture stdout and execute the code with a timeout
285
  stdout_buffer = io.StringIO()
286
  result_value = None
287
+
288
  try:
289
  # Set timeout handler
290
  signal.signal(signal.SIGALRM, self._timeout_handler)
291
  signal.alarm(self.timeout)
292
+
293
  # Execute code and capture stdout
294
  with contextlib.redirect_stdout(stdout_buffer):
295
  # Execute the code within a new dictionary for local variables
296
  local_vars = {}
297
  exec(code, {}, local_vars)
298
+
299
  # Try to extract the result from common variable names
300
+ for var_name in ["result", "answer", "output", "value", "final_result", "data"]:
301
  if var_name in local_vars:
302
  result_value = local_vars[var_name]
303
  break
304
+
305
  # Reset the alarm
306
  signal.alarm(0)
307
+
308
  # Get the captured output
309
  output = stdout_buffer.getvalue()
310
  if len(output) > self.max_output_size:
311
+ output = output[: self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)"
312
+
313
  # If no result_value was found, try to extract a numeric value from the output
314
  if result_value is None:
315
  result_value = self._extract_numeric_value(output)
316
+
317
+ return {"success": True, "output": output, "result_value": result_value}
318
+
319
+ except TimeoutError:
320
+ signal.alarm(0)
 
 
 
 
321
  return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"}
322
  except Exception as e:
323
+ signal.alarm(0)
324
  trace = traceback.format_exc()
325
  error_msg = f"Error executing code: {str(e)}\n{trace}"
326
  return {"success": False, "error": error_msg}
 
328
  # Ensure the alarm is reset
329
  signal.alarm(0)
330
 
331
+ # Helper methods for backward compatibility
 
332
  def execute_file(self, filepath: str) -> Dict[str, Any]:
333
  """Helper to execute Python code from file."""
334
+ return self.forward(filepath=filepath)
 
 
 
 
 
 
 
 
 
335
 
336
  def execute_code(self, code: str) -> Dict[str, Any]:
337
  """Helper to execute Python code from a string."""
338
+ return self.forward(code_string=code)
339
 
340
 
341
+ def _run_tests():
342
+ """Run comprehensive tests for the CodeExecutionTool."""
343
  tool = CodeExecutionTool(timeout=5)
344
+ test_results = []
345
 
346
  # Test 1: Safe code string
347
  safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)"
348
  print("\n--- Test 1: Safe Code String ---")
349
  result1 = tool.forward(code_string=safe_code)
350
  print(result1)
351
+ test_results.append(result1["success"] and "Hello from safe code!" in result1["output"])
 
 
352
 
353
  # Test 2: Code with an error
354
  error_code = "print(1/0)"
355
  print("\n--- Test 2: Code with Error ---")
356
  result2 = tool.forward(code_string=error_code)
357
  print(result2)
358
+ test_results.append(not result2["success"] and "ZeroDivisionError" in result2["error"])
 
359
 
360
  # Test 3: Code with a banned import
361
  unsafe_import_code = "import os; print(os.getcwd())"
362
  print("\n--- Test 3: Unsafe Import ---")
363
  result3 = tool.forward(code_string=unsafe_import_code)
364
  print(result3)
365
+ test_results.append(not result3["success"] and "Safety check failed" in result3["error"])
 
 
366
 
367
  # Test 4: Timeout
368
  timeout_code = "import time; time.sleep(10); print('Done sleeping')"
369
  print("\n--- Test 4: Timeout ---")
370
+ result4 = tool.forward(code_string=timeout_code)
 
 
371
  print(result4)
372
+ test_results.append(not result4["success"] and "timed out" in result4["error"])
 
373
 
374
  # Test 5: Execute from file
375
  test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')"
 
379
  print("\n--- Test 5: Execute from File ---")
380
  result5 = tool.forward(filepath=test_filename)
381
  print(result5)
382
+ test_results.append(result5["success"] and "Hello from file!" in result5["output"])
 
 
383
  os.remove(test_filename)
384
 
385
  # Test 6: File not found
386
  print("\n--- Test 6: File Not Found ---")
387
  result6 = tool.forward(filepath="non_existent_script.py")
388
  print(result6)
389
+ test_results.append(not result6["success"] and "File not found" in result6["error"])
 
390
 
391
  # Test 7: Provide both code_string and filepath
392
  print("\n--- Test 7: Both code_string and filepath ---")
393
+ result7 = tool.forward(code_string="print('hello')", filepath="dummy.py")
394
  print(result7)
395
+ test_results.append(
396
+ not result7["success"]
397
+ and "Provide either a code string or a filepath, not both" in result7["error"]
398
+ )
399
 
400
  # Test 8: Provide neither
401
  print("\n--- Test 8: Neither code_string nor filepath ---")
402
  result8 = tool.forward()
403
  print(result8)
404
+ test_results.append(not result8["success"] and "No code string or filepath provided" in result8["error"])
 
405
 
406
+ # Test 9: Function definition and call
407
  func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))"
408
  print("\n--- Test 9: Function Definition and Call ---")
409
  result9 = tool.forward(code_string=func_def_code)
410
  print(result9)
411
+ test_results.append(result9["success"] and "7" in result9["output"])
412
+
413
+ print(f"\nTests passed: {sum(test_results)}/{len(test_results)}")
414
+ if all(test_results):
415
+ print("All tests passed!")
416
+ else:
417
+ print("Some tests failed - check output for details.")
418
+
419
+
420
+ if __name__ == "__main__":
421
+ _run_tests()