stzhao commited on
Commit
b3f97e9
·
verified ·
1 Parent(s): 52031b8

Create vis_python_exe.py

Browse files
Files changed (1) hide show
  1. vis_python_exe.py +454 -0
vis_python_exe.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import regex
4
+ import pickle
5
+ import traceback
6
+ import copy
7
+ import datetime
8
+ import dateutil.relativedelta
9
+ import multiprocess
10
+ from multiprocess import Pool
11
+ from typing import Any, Dict, Optional, Tuple, List, Union
12
+ from pebble import ProcessPool
13
+ from tqdm import tqdm
14
+ from concurrent.futures import TimeoutError
15
+ from functools import partial
16
+ from timeout_decorator import timeout
17
+ from contextlib import redirect_stdout
18
+ import base64
19
+ from io import BytesIO
20
+ from PIL import Image
21
+ import pdb
22
+
23
+ def encode_image(image_path):
24
+ with open(image_path, "rb") as image_file:
25
+ return base64.b64encode(image_file.read()).decode('utf-8')
26
+
27
+ def base64_to_image(
28
+ base64_str: str,
29
+ remove_prefix: bool = True,
30
+ convert_mode: Optional[str] = "RGB"
31
+ ) -> Union[Image.Image, None]:
32
+ """
33
+ 将Base64编码的图片字符串转换为PIL Image对象
34
+
35
+ Args:
36
+ base64_str: Base64编码的图片字符串(可带data:前缀)
37
+ remove_prefix: 是否自动去除"data:image/..."前缀(默认True)
38
+ convert_mode: 转换为指定模式(如"RGB"/"RGBA",None表示不转换)
39
+
40
+ Returns:
41
+ PIL.Image.Image 对象,解码失败时返回None
42
+
43
+ Examples:
44
+ >>> img = base64_to_image("data:image/png;base64,iVBORw0KGg...")
45
+ >>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False)
46
+ """
47
+ try:
48
+ # 1. 处理Base64前缀
49
+ if remove_prefix and "," in base64_str:
50
+ base64_str = base64_str.split(",")[1]
51
+
52
+ # 2. 解码Base64
53
+ image_data = base64.b64decode(base64_str)
54
+
55
+ # 3. 转换为PIL Image
56
+ image = Image.open(BytesIO(image_data))
57
+
58
+ # 4. 可选模式转换
59
+ if convert_mode:
60
+ image = image.convert(convert_mode)
61
+
62
+ return image
63
+
64
+ except (base64.binascii.Error, OSError, Exception) as e:
65
+ print(f"Base64解码失败: {str(e)}")
66
+ return None
67
+
68
+
69
+ class GenericRuntime:
70
+ GLOBAL_DICT = {}
71
+ LOCAL_DICT = None
72
+ HEADERS = []
73
+
74
+ def __init__(self):
75
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
76
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
77
+ self._captured_figures = []
78
+
79
+ for c in self.HEADERS:
80
+ self.exec_code(c)
81
+
82
+ def exec_code(self, code_piece: str) -> None:
83
+ if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
84
+ r"(\s|^)?os.system\(", code_piece
85
+ ):
86
+ raise RuntimeError("Forbidden function calls detected")
87
+
88
+
89
+
90
+ # 检测并修改plt.show()调用
91
+ if "plt.show()" in code_piece:
92
+ modified_code = code_piece.replace("plt.show()", """
93
+ # 捕获当前图像
94
+ buf = io.BytesIO()
95
+ plt.savefig(buf, format='png')
96
+ buf.seek(0)
97
+ _captured_image = base64.b64encode(buf.read()).decode('utf-8')
98
+ _captured_figures.append(_captured_image)
99
+ plt.close()
100
+ """)
101
+ # 确保_captured_figures变量存在
102
+ if "_captured_figures" not in self._global_vars:
103
+ self._global_vars["_captured_figures"] = []
104
+
105
+ exec(modified_code, self._global_vars)
106
+ else:
107
+ print("###################################### I am excuting code. ##############################################")
108
+ exec(code_piece, self._global_vars)
109
+
110
+ def eval_code(self, expr: str) -> Any:
111
+ return eval(expr, self._global_vars)
112
+
113
+ def inject(self, var_dict: Dict[str, Any]) -> None:
114
+ for k, v in var_dict.items():
115
+ self._global_vars[k] = v
116
+
117
+ @property
118
+ def answer(self):
119
+ return self._global_vars.get("answer", None)
120
+
121
+ @property
122
+ def captured_figures(self):
123
+ return self._global_vars.get("_captured_figures", [])
124
+
125
+
126
+ class ImageRuntime(GenericRuntime):
127
+ """支持图像处理的运行时环境"""
128
+ GLOBAL_DICT = {} # 不预加载模块,避免序列化问题
129
+ LOCAL_DICT = None
130
+
131
+ HEADERS = [
132
+ "import matplotlib",
133
+ "matplotlib.use('Agg')", # 使用非交互式后端
134
+ "import matplotlib.pyplot as plt",
135
+ "from PIL import Image",
136
+ "import io",
137
+ "import base64",
138
+ "import numpy as np",
139
+ "_captured_figures = []", # 初始化图像捕获列表
140
+ ]
141
+
142
+ def __init__(self, messages):
143
+ print("############################### I am initing image runtime ################################")
144
+ # super().__init__()
145
+ # pdb.set_trace()
146
+
147
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
148
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
149
+ self._captured_figures = []
150
+
151
+ # for c in self.HEADERS:
152
+ # self.exec_code(c)
153
+
154
+ image_var_dict = {}
155
+ image_var_idx = 0
156
+ print("############################### I am initing image runtime ################################")
157
+ for message_item in messages:
158
+ content = message_item['content'] # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
159
+ for item in content:
160
+ item_type = item['type']
161
+ if item_type == "image_url":
162
+ item_image_url = item['image_url']['url']
163
+ image = base64_to_image(item_image_url)
164
+ image_var_dict[f"image_clue_{image_var_idx}"] = image
165
+ image_var_idx += 1
166
+
167
+ self.inject(image_var_dict)
168
+ print("##################### Initialize ImageRuntime. ##########################")
169
+
170
+
171
+ class DateRuntime(GenericRuntime):
172
+ GLOBAL_DICT = {}
173
+ HEADERS = [
174
+ "import datetime",
175
+ "from dateutil.relativedelta import relativedelta",
176
+ "timedelta = relativedelta"
177
+ ]
178
+
179
+
180
+ class CustomDict(dict):
181
+ def __iter__(self):
182
+ return list(super().__iter__()).__iter__()
183
+
184
+
185
+ class ColorObjectRuntime(GenericRuntime):
186
+ GLOBAL_DICT = {"dict": CustomDict}
187
+
188
+
189
+ class PythonExecutor:
190
+ def __init__(
191
+ self,
192
+ runtime_class=None,
193
+ get_answer_symbol: Optional[str] = None,
194
+ get_answer_expr: Optional[str] = None,
195
+ get_answer_from_stdout: bool = True,
196
+ timeout_length: int = 20,
197
+ ) -> None:
198
+ print(f"#################### When Init PythonExcutor, RunTime typel:, TimeOut Length: {timeout_length} #############################")
199
+ self.runtime_class = runtime_class if runtime_class else ImageRuntime
200
+ print(self.runtime_class)
201
+ self.answer_symbol = get_answer_symbol
202
+ self.answer_expr = get_answer_expr
203
+ self.get_answer_from_stdout = get_answer_from_stdout
204
+ self.pool = Pool(multiprocess.cpu_count())
205
+ self.timeout_length = timeout_length
206
+
207
+ def process_generation_to_code(self, gens: str):
208
+ return [g.split("\n") for g in gens]
209
+
210
+ @staticmethod
211
+ def execute(
212
+ code,
213
+ messages,
214
+ get_answer_from_stdout=True,
215
+ runtime_class=None,
216
+ answer_symbol=None,
217
+ answer_expr=None,
218
+ timeout_length=20,
219
+ ) -> Tuple[Union[str, Dict[str, Any]], str]:
220
+ # print("dome")
221
+ try:
222
+ # 在每个进程中创建新的运行时实例
223
+ print(f"################################################## I am excuting ! #############################################################")
224
+ print(str(messages)[0:500])
225
+ print(str(messages)[-500:])
226
+ print(runtime_class)
227
+ runtime = runtime_class(messages)
228
+ print(f"################################################## I am excuting ! #############################################################")
229
+
230
+ if get_answer_from_stdout:
231
+ program_io = io.StringIO()
232
+ with redirect_stdout(program_io):
233
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
234
+ program_io.seek(0)
235
+ result = program_io.read()
236
+ elif answer_symbol:
237
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
238
+ result = runtime._global_vars.get(answer_symbol, "")
239
+ elif answer_expr:
240
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
241
+ result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
242
+ else:
243
+ if len(code) > 1:
244
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1]))
245
+ result = timeout(timeout_length)(runtime.eval_code)(code[-1])
246
+ else:
247
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
248
+ result = ""
249
+
250
+ # 检查是否有捕获的图像
251
+ captured_figures = runtime._global_vars.get("_captured_figures", [])
252
+ if captured_figures:
253
+ # 如果有文本输出和图像,将它们组合
254
+ if result:
255
+ result = {
256
+ 'text': result,
257
+ 'images': captured_figures
258
+ }
259
+ else:
260
+ result = {'images': captured_figures}
261
+
262
+ report = "Done"
263
+ except Exception as e:
264
+ result = ""
265
+ report = f"Error: {str(e)}\n{traceback.format_exc()}"
266
+
267
+ # 确保结果可序列化
268
+ try:
269
+ pickle.dumps(result)
270
+ except Exception as e:
271
+ result = f"Result serialization error: {str(e)}"
272
+ report = f"Serialization Error: {str(e)}"
273
+
274
+ return result, report
275
+
276
+ def apply(self, code, messages):
277
+ return self.batch_apply([code], messages)[0]
278
+
279
+ @staticmethod
280
+ def truncate(s, max_length=400):
281
+ if isinstance(s, dict):
282
+ # 如果是字典(包含图像),只截断文本部分
283
+ if 'text' in s:
284
+ half = max_length // 2
285
+ if len(s['text']) > max_length:
286
+ s['text'] = s['text'][:half] + "..." + s['text'][-half:]
287
+ return s
288
+ else:
289
+ half = max_length // 2
290
+ if isinstance(s, str) and len(s) > max_length:
291
+ s = s[:half] + "..." + s[-half:]
292
+ return s
293
+
294
+ def batch_apply(self, batch_code, messages):
295
+ all_code_snippets = self.process_generation_to_code(batch_code)
296
+
297
+ timeout_cnt = 0
298
+ all_exec_results = []
299
+ print(f"################################### num of cpu: {os.cpu_count()} ; len of code: {len(all_code_snippets)} ######################################")
300
+ with ProcessPool(
301
+ max_workers=min(len(all_code_snippets), os.cpu_count())
302
+ ) as pool:
303
+ executor = partial(
304
+ self.execute,
305
+ get_answer_from_stdout=self.get_answer_from_stdout,
306
+ runtime_class=self.runtime_class,
307
+ answer_symbol=self.answer_symbol,
308
+ answer_expr=self.answer_expr,
309
+ timeout_length=self.timeout_length,
310
+ )
311
+ future = pool.map(executor, all_code_snippets, [messages], timeout=self.timeout_length)
312
+ iterator = future.result()
313
+
314
+ if len(all_code_snippets) > 100:
315
+ progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
316
+ else:
317
+ progress_bar = None
318
+
319
+ while True:
320
+ try:
321
+ result = next(iterator)
322
+ all_exec_results.append(result)
323
+ except StopIteration:
324
+ break
325
+ except TimeoutError as error:
326
+ print(error)
327
+ all_exec_results.append(("", "Timeout Error"))
328
+ timeout_cnt += 1
329
+ except Exception as error:
330
+ print(f"Error in batch_apply: {error}")
331
+ all_exec_results.append(("", f"Error: {str(error)}"))
332
+ if progress_bar is not None:
333
+ progress_bar.update(1)
334
+
335
+ if progress_bar is not None:
336
+ progress_bar.close()
337
+
338
+ batch_results = []
339
+ for code, (res, report) in zip(all_code_snippets, all_exec_results):
340
+ # 处理结果
341
+ if isinstance(res, dict):
342
+ # 如果结果包含图像,特殊处理
343
+ if 'text' in res:
344
+ res['text'] = str(res['text']).strip()
345
+ res['text'] = self.truncate(res['text'])
346
+ report = str(report).strip()
347
+ report = self.truncate(report)
348
+ else:
349
+ # 普通文本结果
350
+ res = str(res).strip()
351
+ res = self.truncate(res)
352
+ report = str(report).strip()
353
+ report = self.truncate(report)
354
+ batch_results.append((res, report))
355
+ return batch_results
356
+
357
+
358
+ def _test():
359
+ image_path = "/mnt/petrelfs/zhaoshitian/vis_tool_inference_engine/test_data/0.JPG"
360
+ image_base64 = encode_image(image_path)
361
+ messages = [
362
+ {
363
+ "role": "user",
364
+ "content": [{"type": "text", "text": "From the information on that advertising board, what is the type of this shop?"}]
365
+ },
366
+ {
367
+ "role": "user",
368
+ "content": [{"type": "text", "text": "image_clue_0"}] + [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}]
369
+ }
370
+ ]
371
+ # 测试普通计算
372
+ math_code ="""
373
+ a = 1
374
+ b = 2
375
+ c = a + b
376
+ print(c)
377
+ """
378
+
379
+ batch_code = [math_code]
380
+
381
+ executor = PythonExecutor()
382
+ predictions = executor.apply(batch_code[0], messages)
383
+ print("数学计算结果:", predictions)
384
+
385
+ # 测试图像显示
386
+ image_code = """
387
+ import matplotlib.pyplot as plt
388
+ import numpy as np
389
+ from PIL import Image
390
+ import io
391
+
392
+ # 创建一个简单的图像
393
+ x = np.linspace(0, 10, 100)
394
+ y = np.sin(x)
395
+
396
+ plt.figure(figsize=(8, 6))
397
+ plt.plot(x, y, 'r-', linewidth=2)
398
+ plt.title('Sine Wave')
399
+ plt.grid(True)
400
+ plt.show()
401
+
402
+ # 也可以显示一个简单的图像
403
+ # 创建一个彩色渐变图像
404
+ arr = np.zeros((100, 100, 3), dtype=np.uint8)
405
+ for i in range(100):
406
+ for j in range(100):
407
+ arr[i, j, 0] = i # 红色通道
408
+ arr[i, j, 1] = j # 绿色通道
409
+ arr[i, j, 2] = 100 # 蓝色通道
410
+
411
+ img = Image.fromarray(arr)
412
+ plt.figure()
413
+ plt.imshow(img)
414
+ plt.title('Gradient Image')
415
+ plt.show()
416
+
417
+ print("图像生成完成")
418
+ """
419
+
420
+ image_code = """
421
+ import matplotlib.pyplot as plt
422
+ import numpy as np
423
+ from PIL import Image
424
+ import io
425
+
426
+ plt.imshow(image_clue_0)
427
+ plt.title("Original Image - Locate Advertising Board")
428
+ plt.show()
429
+ """
430
+
431
+ image_result = executor.apply(image_code, messages)
432
+ print("\n图像结果类型:", type(image_result[0]))
433
+ if isinstance(image_result[0], dict) and 'images' in image_result[0]:
434
+ print(f"捕获到 {len(image_result[0]['images'])} 个图像")
435
+ print("第一个图像��base64编码前20个字符:", image_result[0]['images'][0][:20])
436
+
437
+ # 可选:保存图像到文件
438
+ for i, img_data in enumerate(image_result[0]['images']):
439
+ img_bytes = base64.b64decode(img_data)
440
+ with open(f"captured_image_{i}.png", "wb") as f:
441
+ f.write(img_bytes)
442
+ print(f"图像已保存为 captured_image_{i}.png")
443
+
444
+ if 'text' in image_result[0]:
445
+ print("文本输出:", image_result[0]['text'])
446
+ else:
447
+ print("未捕获到图像")
448
+ print("结果:", image_result[0])
449
+
450
+ print("\n执行状态:", image_result[1])
451
+
452
+
453
+ if __name__ == "__main__":
454
+ _test()