stzhao commited on
Commit
8bfa5ee
·
verified ·
1 Parent(s): 33f1cbe

Update shared_vis_python_exe.py

Browse files
Files changed (1) hide show
  1. shared_vis_python_exe.py +299 -290
shared_vis_python_exe.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import io
3
  import regex
@@ -6,19 +8,21 @@ import traceback
6
  import copy
7
  import datetime
8
  import dateutil.relativedelta
9
- import multiprocess
10
- from multiprocess import Pool
11
  from typing import Any, Dict, Optional, Tuple, List, Union
12
- from pebble import ProcessPool
13
  from tqdm import tqdm
14
  from concurrent.futures import TimeoutError
15
- from functools import partial
16
- from timeout_decorator import timeout
17
  from contextlib import redirect_stdout
18
  import base64
19
  from io import BytesIO
20
  from PIL import Image
21
- import pdb
 
 
 
 
 
22
 
23
  def encode_image(image_path):
24
  with open(image_path, "rb") as image_file:
@@ -30,42 +34,212 @@ def base64_to_image(
30
  convert_mode: Optional[str] = "RGB"
31
  ) -> Union[Image.Image, None]:
32
  """
33
- Base64编码的图片字符串转换为PIL Image对象
34
-
35
  Args:
36
- base64_str: Base64编码的图片字符串(可带data:前缀)
37
- remove_prefix: 是否自动去除"data:image/..."前缀(默认True
38
- convert_mode: 转换为指定模式(如"RGB"/"RGBA"None表示不转换)
39
-
40
  Returns:
41
- PIL.Image.Image 对象,解码失败时返回None
42
 
43
  Examples:
44
  >>> img = base64_to_image("data:image/png;base64,iVBORw0KGg...")
45
  >>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False)
46
  """
47
  try:
48
- # 1. 处理Base64前缀
49
  if remove_prefix and "," in base64_str:
50
  base64_str = base64_str.split(",")[1]
51
 
52
- # 2. 解码Base64
53
  image_data = base64.b64decode(base64_str)
54
 
55
- # 3. 转换为PIL Image
56
  image = Image.open(BytesIO(image_data))
57
 
58
- # 4. 可选模式转换
59
  if convert_mode:
60
  image = image.convert(convert_mode)
61
 
62
  return image
63
 
64
  except (base64.binascii.Error, OSError, Exception) as e:
65
- print(f"Base64解码失败: {str(e)}")
66
  return None
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class GenericRuntime:
70
  GLOBAL_DICT = {}
71
  LOCAL_DICT = None
@@ -80,17 +254,14 @@ class GenericRuntime:
80
  self.exec_code(c)
81
 
82
  def exec_code(self, code_piece: str) -> None:
83
- if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
84
- r"(\s|^)?os.system\(", code_piece
85
- ):
86
  raise RuntimeError("Forbidden function calls detected")
87
-
88
 
89
-
90
- # 检测并修改plt.show()调用
91
  if "plt.show()" in code_piece:
92
  modified_code = code_piece.replace("plt.show()", """
93
- # 捕获当前图像
94
  buf = io.BytesIO()
95
  plt.savefig(buf, format='png')
96
  buf.seek(0)
@@ -98,7 +269,7 @@ _captured_image = base64.b64encode(buf.read()).decode('utf-8')
98
  _captured_figures.append(_captured_image)
99
  plt.close()
100
  """)
101
- # 确保_captured_figures变量存在
102
  if "_captured_figures" not in self._global_vars:
103
  self._global_vars["_captured_figures"] = []
104
 
@@ -123,46 +294,41 @@ plt.close()
123
 
124
 
125
  class ImageRuntime(GenericRuntime):
126
- # """支持图像处理的运行时环境"""
127
- # GLOBAL_DICT = {} # 不预加载模块,避免序列化问题
128
- # LOCAL_DICT = None
129
-
130
  HEADERS = [
131
  "import matplotlib",
132
- "matplotlib.use('Agg')", # 使用非交互式后端
133
  "import matplotlib.pyplot as plt",
134
  "from PIL import Image",
135
  "import io",
136
  "import base64",
137
  "import numpy as np",
138
- "_captured_figures = []", # 初始化图像捕获列表
139
  ]
140
 
141
  def __init__(self, messages):
142
  super().__init__()
143
- # pdb.set_trace()
144
-
145
- self._global_vars = copy.copy(self.GLOBAL_DICT)
146
- self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
147
- self._captured_figures = []
148
-
149
- for c in self.HEADERS:
150
- self.exec_code(c)
151
 
152
  image_var_dict = {}
153
  image_var_idx = 0
 
 
154
  for message_item in messages:
155
- content = message_item['content'] # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
156
  for item in content:
157
- item_type = item['type']
158
- if item_type == "image_url":
159
- item_image_url = item['image_url']['url']
160
- image = base64_to_image(item_image_url)
161
- image_var_dict[f"image_clue_{image_var_idx}"] = image
162
- image_var_idx += 1
163
-
 
 
 
 
 
164
  self.inject(image_var_dict)
165
-
166
 
167
  class DateRuntime(GenericRuntime):
168
  GLOBAL_DICT = {}
@@ -190,97 +356,24 @@ class PythonExecutor:
190
  get_answer_expr: Optional[str] = None,
191
  get_answer_from_stdout: bool = True,
192
  timeout_length: int = 20,
 
193
  ) -> None:
194
  self.runtime_class = runtime_class if runtime_class else ImageRuntime
195
- print(self.runtime_class)
196
  self.answer_symbol = get_answer_symbol
197
  self.answer_expr = get_answer_expr
198
  self.get_answer_from_stdout = get_answer_from_stdout
199
  self.timeout_length = timeout_length
 
 
200
 
201
- # Create a persistent runtime instance if messages are provided
202
- self.persistent_runtime = None
 
 
203
 
204
  def process_generation_to_code(self, gens: str):
205
  return [g.split("\n") for g in gens]
206
 
207
- # def execute(
208
- # self,
209
- # code,
210
- # messages,
211
- # get_answer_from_stdout=True,
212
- # runtime_class=None,
213
- # # run_time_instance=None,
214
- # answer_symbol=None,
215
- # answer_expr=None,
216
- # # 移除 timeout_length 参数
217
- # ) -> Tuple[Union[str, Dict[str, Any]], str]:
218
- # # print("dome")
219
- # # try:
220
- # # 在每个进程中创建新的运行时实例
221
- # print(runtime_class)
222
- # # runtime = runtime_class(messages)
223
- # runtime = self.persistent_runtime
224
-
225
- # if get_answer_from_stdout:
226
- # program_io = io.StringIO()
227
- # with redirect_stdout(program_io):
228
- # # 移除 timeout 调用
229
- # runtime.exec_code("\n".join(code))
230
- # program_io.seek(0)
231
- # result = program_io.read()
232
- # elif answer_symbol:
233
- # # 移除 timeout 调用
234
- # runtime.exec_code("\n".join(code))
235
- # result = runtime._global_vars.get(answer_symbol, "")
236
- # elif answer_expr:
237
- # # 移除 timeout 调用
238
- # runtime.exec_code("\n".join(code))
239
- # # 移除 timeout 调用
240
- # result = runtime.eval_code(answer_expr)
241
- # else:
242
- # if len(code) > 1:
243
- # # 移除 timeout 调用
244
- # runtime.exec_code("\n".join(code[:-1]))
245
- # # 移除 timeout 调用
246
- # result = runtime.eval_code(code[-1])
247
- # else:
248
- # # 移除 timeout 调用
249
- # runtime.exec_code("\n".join(code))
250
- # result = ""
251
-
252
- # # 检查是否有捕获的图像
253
- # captured_figures = runtime._global_vars.get("_captured_figures", [])
254
- # if captured_figures:
255
- # # 如果有文本输出和图像,将它们组合
256
- # if result:
257
- # result = {
258
- # 'text': result,
259
- # 'images': captured_figures
260
- # }
261
- # else:
262
- # result = {'images': captured_figures}
263
- # else:
264
- # if result:
265
- # result = {
266
- # 'text': result,
267
- # }
268
-
269
-
270
- # report = "Done"
271
- # # except Exception as e:
272
- # # result = ""
273
- # # report = f"Error: {str(e)}\n{traceback.format_exc()}"
274
-
275
- # # 确保结果可序列化
276
- # try:
277
- # pickle.dumps(result)
278
- # except Exception as e:
279
- # result = f"Result serialization error: {str(e)}"
280
- # report = f"Serialization Error: {str(e)}"
281
-
282
- # return result, report
283
-
284
  def execute(
285
  self,
286
  code,
@@ -290,67 +383,83 @@ class PythonExecutor:
290
  answer_symbol=None,
291
  answer_expr=None,
292
  ) -> Tuple[Union[str, Dict[str, Any]], str]:
293
- print(runtime_class)
294
- runtime = self.persistent_runtime
295
-
296
- try:
297
- if get_answer_from_stdout:
298
- program_io = io.StringIO()
299
- with redirect_stdout(program_io):
300
- runtime.exec_code("\n".join(code))
301
- program_io.seek(0)
302
- result = program_io.read()
303
- elif answer_symbol:
304
- runtime.exec_code("\n".join(code))
305
- result = runtime._global_vars.get(answer_symbol, "")
306
- elif answer_expr:
307
- runtime.exec_code("\n".join(code))
308
- result = runtime.eval_code(answer_expr)
 
 
309
  else:
310
- if len(code) > 1:
311
- runtime.exec_code("\n".join(code[:-1]))
312
- result = runtime.eval_code(code[-1])
313
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  runtime.exec_code("\n".join(code))
315
- result = ""
316
-
317
- # Check for captured figures
318
- captured_figures = runtime._global_vars.get("_captured_figures", [])
319
- if captured_figures:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  result = {
321
- 'text': result,
322
- 'images': captured_figures
323
- } if result else {'images': captured_figures}
324
- else:
325
- result = {'text': result} if result else {}
326
-
327
- report = "Done"
328
-
329
- except Exception as e:
330
- result = {
331
- 'error': str(e),
332
- 'traceback': traceback.format_exc()
333
- }
334
- report = f"Error: {str(e)}"
335
-
336
- # Ensure result is serializable
337
- try:
338
- pickle.dumps(result)
339
- except Exception as e:
340
- result = f"Result serialization error: {str(e)}"
341
- report = f"Serialization Error: {str(e)}"
342
-
343
- return result, report
344
 
 
345
 
346
-
347
  def apply(self, code, messages):
348
  return self.batch_apply([code], messages)[0]
349
 
350
  @staticmethod
351
  def truncate(s, max_length=400):
352
  if isinstance(s, dict):
353
- # 如果是字典(包含图像),只截断文本部分
354
  if 'text' in s:
355
  half = max_length // 2
356
  if len(s['text']) > max_length:
@@ -362,21 +471,12 @@ class PythonExecutor:
362
  s = s[:half] + "..." + s[-half:]
363
  return s
364
 
365
- def update_persistent_runtime_with_messages():
366
- pass
367
-
368
- def get_persistent_runtime(self):
369
- return self.persistent_runtime
370
-
371
  def batch_apply(self, batch_code, messages):
372
- if not self.persistent_runtime and messages:
373
- self.persistent_runtime = self.runtime_class(messages)
374
  all_code_snippets = self.process_generation_to_code(batch_code)
375
 
376
  timeout_cnt = 0
377
  all_exec_results = []
378
 
379
- # 去掉 ProcessPool,改为单进程顺序执行
380
  if len(all_code_snippets) > 100:
381
  progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
382
  else:
@@ -384,7 +484,6 @@ class PythonExecutor:
384
 
385
  for code in all_code_snippets:
386
  try:
387
- # 直接调用 self.execute,而不是用 ProcessPool
388
  result = self.execute(
389
  code,
390
  messages=messages,
@@ -392,7 +491,6 @@ class PythonExecutor:
392
  runtime_class=self.runtime_class,
393
  answer_symbol=self.answer_symbol,
394
  answer_expr=self.answer_expr,
395
- # timeout_length=self.timeout_length,
396
  )
397
  all_exec_results.append(result)
398
  except TimeoutError as error:
@@ -411,16 +509,16 @@ class PythonExecutor:
411
 
412
  batch_results = []
413
  for code, (res, report) in zip(all_code_snippets, all_exec_results):
414
- # 处理结果
415
  if isinstance(res, dict):
416
- # 如果结果包含图像,特殊处理
417
  if 'text' in res:
418
  res['text'] = str(res['text']).strip()
419
  res['text'] = self.truncate(res['text'])
420
  report = str(report).strip()
421
  report = self.truncate(report)
422
  else:
423
- # 普通文本结果
424
  res = str(res).strip()
425
  res = self.truncate(res)
426
  report = str(report).strip()
@@ -428,101 +526,12 @@ class PythonExecutor:
428
  batch_results.append((res, report))
429
  return batch_results
430
 
 
 
 
 
431
 
432
- def _test():
433
- image_path = "/mnt/petrelfs/zhaoshitian/vis_tool_inference_engine/test_data/0.JPG"
434
- image_base64 = encode_image(image_path)
435
- messages = [
436
- {
437
- "role": "user",
438
- "content": [{"type": "text", "text": "From the information on that advertising board, what is the type of this shop?"}]
439
- },
440
- {
441
- "role": "user",
442
- "content": [{"type": "text", "text": "image_clue_0"}] + [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}]
443
- }
444
- ]
445
- # 测试普通计算
446
- math_code ="""
447
- a = 1
448
- b = 2
449
- c = a + b
450
- print(c)
451
- """
452
-
453
- batch_code = [math_code]
454
-
455
- executor = PythonExecutor()
456
- predictions = executor.apply(batch_code[0], messages)
457
- print("数学计算结果:", predictions)
458
-
459
- # 测试图像显示
460
- image_code = """
461
- import matplotlib.pyplot as plt
462
- import numpy as np
463
- from PIL import Image
464
- import io
465
-
466
- # 创建一个简单的图像
467
- x = np.linspace(0, 10, 100)
468
- y = np.sin(x)
469
-
470
- plt.figure(figsize=(8, 6))
471
- plt.plot(x, y, 'r-', linewidth=2)
472
- plt.title('Sine Wave')
473
- plt.grid(True)
474
- plt.show()
475
-
476
- # 也可以显示一个简单的图像
477
- # 创建一个彩色渐变图像
478
- arr = np.zeros((100, 100, 3), dtype=np.uint8)
479
- for i in range(100):
480
- for j in range(100):
481
- arr[i, j, 0] = i # 红色通道
482
- arr[i, j, 1] = j # 绿色通道
483
- arr[i, j, 2] = 100 # 蓝色通道
484
-
485
- img = Image.fromarray(arr)
486
- plt.figure()
487
- plt.imshow(img)
488
- plt.title('Gradient Image')
489
- plt.show()
490
-
491
- print("图像生成完成")
492
- """
493
-
494
- image_code = """
495
- import matplotlib.pyplot as plt
496
- import numpy as np
497
- from PIL import Image
498
- import io
499
-
500
- plt.imshow(image_clue_0)
501
- plt.title("Original Image - Locate Advertising Board")
502
- plt.show()
503
- """
504
-
505
- image_result = executor.apply(image_code, messages)
506
- print("\n图像结果类型:", type(image_result[0]))
507
- if isinstance(image_result[0], dict) and 'images' in image_result[0]:
508
- print(f"捕获到 {len(image_result[0]['images'])} 个图像")
509
- print("第一个图像的base64编码前20个字符:", image_result[0]['images'][0][:20])
510
-
511
- # 可选:保存图像到文件
512
- for i, img_data in enumerate(image_result[0]['images']):
513
- img_bytes = base64.b64decode(img_data)
514
- with open(f"captured_image_{i}.png", "wb") as f:
515
- f.write(img_bytes)
516
- print(f"图像已保存为 captured_image_{i}.png")
517
-
518
- if 'text' in image_result[0]:
519
- print("文本输出:", image_result[0]['text'])
520
- else:
521
- print("未捕获到图像")
522
- print("结果:", image_result[0])
523
-
524
- print("\n执行状态:", image_result[1])
525
-
526
-
527
- if __name__ == "__main__":
528
- _test()
 
1
+ # shared_vis_python_exe.py
2
+
3
  import os
4
  import io
5
  import regex
 
8
  import copy
9
  import datetime
10
  import dateutil.relativedelta
11
+ import multiprocessing
12
+ from multiprocessing import Queue, Process
13
  from typing import Any, Dict, Optional, Tuple, List, Union
 
14
  from tqdm import tqdm
15
  from concurrent.futures import TimeoutError
 
 
16
  from contextlib import redirect_stdout
17
  import base64
18
  from io import BytesIO
19
  from PIL import Image
20
+ import matplotlib
21
+ matplotlib.use('Agg')
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import time
25
+ import queue
26
 
27
  def encode_image(image_path):
28
  with open(image_path, "rb") as image_file:
 
34
  convert_mode: Optional[str] = "RGB"
35
  ) -> Union[Image.Image, None]:
36
  """
37
+ Convert a Base64-encoded image string to a PIL Image object.
38
+
39
  Args:
40
+ base64_str: Base64-encoded image string (can include data: prefix)
41
+ remove_prefix: Whether to automatically remove the "data:image/..." prefix (default True)
42
+ convert_mode: Convert to the specified mode (such as "RGB"/"RGBA", None means no conversion)
43
+
44
  Returns:
45
+ PIL.Image.Image object, or None if decoding fails
46
 
47
  Examples:
48
  >>> img = base64_to_image("data:image/png;base64,iVBORw0KGg...")
49
  >>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False)
50
  """
51
  try:
52
+ # 1. Handle Base64 prefix
53
  if remove_prefix and "," in base64_str:
54
  base64_str = base64_str.split(",")[1]
55
 
56
+ # 2. Decode Base64
57
  image_data = base64.b64decode(base64_str)
58
 
59
+ # 3. Convert to PIL Image
60
  image = Image.open(BytesIO(image_data))
61
 
62
+ # 4. Optional mode conversion
63
  if convert_mode:
64
  image = image.convert(convert_mode)
65
 
66
  return image
67
 
68
  except (base64.binascii.Error, OSError, Exception) as e:
69
+ print(f"Base64 decode failed: {str(e)}")
70
  return None
71
 
72
 
73
+ class PersistentWorker:
74
+ """Persistent worker process."""
75
+
76
+ def __init__(self):
77
+ self.input_queue = multiprocessing.Queue()
78
+ self.output_queue = multiprocessing.Queue()
79
+ self.process = None
80
+ self.start()
81
+
82
+ def start(self):
83
+ """Start the worker process."""
84
+ self.process = Process(target=self._worker_loop)
85
+ self.process.daemon = True
86
+ self.process.start()
87
+
88
+ def _worker_loop(self):
89
+ """Main loop for the worker process."""
90
+ runtime = None
91
+ runtime_class = None
92
+
93
+ while True:
94
+ try:
95
+ # Get task
96
+ task = self.input_queue.get()
97
+
98
+ if task is None: # Termination signal
99
+ break
100
+
101
+ task_type = task.get('type')
102
+
103
+ if task_type == 'init':
104
+ # Initialize runtime
105
+ messages = task.get('messages', [])
106
+ runtime_class = task.get('runtime_class', ImageRuntime)
107
+ runtime = runtime_class(messages)
108
+ self.output_queue.put({
109
+ 'status': 'success',
110
+ 'result': 'Initialized'
111
+ })
112
+
113
+ elif task_type == 'execute':
114
+ # Execute code
115
+ if runtime is None:
116
+ messages = task.get('messages', [])
117
+ runtime_class = task.get('runtime_class', ImageRuntime)
118
+ runtime = runtime_class(messages)
119
+
120
+ code = task.get('code')
121
+ get_answer_from_stdout = task.get('get_answer_from_stdout', True)
122
+ answer_symbol = task.get('answer_symbol')
123
+ answer_expr = task.get('answer_expr')
124
+
125
+ try:
126
+ # Record the number of images before execution
127
+ pre_figures_count = len(runtime._global_vars.get("_captured_figures", []))
128
+
129
+ if get_answer_from_stdout:
130
+ program_io = io.StringIO()
131
+ with redirect_stdout(program_io):
132
+ runtime.exec_code("\n".join(code))
133
+ program_io.seek(0)
134
+ result = program_io.read()
135
+ elif answer_symbol:
136
+ runtime.exec_code("\n".join(code))
137
+ result = runtime._global_vars.get(answer_symbol, "")
138
+ elif answer_expr:
139
+ runtime.exec_code("\n".join(code))
140
+ result = runtime.eval_code(answer_expr)
141
+ else:
142
+ if len(code) > 1:
143
+ runtime.exec_code("\n".join(code[:-1]))
144
+ result = runtime.eval_code(code[-1])
145
+ else:
146
+ runtime.exec_code("\n".join(code))
147
+ result = ""
148
+
149
+ # Get newly generated images
150
+ all_figures = runtime._global_vars.get("_captured_figures", [])
151
+ new_figures = all_figures[pre_figures_count:]
152
+
153
+ # Build result
154
+ if new_figures:
155
+ result = {
156
+ 'text': result,
157
+ 'images': new_figures
158
+ } if result else {'images': new_figures}
159
+ else:
160
+ result = {'text': result} if result else {}
161
+
162
+ self.output_queue.put({
163
+ 'status': 'success',
164
+ 'result': result,
165
+ 'report': 'Done'
166
+ })
167
+
168
+ except Exception as e:
169
+ self.output_queue.put({
170
+ 'status': 'error',
171
+ 'error': str(e),
172
+ 'traceback': traceback.format_exc(),
173
+ 'report': f'Error: {str(e)}'
174
+ })
175
+
176
+ elif task_type == 'reset':
177
+ # Reset runtime
178
+ messages = task.get('messages', [])
179
+ runtime_class = task.get('runtime_class', ImageRuntime)
180
+ runtime = runtime_class(messages)
181
+ self.output_queue.put({
182
+ 'status': 'success',
183
+ 'result': 'Reset'
184
+ })
185
+
186
+ except Exception as e:
187
+ self.output_queue.put({
188
+ 'status': 'error',
189
+ 'error': f'Worker error: {str(e)}',
190
+ 'traceback': traceback.format_exc()
191
+ })
192
+
193
+ def execute(self, code: List[str], messages: list = None, runtime_class=None,
194
+ get_answer_from_stdout=True, answer_symbol=None, answer_expr=None, timeout: int = 30):
195
+ """Execute code."""
196
+ self.input_queue.put({
197
+ 'type': 'execute',
198
+ 'code': code,
199
+ 'messages': messages,
200
+ 'runtime_class': runtime_class,
201
+ 'get_answer_from_stdout': get_answer_from_stdout,
202
+ 'answer_symbol': answer_symbol,
203
+ 'answer_expr': answer_expr
204
+ })
205
+
206
+ try:
207
+ result = self.output_queue.get(timeout=timeout)
208
+ return result
209
+ except queue.Empty:
210
+ return {
211
+ 'status': 'error',
212
+ 'error': 'Execution timeout',
213
+ 'report': 'Timeout Error'
214
+ }
215
+
216
+ def init_runtime(self, messages: list, runtime_class=None):
217
+ """Initialize runtime."""
218
+ self.input_queue.put({
219
+ 'type': 'init',
220
+ 'messages': messages,
221
+ 'runtime_class': runtime_class
222
+ })
223
+ return self.output_queue.get()
224
+
225
+ def reset_runtime(self, messages: list = None, runtime_class=None):
226
+ """Reset runtime."""
227
+ self.input_queue.put({
228
+ 'type': 'reset',
229
+ 'messages': messages,
230
+ 'runtime_class': runtime_class
231
+ })
232
+ return self.output_queue.get()
233
+
234
+ def terminate(self):
235
+ """Terminate the worker process."""
236
+ if self.process and self.process.is_alive():
237
+ self.input_queue.put(None)
238
+ self.process.join(timeout=5)
239
+ if self.process.is_alive():
240
+ self.process.terminate()
241
+
242
+
243
  class GenericRuntime:
244
  GLOBAL_DICT = {}
245
  LOCAL_DICT = None
 
254
  self.exec_code(c)
255
 
256
  def exec_code(self, code_piece: str) -> None:
257
+ # Security check
258
+ if regex.search(r"(\s|^)?(input|os\.system|subprocess)\(", code_piece):
 
259
  raise RuntimeError("Forbidden function calls detected")
 
260
 
261
+ # Detect and modify plt.show() calls
 
262
  if "plt.show()" in code_piece:
263
  modified_code = code_piece.replace("plt.show()", """
264
+ # Capture current figure
265
  buf = io.BytesIO()
266
  plt.savefig(buf, format='png')
267
  buf.seek(0)
 
269
  _captured_figures.append(_captured_image)
270
  plt.close()
271
  """)
272
+ # Ensure _captured_figures variable exists
273
  if "_captured_figures" not in self._global_vars:
274
  self._global_vars["_captured_figures"] = []
275
 
 
294
 
295
 
296
  class ImageRuntime(GenericRuntime):
 
 
 
 
297
  HEADERS = [
298
  "import matplotlib",
299
+ "matplotlib.use('Agg')", # Use non-interactive backend
300
  "import matplotlib.pyplot as plt",
301
  "from PIL import Image",
302
  "import io",
303
  "import base64",
304
  "import numpy as np",
305
+ "_captured_figures = []", # Initialize image capture list
306
  ]
307
 
308
  def __init__(self, messages):
309
  super().__init__()
 
 
 
 
 
 
 
 
310
 
311
  image_var_dict = {}
312
  image_var_idx = 0
313
+ init_captured_figures = []
314
+
315
  for message_item in messages:
316
+ content = message_item['content']
317
  for item in content:
318
+ if isinstance(item, dict):
319
+ item_type = item.get('type')
320
+ if item_type == "image_url":
321
+ item_image_url = item['image_url']['url']
322
+ image = base64_to_image(item_image_url)
323
+ if image:
324
+ image_var_dict[f"image_clue_{image_var_idx}"] = image
325
+ init_captured_figures.append(base64.b64encode(
326
+ BytesIO(image.tobytes()).getvalue()).decode('utf-8'))
327
+ image_var_idx += 1
328
+
329
+ image_var_dict["_captured_figures"] = init_captured_figures
330
  self.inject(image_var_dict)
331
+
332
 
333
  class DateRuntime(GenericRuntime):
334
  GLOBAL_DICT = {}
 
356
  get_answer_expr: Optional[str] = None,
357
  get_answer_from_stdout: bool = True,
358
  timeout_length: int = 20,
359
+ use_process_isolation: bool = True,
360
  ) -> None:
361
  self.runtime_class = runtime_class if runtime_class else ImageRuntime
 
362
  self.answer_symbol = get_answer_symbol
363
  self.answer_expr = get_answer_expr
364
  self.get_answer_from_stdout = get_answer_from_stdout
365
  self.timeout_length = timeout_length
366
+ self.use_process_isolation = use_process_isolation
367
+ self.persistent_worker = None
368
 
369
+ def _ensure_worker(self):
370
+ """Ensure the worker process exists."""
371
+ if self.persistent_worker is None:
372
+ self.persistent_worker = PersistentWorker()
373
 
374
  def process_generation_to_code(self, gens: str):
375
  return [g.split("\n") for g in gens]
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  def execute(
378
  self,
379
  code,
 
383
  answer_symbol=None,
384
  answer_expr=None,
385
  ) -> Tuple[Union[str, Dict[str, Any]], str]:
386
+
387
+ if self.use_process_isolation:
388
+ # Ensure worker process exists
389
+ self._ensure_worker()
390
+
391
+ # Execute code
392
+ result = self.persistent_worker.execute(
393
+ code,
394
+ messages,
395
+ runtime_class or self.runtime_class,
396
+ get_answer_from_stdout,
397
+ answer_symbol,
398
+ answer_expr,
399
+ timeout=self.timeout_length
400
+ )
401
+
402
+ if result['status'] == 'success':
403
+ return result['result'], result.get('report', 'Done')
404
  else:
405
+ error_result = {
406
+ 'error': result.get('error', 'Unknown error'),
407
+ 'traceback': result.get('traceback', '')
408
+ }
409
+ return error_result, result.get('report', f"Error: {result.get('error', 'Unknown error')}")
410
+ else:
411
+ # Non-isolation mode (for backward compatibility)
412
+ runtime = runtime_class(messages) if runtime_class else self.runtime_class(messages)
413
+
414
+ try:
415
+ if get_answer_from_stdout:
416
+ program_io = io.StringIO()
417
+ with redirect_stdout(program_io):
418
+ runtime.exec_code("\n".join(code))
419
+ program_io.seek(0)
420
+ result = program_io.read()
421
+ elif answer_symbol:
422
  runtime.exec_code("\n".join(code))
423
+ result = runtime._global_vars.get(answer_symbol, "")
424
+ elif answer_expr:
425
+ runtime.exec_code("\n".join(code))
426
+ result = runtime.eval_code(answer_expr)
427
+ else:
428
+ if len(code) > 1:
429
+ runtime.exec_code("\n".join(code[:-1]))
430
+ result = runtime.eval_code(code[-1])
431
+ else:
432
+ runtime.exec_code("\n".join(code))
433
+ result = ""
434
+
435
+ # Check for captured figures
436
+ captured_figures = runtime.captured_figures
437
+ if captured_figures:
438
+ result = {
439
+ 'text': result,
440
+ 'images': captured_figures
441
+ } if result else {'images': captured_figures}
442
+ else:
443
+ result = {'text': result} if result else {}
444
+
445
+ report = "Done"
446
+
447
+ except Exception as e:
448
  result = {
449
+ 'error': str(e),
450
+ 'traceback': traceback.format_exc()
451
+ }
452
+ report = f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
+ return result, report
455
 
 
456
  def apply(self, code, messages):
457
  return self.batch_apply([code], messages)[0]
458
 
459
  @staticmethod
460
  def truncate(s, max_length=400):
461
  if isinstance(s, dict):
462
+ # If it is a dict (with images), truncate only the text part
463
  if 'text' in s:
464
  half = max_length // 2
465
  if len(s['text']) > max_length:
 
471
  s = s[:half] + "..." + s[-half:]
472
  return s
473
 
 
 
 
 
 
 
474
  def batch_apply(self, batch_code, messages):
 
 
475
  all_code_snippets = self.process_generation_to_code(batch_code)
476
 
477
  timeout_cnt = 0
478
  all_exec_results = []
479
 
 
480
  if len(all_code_snippets) > 100:
481
  progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
482
  else:
 
484
 
485
  for code in all_code_snippets:
486
  try:
 
487
  result = self.execute(
488
  code,
489
  messages=messages,
 
491
  runtime_class=self.runtime_class,
492
  answer_symbol=self.answer_symbol,
493
  answer_expr=self.answer_expr,
 
494
  )
495
  all_exec_results.append(result)
496
  except TimeoutError as error:
 
509
 
510
  batch_results = []
511
  for code, (res, report) in zip(all_code_snippets, all_exec_results):
512
+ # Handle results
513
  if isinstance(res, dict):
514
+ # If result contains images, special handling
515
  if 'text' in res:
516
  res['text'] = str(res['text']).strip()
517
  res['text'] = self.truncate(res['text'])
518
  report = str(report).strip()
519
  report = self.truncate(report)
520
  else:
521
+ # Normal text result
522
  res = str(res).strip()
523
  res = self.truncate(res)
524
  report = str(report).strip()
 
526
  batch_results.append((res, report))
527
  return batch_results
528
 
529
+ def reset(self, messages=None):
530
+ """Reset executor state."""
531
+ if self.use_process_isolation and self.persistent_worker:
532
+ self.persistent_worker.reset_runtime(messages, self.runtime_class)
533
 
534
+ def __del__(self):
535
+ """Clean up resources."""
536
+ if self.persistent_worker:
537
+ self.persistent_worker.terminate()