|
import os |
|
import io |
|
import regex |
|
import pickle |
|
import traceback |
|
import copy |
|
import datetime |
|
import dateutil.relativedelta |
|
import multiprocess |
|
from multiprocess import Pool |
|
from typing import Any, Dict, Optional, Tuple, List, Union |
|
from pebble import ProcessPool |
|
from tqdm import tqdm |
|
from concurrent.futures import TimeoutError |
|
from functools import partial |
|
from timeout_decorator import timeout |
|
from contextlib import redirect_stdout |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import pdb |
|
|
|
def encode_image(image_path): |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
def base64_to_image( |
|
base64_str: str, |
|
remove_prefix: bool = True, |
|
convert_mode: Optional[str] = "RGB" |
|
) -> Union[Image.Image, None]: |
|
""" |
|
将Base64编码的图片字符串转换为PIL Image对象 |
|
|
|
Args: |
|
base64_str: Base64编码的图片字符串(可带data:前缀) |
|
remove_prefix: 是否自动去除"data:image/..."前缀(默认True) |
|
convert_mode: 转换为指定模式(如"RGB"/"RGBA",None表示不转换) |
|
|
|
Returns: |
|
PIL.Image.Image 对象,解码失败时返回None |
|
|
|
Examples: |
|
>>> img = base64_to_image("data:image/png;base64,iVBORw0KGg...") |
|
>>> img = base64_to_image("iVBORw0KGg...", remove_prefix=False) |
|
""" |
|
try: |
|
|
|
if remove_prefix and "," in base64_str: |
|
base64_str = base64_str.split(",")[1] |
|
|
|
|
|
image_data = base64.b64decode(base64_str) |
|
|
|
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
if convert_mode: |
|
image = image.convert(convert_mode) |
|
|
|
return image |
|
|
|
except (base64.binascii.Error, OSError, Exception) as e: |
|
print(f"Base64解码失败: {str(e)}") |
|
return None |
|
|
|
|
|
class GenericRuntime: |
|
GLOBAL_DICT = {} |
|
LOCAL_DICT = None |
|
HEADERS = [] |
|
|
|
def __init__(self): |
|
self._global_vars = copy.copy(self.GLOBAL_DICT) |
|
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None |
|
self._captured_figures = [] |
|
|
|
for c in self.HEADERS: |
|
self.exec_code(c) |
|
|
|
def exec_code(self, code_piece: str) -> None: |
|
if regex.search(r"(\s|^)?input\(", code_piece) or regex.search( |
|
r"(\s|^)?os.system\(", code_piece |
|
): |
|
raise RuntimeError("Forbidden function calls detected") |
|
|
|
|
|
|
|
|
|
if "plt.show()" in code_piece: |
|
modified_code = code_piece.replace("plt.show()", """ |
|
# 捕获当前图像 |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
_captured_image = base64.b64encode(buf.read()).decode('utf-8') |
|
_captured_figures.append(_captured_image) |
|
plt.close() |
|
""") |
|
|
|
if "_captured_figures" not in self._global_vars: |
|
self._global_vars["_captured_figures"] = [] |
|
|
|
exec(modified_code, self._global_vars) |
|
else: |
|
print("###################################### I am excuting code. ##############################################") |
|
exec(code_piece, self._global_vars) |
|
|
|
def eval_code(self, expr: str) -> Any: |
|
return eval(expr, self._global_vars) |
|
|
|
def inject(self, var_dict: Dict[str, Any]) -> None: |
|
for k, v in var_dict.items(): |
|
self._global_vars[k] = v |
|
|
|
@property |
|
def answer(self): |
|
return self._global_vars.get("answer", None) |
|
|
|
@property |
|
def captured_figures(self): |
|
return self._global_vars.get("_captured_figures", []) |
|
|
|
|
|
class ImageRuntime(GenericRuntime): |
|
|
|
|
|
|
|
|
|
HEADERS = [ |
|
"import matplotlib", |
|
"matplotlib.use('Agg')", |
|
"import matplotlib.pyplot as plt", |
|
"from PIL import Image", |
|
"import io", |
|
"import base64", |
|
"import numpy as np", |
|
"_captured_figures = []", |
|
] |
|
|
|
def __init__(self, messages): |
|
print("############################### I am initing image runtime ################################") |
|
super().__init__() |
|
|
|
|
|
self._global_vars = copy.copy(self.GLOBAL_DICT) |
|
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None |
|
self._captured_figures = [] |
|
|
|
for c in self.HEADERS: |
|
self.exec_code(c) |
|
|
|
image_var_dict = {} |
|
image_var_idx = 0 |
|
print("############################### I am initing image runtime ################################") |
|
for message_item in messages: |
|
content = message_item['content'] |
|
for item in content: |
|
item_type = item['type'] |
|
if item_type == "image_url": |
|
item_image_url = item['image_url']['url'] |
|
image = base64_to_image(item_image_url) |
|
image_var_dict[f"image_clue_{image_var_idx}"] = image |
|
image_var_idx += 1 |
|
|
|
self.inject(image_var_dict) |
|
print("##################### Initialize ImageRuntime. ##########################") |
|
|
|
|
|
class DateRuntime(GenericRuntime): |
|
GLOBAL_DICT = {} |
|
HEADERS = [ |
|
"import datetime", |
|
"from dateutil.relativedelta import relativedelta", |
|
"timedelta = relativedelta" |
|
] |
|
|
|
|
|
class CustomDict(dict): |
|
def __iter__(self): |
|
return list(super().__iter__()).__iter__() |
|
|
|
|
|
class ColorObjectRuntime(GenericRuntime): |
|
GLOBAL_DICT = {"dict": CustomDict} |
|
|
|
|
|
class PythonExecutor: |
|
def __init__( |
|
self, |
|
runtime_class=None, |
|
get_answer_symbol: Optional[str] = None, |
|
get_answer_expr: Optional[str] = None, |
|
get_answer_from_stdout: bool = True, |
|
timeout_length: int = 20, |
|
) -> None: |
|
print(f"#################### When Init PythonExcutor, RunTime typel:, TimeOut Length: {timeout_length} #############################") |
|
self.runtime_class = runtime_class if runtime_class else ImageRuntime |
|
print(self.runtime_class) |
|
self.answer_symbol = get_answer_symbol |
|
self.answer_expr = get_answer_expr |
|
self.get_answer_from_stdout = get_answer_from_stdout |
|
self.timeout_length = timeout_length |
|
|
|
|
|
self.persistent_runtime = None |
|
|
|
|
|
|
|
def process_generation_to_code(self, gens: str): |
|
return [g.split("\n") for g in gens] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def execute( |
|
self, |
|
code, |
|
messages, |
|
get_answer_from_stdout=True, |
|
runtime_class=None, |
|
|
|
answer_symbol=None, |
|
answer_expr=None, |
|
|
|
) -> Tuple[Union[str, Dict[str, Any]], str]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(runtime_class) |
|
|
|
runtime = self.persistent_runtime |
|
|
|
|
|
if get_answer_from_stdout: |
|
program_io = io.StringIO() |
|
with redirect_stdout(program_io): |
|
|
|
runtime.exec_code("\n".join(code)) |
|
program_io.seek(0) |
|
result = program_io.read() |
|
elif answer_symbol: |
|
|
|
runtime.exec_code("\n".join(code)) |
|
result = runtime._global_vars.get(answer_symbol, "") |
|
elif answer_expr: |
|
|
|
runtime.exec_code("\n".join(code)) |
|
|
|
result = runtime.eval_code(answer_expr) |
|
else: |
|
if len(code) > 1: |
|
|
|
runtime.exec_code("\n".join(code[:-1])) |
|
|
|
result = runtime.eval_code(code[-1]) |
|
else: |
|
|
|
runtime.exec_code("\n".join(code)) |
|
result = "" |
|
|
|
|
|
captured_figures = runtime._global_vars.get("_captured_figures", []) |
|
if captured_figures: |
|
|
|
if result: |
|
result = { |
|
'text': result, |
|
'images': captured_figures |
|
} |
|
else: |
|
result = {'images': captured_figures} |
|
else: |
|
if result: |
|
result = { |
|
'text': result, |
|
} |
|
|
|
|
|
report = "Done" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
pickle.dumps(result) |
|
except Exception as e: |
|
result = f"Result serialization error: {str(e)}" |
|
report = f"Serialization Error: {str(e)}" |
|
|
|
return result, report |
|
|
|
|
|
def apply(self, code, messages): |
|
return self.batch_apply([code], messages)[0] |
|
|
|
@staticmethod |
|
def truncate(s, max_length=400): |
|
if isinstance(s, dict): |
|
|
|
if 'text' in s: |
|
half = max_length // 2 |
|
if len(s['text']) > max_length: |
|
s['text'] = s['text'][:half] + "..." + s['text'][-half:] |
|
return s |
|
else: |
|
half = max_length // 2 |
|
if isinstance(s, str) and len(s) > max_length: |
|
s = s[:half] + "..." + s[-half:] |
|
return s |
|
|
|
def update_persistent_runtime_with_messages(): |
|
pass |
|
|
|
def get_persistent_runtime(self): |
|
return self.persistent_runtime |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_apply(self, batch_code, messages): |
|
if not self.persistent_runtime and messages: |
|
self.persistent_runtime = self.runtime_class(messages) |
|
all_code_snippets = self.process_generation_to_code(batch_code) |
|
|
|
timeout_cnt = 0 |
|
all_exec_results = [] |
|
|
|
print(f"################################### num of cpu: {os.cpu_count()} ; len of code: {len(all_code_snippets)} ######################################") |
|
|
|
|
|
if len(all_code_snippets) > 100: |
|
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") |
|
else: |
|
progress_bar = None |
|
|
|
for code in all_code_snippets: |
|
try: |
|
|
|
result = self.execute( |
|
code, |
|
messages=messages, |
|
get_answer_from_stdout=self.get_answer_from_stdout, |
|
runtime_class=self.runtime_class, |
|
answer_symbol=self.answer_symbol, |
|
answer_expr=self.answer_expr, |
|
|
|
) |
|
all_exec_results.append(result) |
|
except TimeoutError as error: |
|
print(error) |
|
all_exec_results.append(("", "Timeout Error")) |
|
timeout_cnt += 1 |
|
except Exception as error: |
|
print(f"Error in batch_apply: {error}") |
|
all_exec_results.append(("", f"Error: {str(error)}")) |
|
|
|
if progress_bar is not None: |
|
progress_bar.update(1) |
|
|
|
if progress_bar is not None: |
|
progress_bar.close() |
|
|
|
batch_results = [] |
|
for code, (res, report) in zip(all_code_snippets, all_exec_results): |
|
|
|
if isinstance(res, dict): |
|
|
|
if 'text' in res: |
|
res['text'] = str(res['text']).strip() |
|
res['text'] = self.truncate(res['text']) |
|
report = str(report).strip() |
|
report = self.truncate(report) |
|
else: |
|
|
|
res = str(res).strip() |
|
res = self.truncate(res) |
|
report = str(report).strip() |
|
report = self.truncate(report) |
|
batch_results.append((res, report)) |
|
return batch_results |
|
|
|
|
|
def _test(): |
|
image_path = "/mnt/petrelfs/zhaoshitian/vis_tool_inference_engine/test_data/0.JPG" |
|
image_base64 = encode_image(image_path) |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": "From the information on that advertising board, what is the type of this shop?"}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": "image_clue_0"}] + [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}] |
|
} |
|
] |
|
|
|
math_code =""" |
|
a = 1 |
|
b = 2 |
|
c = a + b |
|
print(c) |
|
""" |
|
|
|
batch_code = [math_code] |
|
|
|
executor = PythonExecutor() |
|
predictions = executor.apply(batch_code[0], messages) |
|
print("数学计算结果:", predictions) |
|
|
|
|
|
image_code = """ |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
|
|
# 创建一个简单的图像 |
|
x = np.linspace(0, 10, 100) |
|
y = np.sin(x) |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.plot(x, y, 'r-', linewidth=2) |
|
plt.title('Sine Wave') |
|
plt.grid(True) |
|
plt.show() |
|
|
|
# 也可以显示一个简单的图像 |
|
# 创建一个彩色渐变图像 |
|
arr = np.zeros((100, 100, 3), dtype=np.uint8) |
|
for i in range(100): |
|
for j in range(100): |
|
arr[i, j, 0] = i # 红色通道 |
|
arr[i, j, 1] = j # 绿色通道 |
|
arr[i, j, 2] = 100 # 蓝色通道 |
|
|
|
img = Image.fromarray(arr) |
|
plt.figure() |
|
plt.imshow(img) |
|
plt.title('Gradient Image') |
|
plt.show() |
|
|
|
print("图像生成完成") |
|
""" |
|
|
|
image_code = """ |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
|
|
plt.imshow(image_clue_0) |
|
plt.title("Original Image - Locate Advertising Board") |
|
plt.show() |
|
""" |
|
|
|
image_result = executor.apply(image_code, messages) |
|
print("\n图像结果类型:", type(image_result[0])) |
|
if isinstance(image_result[0], dict) and 'images' in image_result[0]: |
|
print(f"捕获到 {len(image_result[0]['images'])} 个图像") |
|
print("第一个图像的base64编码前20个字符:", image_result[0]['images'][0][:20]) |
|
|
|
|
|
for i, img_data in enumerate(image_result[0]['images']): |
|
img_bytes = base64.b64decode(img_data) |
|
with open(f"captured_image_{i}.png", "wb") as f: |
|
f.write(img_bytes) |
|
print(f"图像已保存为 captured_image_{i}.png") |
|
|
|
if 'text' in image_result[0]: |
|
print("文本输出:", image_result[0]['text']) |
|
else: |
|
print("未捕获到图像") |
|
print("结果:", image_result[0]) |
|
|
|
print("\n执行状态:", image_result[1]) |
|
|
|
|
|
if __name__ == "__main__": |
|
_test() |
|
|