import threading from typing import Dict, Any, List, Tuple, Optional, Union import io from contextlib import redirect_stdout from timeout_decorator import timeout import base64 from PIL import Image from vis_python_exe import PythonExecutor, GenericRuntime class SharedRuntimeExecutor(PythonExecutor): """ 支持变量共享的Python执行器,增强特性: 1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量 2. 默认模式仅保留系统必要变量和白名单变量 3. 线程安全的运行时管理 """ def __init__( self, runtime_class=None, get_answer_symbol: Optional[str] = None, get_answer_expr: Optional[str] = None, get_answer_from_stdout: bool = True, timeout_length: int = 20, var_whitelist: Union[List[str], str, None] = None, ): """ Args: var_whitelist: - 列表: 保留指定前缀的变量 - "RETAIN_ALL_VARS": 保留所有变量 - None: 仅保留系统变量 """ super().__init__( runtime_class=runtime_class, get_answer_symbol=get_answer_symbol, get_answer_expr=get_answer_expr, get_answer_from_stdout=get_answer_from_stdout, timeout_length=timeout_length, ) # 变量保留策略 self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS") self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or []) # 确保系统必要变量 if '_captured_figures' not in self.var_whitelist: self.var_whitelist.append('_captured_figures') # 线程安全运行时存储 self._runtime_pool: Dict[str, GenericRuntime] = {} self._lock = threading.Lock() def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]: """执行代码并保持会话状态""" runtime = self._get_runtime(session_id, messages) try: # 执行代码 result, report = self._execute_shared(code, runtime) # 清理变量(保留策略在此生效) self._clean_runtime_vars(runtime) return result, report except Exception as e: return None, f"Execution failed: {str(e)}" def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime: """线程安全地获取运行时实例""" with self._lock: if session_id not in self._runtime_pool: self._runtime_pool[session_id] = self.runtime_class(messages) return self._runtime_pool[session_id] def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]: """使用共享运行时执行代码""" code_lines = [line for line in code.split('\n') if line.strip()] try: if self.get_answer_from_stdout: program_io = io.StringIO() with redirect_stdout(program_io): timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) program_io.seek(0) result = program_io.read() elif self.answer_symbol: timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) result = runtime._global_vars.get(self.answer_symbol, "") elif self.answer_expr: timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr) else: if len(code_lines) > 1: timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1])) result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1]) else: timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) result = "" # 处理捕获的图像 captured_figures = runtime._global_vars.get("_captured_figures", []) if captured_figures: result = { 'text': str(result).strip(), 'images': captured_figures } return result, "Success" except Exception as e: return None, f"Error: {str(e)}" def _clean_runtime_vars(self, runtime: GenericRuntime): """实现变量保留策略""" if self.retain_all_vars: # 全保留模式:保留所有非系统变量 persistent_vars = { k: self._serialize_var(v) for k, v in runtime._global_vars.items() if not k.startswith('_sys_') # 示例:排除真正的系统变量 } else: # 正常模式:按白名单保留 persistent_vars = { k: self._serialize_var(v) for k, v in runtime._global_vars.items() if ( k.startswith('image_clue_') or # 保留注入的图像 any(k.startswith(p) for p in self.var_whitelist) # 用户白名单 ) } # 重建变量空间 runtime._global_vars.clear() runtime._global_vars.update(persistent_vars) # 确保必要的系统变量存在 runtime._global_vars.setdefault('_captured_figures', []) def _serialize_var(self, var_value: Any) -> Any: """处理特殊对象的序列化""" if isinstance(var_value, Image.Image): # PIL图像转为base64 buf = io.BytesIO() var_value.save(buf, format='PNG') return base64.b64encode(buf.getvalue()).decode('utf-8') return var_value def cleanup_session(self, session_id: str): """清理指定会话""" with self._lock: if session_id in self._runtime_pool: del self._runtime_pool[session_id] def cleanup_all(self): """清理所有会话""" with self._lock: self._runtime_pool.clear()