stzhao commited on
Commit
faf2a20
·
verified ·
1 Parent(s): 276916a

Update shared_vis_python_exe.py

Browse files
Files changed (1) hide show
  1. shared_vis_python_exe.py +508 -135
shared_vis_python_exe.py CHANGED
@@ -1,20 +1,192 @@
1
- import threading
2
- from typing import Dict, Any, List, Tuple, Optional, Union
3
  import io
4
- from contextlib import redirect_stdout
 
 
 
 
 
 
 
 
 
 
 
 
5
  from timeout_decorator import timeout
 
6
  import base64
 
7
  from PIL import Image
8
- from vis_python_exe import PythonExecutor, GenericRuntime
9
 
10
- class SharedRuntimeExecutor(PythonExecutor):
 
 
 
 
 
 
 
 
11
  """
12
- 支持变量共享的Python执行器,增强特性:
13
- 1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量
14
- 2. 默认模式仅保留系统必要变量和白名单变量
15
- 3. 线程安全的运行时管理
 
 
 
 
 
 
 
 
 
16
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def __init__(
19
  self,
20
  runtime_class=None,
@@ -22,139 +194,340 @@ class SharedRuntimeExecutor(PythonExecutor):
22
  get_answer_expr: Optional[str] = None,
23
  get_answer_from_stdout: bool = True,
24
  timeout_length: int = 20,
25
- var_whitelist: Union[List[str], str, None] = None,
26
- ):
27
- """
28
- Args:
29
- var_whitelist:
30
- - 列表: 保留指定前缀的变量
31
- - "RETAIN_ALL_VARS": 保留所有变量
32
- - None: 仅保留系统变量
33
- """
34
- super().__init__(
35
- runtime_class=runtime_class,
36
- get_answer_symbol=get_answer_symbol,
37
- get_answer_expr=get_answer_expr,
38
- get_answer_from_stdout=get_answer_from_stdout,
39
- timeout_length=timeout_length,
40
- )
41
-
42
- # 变量保留策略
43
- self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS")
44
- self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or [])
45
-
46
- # 确保系统必要变量
47
- if '_captured_figures' not in self.var_whitelist:
48
- self.var_whitelist.append('_captured_figures')
49
-
50
- # 线程安全运行时存储
51
- self._runtime_pool: Dict[str, GenericRuntime] = {}
52
- self._lock = threading.Lock()
53
 
54
- def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]:
55
- """执行代码并保持会话状态"""
56
- runtime = self._get_runtime(session_id, messages)
57
-
58
- try:
59
- # 执行代码
60
- result, report = self._execute_shared(code, runtime)
61
-
62
- # 清理变量(保留策略在此生效)
63
- self._clean_runtime_vars(runtime)
64
-
65
- return result, report
66
-
67
- except Exception as e:
68
- return None, f"Execution failed: {str(e)}"
69
-
70
- def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime:
71
- """线程安全地获取运行时实例"""
72
- with self._lock:
73
- if session_id not in self._runtime_pool:
74
- self._runtime_pool[session_id] = self.runtime_class(messages)
75
- return self._runtime_pool[session_id]
76
-
77
- def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]:
78
- """使用共享运行时执行代码"""
79
- code_lines = [line for line in code.split('\n') if line.strip()]
 
 
 
 
80
 
81
- try:
82
- if self.get_answer_from_stdout:
83
- program_io = io.StringIO()
84
- with redirect_stdout(program_io):
85
- timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
86
- program_io.seek(0)
87
- result = program_io.read()
88
- elif self.answer_symbol:
89
- timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
90
- result = runtime._global_vars.get(self.answer_symbol, "")
91
- elif self.answer_expr:
92
- timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
93
- result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr)
 
 
 
94
  else:
95
- if len(code_lines) > 1:
96
- timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1]))
97
- result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1])
98
- else:
99
- timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
100
- result = ""
101
-
102
- # 处理捕获的图像
103
- captured_figures = runtime._global_vars.get("_captured_figures", [])
104
- if captured_figures:
105
  result = {
106
- 'text': str(result).strip(),
107
  'images': captured_figures
108
  }
109
-
110
- return result, "Success"
111
-
 
 
 
 
 
 
 
 
112
  except Exception as e:
113
- return None, f"Error: {str(e)}"
114
-
115
- def _clean_runtime_vars(self, runtime: GenericRuntime):
116
- """实现变量保留策略"""
117
- if self.retain_all_vars:
118
- # 全保留模式:保留所有非系统变量
119
- persistent_vars = {
120
- k: self._serialize_var(v)
121
- for k, v in runtime._global_vars.items()
122
- if not k.startswith('_sys_') # 示例:排除真正的系统变量
123
- }
 
 
 
 
 
 
124
  else:
125
- # 正常模式:按白名单保留
126
- persistent_vars = {
127
- k: self._serialize_var(v)
128
- for k, v in runtime._global_vars.items()
129
- if (
130
- k.startswith('image_clue_') or # 保留注入的图像
131
- any(k.startswith(p) for p in self.var_whitelist) # 用户白名单
132
- )
133
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # 重建变量空间
136
- runtime._global_vars.clear()
137
- runtime._global_vars.update(persistent_vars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # 确保必要的系统变量存在
140
- runtime._global_vars.setdefault('_captured_figures', [])
141
-
142
- def _serialize_var(self, var_value: Any) -> Any:
143
- """处理特殊对象的序列化"""
144
- if isinstance(var_value, Image.Image):
145
- # PIL图像转为base64
146
- buf = io.BytesIO()
147
- var_value.save(buf, format='PNG')
148
- return base64.b64encode(buf.getvalue()).decode('utf-8')
149
- return var_value
150
-
151
- def cleanup_session(self, session_id: str):
152
- """清理指定会话"""
153
- with self._lock:
154
- if session_id in self._runtime_pool:
155
- del self._runtime_pool[session_id]
156
-
157
- def cleanup_all(self):
158
- """清理所有会话"""
159
- with self._lock:
160
- self._runtime_pool.clear()
 
1
+ import os
 
2
  import io
3
+ import regex
4
+ import pickle
5
+ import traceback
6
+ import copy
7
+ import datetime
8
+ import dateutil.relativedelta
9
+ import multiprocess
10
+ from multiprocess import Pool
11
+ from typing import Any, Dict, Optional, Tuple, List, Union
12
+ from pebble import ProcessPool
13
+ from tqdm import tqdm
14
+ from concurrent.futures import TimeoutError
15
+ from functools import partial
16
  from timeout_decorator import timeout
17
+ from contextlib import redirect_stdout
18
  import base64
19
+ from io import BytesIO
20
  from PIL import Image
21
+ import pdb
22
 
23
+ def encode_image(image_path):
24
+ with open(image_path, "rb") as image_file:
25
+ return base64.b64encode(image_file.read()).decode('utf-8')
26
+
27
+ def base64_to_image(
28
+ base64_str: str,
29
+ remove_prefix: bool = True,
30
+ convert_mode: Optional[str] = "RGB"
31
+ ) -> Union[Image.Image, None]:
32
  """
33
+ 将Base64编码的图片字符串转换为PIL Image对象
34
+
35
+ Args:
36
+ base64_str: Base64编码的图片字符串(可带data:前缀)
37
+ remove_prefix: 是否自动去除"...")
45
+ >>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False)
46
  """
47
+ try:
48
+ # 1. 处理Base64前缀
49
+ if remove_prefix and "," in base64_str:
50
+ base64_str = base64_str.split(",")[1]
51
+
52
+ # 2. 解码Base64
53
+ image_data = base64.b64decode(base64_str)
54
+
55
+ # 3. 转换为PIL Image
56
+ image = Image.open(BytesIO(image_data))
57
+
58
+ # 4. 可选模式转换
59
+ if convert_mode:
60
+ image = image.convert(convert_mode)
61
+
62
+ return image
63
+
64
+ except (base64.binascii.Error, OSError, Exception) as e:
65
+ print(f"Base64解码失败: {str(e)}")
66
+ return None
67
+
68
+
69
+ class GenericRuntime:
70
+ GLOBAL_DICT = {}
71
+ LOCAL_DICT = None
72
+ HEADERS = []
73
+
74
+ def __init__(self):
75
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
76
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
77
+ self._captured_figures = []
78
+
79
+ for c in self.HEADERS:
80
+ self.exec_code(c)
81
+
82
+ def exec_code(self, code_piece: str) -> None:
83
+ if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
84
+ r"(\s|^)?os.system\(", code_piece
85
+ ):
86
+ raise RuntimeError("Forbidden function calls detected")
87
+
88
+
89
+
90
+ # 检测并修改plt.show()调用
91
+ if "plt.show()" in code_piece:
92
+ modified_code = code_piece.replace("plt.show()", """
93
+ # 捕获当前图像
94
+ buf = io.BytesIO()
95
+ plt.savefig(buf, format='png')
96
+ buf.seek(0)
97
+ _captured_image = base64.b64encode(buf.read()).decode('utf-8')
98
+ _captured_figures.append(_captured_image)
99
+ plt.close()
100
+ """)
101
+ # 确保_captured_figures变量存在
102
+ if "_captured_figures" not in self._global_vars:
103
+ self._global_vars["_captured_figures"] = []
104
+
105
+ exec(modified_code, self._global_vars)
106
+ else:
107
+ print("###################################### I am excuting code. ##############################################")
108
+ exec(code_piece, self._global_vars)
109
+
110
+ def eval_code(self, expr: str) -> Any:
111
+ return eval(expr, self._global_vars)
112
+
113
+ def inject(self, var_dict: Dict[str, Any]) -> None:
114
+ for k, v in var_dict.items():
115
+ self._global_vars[k] = v
116
+
117
+ @property
118
+ def answer(self):
119
+ return self._global_vars.get("answer", None)
120
+
121
+ @property
122
+ def captured_figures(self):
123
+ return self._global_vars.get("_captured_figures", [])
124
+
125
+
126
+ class ImageRuntime(GenericRuntime):
127
+ # """支持图像处理的运行时环境"""
128
+ # GLOBAL_DICT = {} # 不预加载模块,避免序列化问题
129
+ # LOCAL_DICT = None
130
+
131
+ HEADERS = [
132
+ "import matplotlib",
133
+ "matplotlib.use('Agg')", # 使用非交互式后端
134
+ "import matplotlib.pyplot as plt",
135
+ "from PIL import Image",
136
+ "import io",
137
+ "import base64",
138
+ "import numpy as np",
139
+ "_captured_figures = []", # 初始化图像捕获列表
140
+ ]
141
+
142
+ def __init__(self, messages):
143
+ print("############################### I am initing image runtime ################################")
144
+ super().__init__()
145
+ # pdb.set_trace()
146
+
147
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
148
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
149
+ self._captured_figures = []
150
 
151
+ for c in self.HEADERS:
152
+ self.exec_code(c)
153
+
154
+ image_var_dict = {}
155
+ image_var_idx = 0
156
+ print("############################### I am initing image runtime ################################")
157
+ for message_item in messages:
158
+ content = message_item['content'] # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
159
+ for item in content:
160
+ item_type = item['type']
161
+ if item_type == "image_url":
162
+ item_image_url = item['image_url']['url']
163
+ image = base64_to_image(item_image_url)
164
+ image_var_dict[f"image_clue_{image_var_idx}"] = image
165
+ image_var_idx += 1
166
+
167
+ self.inject(image_var_dict)
168
+ print("##################### Initialize ImageRuntime. ##########################")
169
+
170
+
171
+ class DateRuntime(GenericRuntime):
172
+ GLOBAL_DICT = {}
173
+ HEADERS = [
174
+ "import datetime",
175
+ "from dateutil.relativedelta import relativedelta",
176
+ "timedelta = relativedelta"
177
+ ]
178
+
179
+
180
+ class CustomDict(dict):
181
+ def __iter__(self):
182
+ return list(super().__iter__()).__iter__()
183
+
184
+
185
+ class ColorObjectRuntime(GenericRuntime):
186
+ GLOBAL_DICT = {"dict": CustomDict}
187
+
188
+
189
+ class PythonExecutor:
190
  def __init__(
191
  self,
192
  runtime_class=None,
 
194
  get_answer_expr: Optional[str] = None,
195
  get_answer_from_stdout: bool = True,
196
  timeout_length: int = 20,
197
+ ) -> None:
198
+ print(f"#################### When Init PythonExcutor, RunTime typel:, TimeOut Length: {timeout_length} #############################")
199
+ self.runtime_class = runtime_class if runtime_class else ImageRuntime
200
+ print(self.runtime_class)
201
+ self.answer_symbol = get_answer_symbol
202
+ self.answer_expr = get_answer_expr
203
+ self.get_answer_from_stdout = get_answer_from_stdout
204
+ self.pool = Pool(multiprocess.cpu_count())
205
+ self.timeout_length = timeout_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ # Create a persistent runtime instance if messages are provided
208
+ self.persistent_runtime = None
209
+ # if messages:
210
+ # self.persistent_runtime = self.runtime_class(messages)
211
+
212
+ def process_generation_to_code(self, gens: str):
213
+ return [g.split("\n") for g in gens]
214
+
215
+ # @staticmethod
216
+ def execute(
217
+ self,
218
+ code,
219
+ messages,
220
+ get_answer_from_stdout=True,
221
+ runtime_class=None,
222
+ # run_time_instance=None,
223
+ answer_symbol=None,
224
+ answer_expr=None,
225
+ timeout_length=20,
226
+ ) -> Tuple[Union[str, Dict[str, Any]], str]:
227
+ # print("dome")
228
+ # try:
229
+ # 在每个进程中创建新的运行时实例
230
+ print(f"################################################## I am excuting ! #############################################################")
231
+ print(str(messages)[0:500])
232
+ print(str(messages)[-500:])
233
+ print(runtime_class)
234
+ # runtime = runtime_class(messages)
235
+ runtime = self.persistent_runtime
236
+ print(f"################################################## I am excuting ! #############################################################")
237
 
238
+ if get_answer_from_stdout:
239
+ program_io = io.StringIO()
240
+ with redirect_stdout(program_io):
241
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
242
+ program_io.seek(0)
243
+ result = program_io.read()
244
+ elif answer_symbol:
245
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
246
+ result = runtime._global_vars.get(answer_symbol, "")
247
+ elif answer_expr:
248
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
249
+ result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
250
+ else:
251
+ if len(code) > 1:
252
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1]))
253
+ result = timeout(timeout_length)(runtime.eval_code)(code[-1])
254
  else:
255
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
256
+ result = ""
257
+
258
+ # 检查是否有捕获的图像
259
+ captured_figures = runtime._global_vars.get("_captured_figures", [])
260
+ if captured_figures:
261
+ # 如果有文本输出和图像,将它们组合
262
+ if result:
 
 
263
  result = {
264
+ 'text': result,
265
  'images': captured_figures
266
  }
267
+ else:
268
+ result = {'images': captured_figures}
269
+
270
+ report = "Done"
271
+ # except Exception as e:
272
+ # result = ""
273
+ # report = f"Error: {str(e)}\n{traceback.format_exc()}"
274
+
275
+ # 确保结果可序列化
276
+ try:
277
+ pickle.dumps(result)
278
  except Exception as e:
279
+ result = f"Result serialization error: {str(e)}"
280
+ report = f"Serialization Error: {str(e)}"
281
+
282
+ return result, report
283
+
284
+ def apply(self, code, messages):
285
+ return self.batch_apply([code], messages)[0]
286
+
287
+ @staticmethod
288
+ def truncate(s, max_length=400):
289
+ if isinstance(s, dict):
290
+ # 如果是字典(包含图像),只截断文本部分
291
+ if 'text' in s:
292
+ half = max_length // 2
293
+ if len(s['text']) > max_length:
294
+ s['text'] = s['text'][:half] + "..." + s['text'][-half:]
295
+ return s
296
  else:
297
+ half = max_length // 2
298
+ if isinstance(s, str) and len(s) > max_length:
299
+ s = s[:half] + "..." + s[-half:]
300
+ return s
301
+
302
+ def update_persistent_runtime_with_messages():
303
+ pass
304
+
305
+ def get_persistent_runtime(self):
306
+ return self.persistent_runtime
307
+
308
+ # def batch_apply(self, batch_code, messages):
309
+ # if not self.persistent_runtime and messages:
310
+ # self.persistent_runtime = self.runtime_class(messages)
311
+ # all_code_snippets = self.process_generation_to_code(batch_code)
312
+
313
+ # timeout_cnt = 0
314
+ # all_exec_results = []
315
+ # print(f"################################### num of cpu: {os.cpu_count()} ; len of code: {len(all_code_snippets)} ######################################")
316
+ # with ProcessPool(
317
+ # max_workers=min(len(all_code_snippets), os.cpu_count())
318
+ # ) as pool:
319
+ # executor = partial(
320
+ # self.execute,
321
+ # get_answer_from_stdout=self.get_answer_from_stdout,
322
+ # runtime_class=self.runtime_class,
323
+ # # run_time_instance=self.persistent_runtime,
324
+ # answer_symbol=self.answer_symbol,
325
+ # answer_expr=self.answer_expr,
326
+ # timeout_length=self.timeout_length,
327
+ # )
328
+ # future = pool.map(executor, all_code_snippets, [messages], timeout=self.timeout_length)
329
+ # iterator = future.result()
330
+
331
+ # if len(all_code_snippets) > 100:
332
+ # progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
333
+ # else:
334
+ # progress_bar = None
335
+
336
+ # while True:
337
+ # try:
338
+ # result = next(iterator)
339
+ # all_exec_results.append(result)
340
+ # except StopIteration:
341
+ # break
342
+ # except TimeoutError as error:
343
+ # print(error)
344
+ # all_exec_results.append(("", "Timeout Error"))
345
+ # timeout_cnt += 1
346
+ # except Exception as error:
347
+ # print(f"Error in batch_apply: {error}")
348
+ # all_exec_results.append(("", f"Error: {str(error)}"))
349
+ # if progress_bar is not None:
350
+ # progress_bar.update(1)
351
+
352
+ # if progress_bar is not None:
353
+ # progress_bar.close()
354
+
355
+ # batch_results = []
356
+ # for code, (res, report) in zip(all_code_snippets, all_exec_results):
357
+ # # 处理结果
358
+ # if isinstance(res, dict):
359
+ # # 如果结果包含图像,特殊处理
360
+ # if 'text' in res:
361
+ # res['text'] = str(res['text']).strip()
362
+ # res['text'] = self.truncate(res['text'])
363
+ # report = str(report).strip()
364
+ # report = self.truncate(report)
365
+ # else:
366
+ # # 普通文本结果
367
+ # res = str(res).strip()
368
+ # res = self.truncate(res)
369
+ # report = str(report).strip()
370
+ # report = self.truncate(report)
371
+ # batch_results.append((res, report))
372
+ # return batch_results
373
+
374
+ def batch_apply(self, batch_code, messages):
375
+ if not self.persistent_runtime and messages:
376
+ self.persistent_runtime = self.runtime_class(messages)
377
+ all_code_snippets = self.process_generation_to_code(batch_code)
378
+
379
+ timeout_cnt = 0
380
+ all_exec_results = []
381
 
382
+ print(f"################################### num of cpu: {os.cpu_count()} ; len of code: {len(all_code_snippets)} ######################################")
383
+
384
+ # 去掉 ProcessPool,改为单进程顺序执行
385
+ if len(all_code_snippets) > 100:
386
+ progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
387
+ else:
388
+ progress_bar = None
389
+
390
+ for code in all_code_snippets:
391
+ try:
392
+ # 直接调用 self.execute,而不是用 ProcessPool
393
+ result = self.execute(
394
+ code,
395
+ messages=messages,
396
+ get_answer_from_stdout=self.get_answer_from_stdout,
397
+ runtime_class=self.runtime_class,
398
+ answer_symbol=self.answer_symbol,
399
+ answer_expr=self.answer_expr,
400
+ timeout_length=self.timeout_length,
401
+ )
402
+ all_exec_results.append(result)
403
+ except TimeoutError as error:
404
+ print(error)
405
+ all_exec_results.append(("", "Timeout Error"))
406
+ timeout_cnt += 1
407
+ except Exception as error:
408
+ print(f"Error in batch_apply: {error}")
409
+ all_exec_results.append(("", f"Error: {str(error)}"))
410
+
411
+ if progress_bar is not None:
412
+ progress_bar.update(1)
413
+
414
+ if progress_bar is not None:
415
+ progress_bar.close()
416
+
417
+ batch_results = []
418
+ for code, (res, report) in zip(all_code_snippets, all_exec_results):
419
+ # 处理结果
420
+ if isinstance(res, dict):
421
+ # 如果结果包含图像,特殊处理
422
+ if 'text' in res:
423
+ res['text'] = str(res['text']).strip()
424
+ res['text'] = self.truncate(res['text'])
425
+ report = str(report).strip()
426
+ report = self.truncate(report)
427
+ else:
428
+ # 普通文本结果
429
+ res = str(res).strip()
430
+ res = self.truncate(res)
431
+ report = str(report).strip()
432
+ report = self.truncate(report)
433
+ batch_results.append((res, report))
434
+ return batch_results
435
+
436
+
437
+ def _test():
438
+ image_path = "/mnt/petrelfs/zhaoshitian/vis_tool_inference_engine/test_data/0.JPG"
439
+ image_base64 = encode_image(image_path)
440
+ messages = [
441
+ {
442
+ "role": "user",
443
+ "content": [{"type": "text", "text": "From the information on that advertising board, what is the type of this shop?"}]
444
+ },
445
+ {
446
+ "role": "user",
447
+ "content": [{"type": "text", "text": "image_clue_0"}] + [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}]
448
+ }
449
+ ]
450
+ # 测试普通计算
451
+ math_code ="""
452
+ a = 1
453
+ b = 2
454
+ c = a + b
455
+ print(c)
456
+ """
457
+
458
+ batch_code = [math_code]
459
+
460
+ executor = PythonExecutor()
461
+ predictions = executor.apply(batch_code[0], messages)
462
+ print("数学计算结果:", predictions)
463
+
464
+ # 测试图像显示
465
+ image_code = """
466
+ import matplotlib.pyplot as plt
467
+ import numpy as np
468
+ from PIL import Image
469
+ import io
470
+
471
+ # 创建一个简单的图像
472
+ x = np.linspace(0, 10, 100)
473
+ y = np.sin(x)
474
+
475
+ plt.figure(figsize=(8, 6))
476
+ plt.plot(x, y, 'r-', linewidth=2)
477
+ plt.title('Sine Wave')
478
+ plt.grid(True)
479
+ plt.show()
480
+
481
+ # 也可以显示一个简单的图像
482
+ # 创建一个彩色渐变图像
483
+ arr = np.zeros((100, 100, 3), dtype=np.uint8)
484
+ for i in range(100):
485
+ for j in range(100):
486
+ arr[i, j, 0] = i # 红色通道
487
+ arr[i, j, 1] = j # 绿色通道
488
+ arr[i, j, 2] = 100 # 蓝色通道
489
+
490
+ img = Image.fromarray(arr)
491
+ plt.figure()
492
+ plt.imshow(img)
493
+ plt.title('Gradient Image')
494
+ plt.show()
495
+
496
+ print("图像生成完成")
497
+ """
498
+
499
+ image_code = """
500
+ import matplotlib.pyplot as plt
501
+ import numpy as np
502
+ from PIL import Image
503
+ import io
504
+
505
+ plt.imshow(image_clue_0)
506
+ plt.title("Original Image - Locate Advertising Board")
507
+ plt.show()
508
+ """
509
+
510
+ image_result = executor.apply(image_code, messages)
511
+ print("\n图像结果类型:", type(image_result[0]))
512
+ if isinstance(image_result[0], dict) and 'images' in image_result[0]:
513
+ print(f"捕获到 {len(image_result[0]['images'])} 个图像")
514
+ print("第一个图像的base64编码前20个字符:", image_result[0]['images'][0][:20])
515
 
516
+ # 可选:保存图像到文件
517
+ for i, img_data in enumerate(image_result[0]['images']):
518
+ img_bytes = base64.b64decode(img_data)
519
+ with open(f"captured_image_{i}.png", "wb") as f:
520
+ f.write(img_bytes)
521
+ print(f"图像已保存为 captured_image_{i}.png")
522
+
523
+ if 'text' in image_result[0]:
524
+ print("文本输出:", image_result[0]['text'])
525
+ else:
526
+ print("未捕获到图像")
527
+ print("结果:", image_result[0])
528
+
529
+ print("\n执行状态:", image_result[1])
530
+
531
+
532
+ if __name__ == "__main__":
533
+ _test()