PyVision / shared_vis_python_exe.py
stzhao's picture
Update shared_vis_python_exe.py
95b7fd7 verified
raw
history blame
6.29 kB
import threading
from typing import Dict, Any, List, Tuple, Optional, Union
import io
from contextlib import redirect_stdout
from timeout_decorator import timeout
import base64
from PIL import Image
from vis_python_exe import PythonExecutor, GenericRuntime
class SharedRuntimeExecutor(PythonExecutor):
"""
支持变量共享的Python执行器,增强特性:
1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量
2. 默认模式仅保留系统必要变量和白名单变量
3. 线程安全的运行时管理
"""
def __init__(
self,
runtime_class=None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = True,
timeout_length: int = 20,
var_whitelist: Union[List[str], str, None] = None,
):
"""
Args:
var_whitelist:
- 列表: 保留指定前缀的变量
- "RETAIN_ALL_VARS": 保留所有变量
- None: 仅保留系统变量
"""
super().__init__(
runtime_class=runtime_class,
get_answer_symbol=get_answer_symbol,
get_answer_expr=get_answer_expr,
get_answer_from_stdout=get_answer_from_stdout,
timeout_length=timeout_length,
)
# 变量保留策略
self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS")
self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or [])
# 确保系统必要变量
if '_captured_figures' not in self.var_whitelist:
self.var_whitelist.append('_captured_figures')
# 线程安全运行时存储
self._runtime_pool: Dict[str, GenericRuntime] = {}
self._lock = threading.Lock()
def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]:
"""执行代码并保持会话状态"""
runtime = self._get_runtime(session_id, messages)
try:
# 执行代码
result, report = self._execute_shared(code, runtime)
# 清理变量(保留策略在此生效)
self._clean_runtime_vars(runtime)
return result, report
except Exception as e:
return None, f"Execution failed: {str(e)}"
def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime:
"""线程安全地获取运行时实例"""
with self._lock:
if session_id not in self._runtime_pool:
self._runtime_pool[session_id] = self.runtime_class(messages)
return self._runtime_pool[session_id]
def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]:
"""使用共享运行时执行代码"""
code_lines = [line for line in code.split('\n') if line.strip()]
try:
if self.get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
program_io.seek(0)
result = program_io.read()
elif self.answer_symbol:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = runtime._global_vars.get(self.answer_symbol, "")
elif self.answer_expr:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr)
else:
if len(code_lines) > 1:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1]))
result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1])
else:
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
result = ""
# 处理捕获的图像
captured_figures = runtime._global_vars.get("_captured_figures", [])
if captured_figures:
result = {
'text': str(result).strip(),
'images': captured_figures
}
return result, "Success"
except Exception as e:
return None, f"Error: {str(e)}"
def _clean_runtime_vars(self, runtime: GenericRuntime):
"""实现变量保留策略"""
if self.retain_all_vars:
# 全保留模式:保留所有非系统变量
persistent_vars = {
k: self._serialize_var(v)
for k, v in runtime._global_vars.items()
if not k.startswith('_sys_') # 示例:排除真正的系统变量
}
else:
# 正常模式:按白名单保留
persistent_vars = {
k: self._serialize_var(v)
for k, v in runtime._global_vars.items()
if (
k.startswith('image_clue_') or # 保留注入的图像
any(k.startswith(p) for p in self.var_whitelist) # 用户白名单
)
}
# 重建变量空间
runtime._global_vars.clear()
runtime._global_vars.update(persistent_vars)
# 确保必要的系统变量存在
runtime._global_vars.setdefault('_captured_figures', [])
def _serialize_var(self, var_value: Any) -> Any:
"""处理特殊对象的序列化"""
if isinstance(var_value, Image.Image):
# PIL图像转为base64
buf = io.BytesIO()
var_value.save(buf, format='PNG')
return base64.b64encode(buf.getvalue()).decode('utf-8')
return var_value
def cleanup_session(self, session_id: str):
"""清理指定会话"""
with self._lock:
if session_id in self._runtime_pool:
del self._runtime_pool[session_id]
def cleanup_all(self):
"""清理所有会话"""
with self._lock:
self._runtime_pool.clear()