File size: 6,289 Bytes
490df47 95b7fd7 490df47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import threading
from typing import Dict, Any, List, Tuple, Optional, Union
import io
from contextlib import redirect_stdout
from timeout_decorator import timeout
import base64
from PIL import Image
from vis_python_exe import PythonExecutor, GenericRuntime
class SharedRuntimeExecutor(PythonExecutor):
"""
支持变量共享的Python执行器,增强特性:
1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量
2. 默认模式仅保留系统必要变量和白名单变量
3. 线程安全的运行时管理
"""
def __init__(
self,
runtime_class=None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = True,
timeout_length: int = 20,
var_whitelist: Union[List[str], str, None] = None,
):
"""
Args:
var_whitelist:
- 列表: 保留指定前缀的变量
- "RETAIN_ALL_VARS": 保留所有变量
- None: 仅保留系统变量
"""
super().__init__(
runtime_class=runtime_class,
get_answer_symbol=get_answer_symbol,
get_answer_expr=get_answer_expr,
get_answer_from_stdout=get_answer_from_stdout,
timeout_length=timeout_length,
)
# 变量保留策略
self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS")
self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or [])
# 确保系统必要变量
if '_captured_figures' not in self.var_whitelist:
self.var_whitelist.append('_captured_figures')
# 线程安全运行时存储
self._runtime_pool: Dict[str, GenericRuntime] = {}
self._lock = threading.Lock()
def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]:
"""执行代码并保持会话状态"""
runtime = self._get_runtime(session_id, messages)
try:
# 执行代码
result, report = self._execute_shared(code, runtime)
# 清理变量(保留策略在此生效)
self._clean_runtime_vars(runtime)
return result, report
except Exception as e:
return None, f"Execution failed: {str(e)}"
def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime:
"""线程安全地获取运行时实例"""
with self._lock:
if session_id not in self._runtime_pool:
self._runtime_pool[session_id] = self.runtime_class(messages)
return self._runtime_pool[session_id]
def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]:
"""使用共享运行时执行代码"""
code_lines = [line for line in code.split('\n') if line.strip()]
try:
if self.get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
program_io.seek(0)
result = program_io.read()
elif self.answer_symbol:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = runtime._global_vars.get(self.answer_symbol, "")
elif self.answer_expr:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr)
else:
if len(code_lines) > 1:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1]))
result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1])
else:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = ""
# 处理捕获的图像
captured_figures = runtime._global_vars.get("_captured_figures", [])
if captured_figures:
result = {
'text': str(result).strip(),
'images': captured_figures
}
return result, "Success"
except Exception as e:
return None, f"Error: {str(e)}"
def _clean_runtime_vars(self, runtime: GenericRuntime):
"""实现变量保留策略"""
if self.retain_all_vars:
# 全保留模式:保留所有非系统变量
persistent_vars = {
k: self._serialize_var(v)
for k, v in runtime._global_vars.items()
if not k.startswith('_sys_') # 示例:排除真正的系统变量
}
else:
# 正常模式:按白名单保留
persistent_vars = {
k: self._serialize_var(v)
for k, v in runtime._global_vars.items()
if (
k.startswith('image_clue_') or # 保留注入的图像
any(k.startswith(p) for p in self.var_whitelist) # 用户白名单
)
}
# 重建变量空间
runtime._global_vars.clear()
runtime._global_vars.update(persistent_vars)
# 确保必要的系统变量存在
runtime._global_vars.setdefault('_captured_figures', [])
def _serialize_var(self, var_value: Any) -> Any:
"""处理特殊对象的序列化"""
if isinstance(var_value, Image.Image):
# PIL图像转为base64
buf = io.BytesIO()
var_value.save(buf, format='PNG')
return base64.b64encode(buf.getvalue()).decode('utf-8')
return var_value
def cleanup_session(self, session_id: str):
"""清理指定会话"""
with self._lock:
if session_id in self._runtime_pool:
del self._runtime_pool[session_id]
def cleanup_all(self):
"""清理所有会话"""
with self._lock:
self._runtime_pool.clear() |