|
import threading |
|
from typing import Dict, Any, List, Tuple, Optional, Union |
|
import io |
|
from contextlib import redirect_stdout |
|
from timeout_decorator import timeout |
|
import base64 |
|
from PIL import Image |
|
from vis_python_exe import PythonExecutor, GenericRuntime |
|
|
|
class SharedRuntimeExecutor(PythonExecutor): |
|
""" |
|
支持变量共享的Python执行器,增强特性: |
|
1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量 |
|
2. 默认模式仅保留系统必要变量和白名单变量 |
|
3. 线程安全的运行时管理 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
runtime_class=None, |
|
get_answer_symbol: Optional[str] = None, |
|
get_answer_expr: Optional[str] = None, |
|
get_answer_from_stdout: bool = True, |
|
timeout_length: int = 20, |
|
var_whitelist: Union[List[str], str, None] = None, |
|
): |
|
""" |
|
Args: |
|
var_whitelist: |
|
- 列表: 保留指定前缀的变量 |
|
- "RETAIN_ALL_VARS": 保留所有变量 |
|
- None: 仅保留系统变量 |
|
""" |
|
super().__init__( |
|
runtime_class=runtime_class, |
|
get_answer_symbol=get_answer_symbol, |
|
get_answer_expr=get_answer_expr, |
|
get_answer_from_stdout=get_answer_from_stdout, |
|
timeout_length=timeout_length, |
|
) |
|
|
|
|
|
self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS") |
|
self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or []) |
|
|
|
|
|
if '_captured_figures' not in self.var_whitelist: |
|
self.var_whitelist.append('_captured_figures') |
|
|
|
|
|
self._runtime_pool: Dict[str, GenericRuntime] = {} |
|
self._lock = threading.Lock() |
|
|
|
def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]: |
|
"""执行代码并保持会话状态""" |
|
runtime = self._get_runtime(session_id, messages) |
|
|
|
try: |
|
|
|
result, report = self._execute_shared(code, runtime) |
|
|
|
|
|
self._clean_runtime_vars(runtime) |
|
|
|
return result, report |
|
|
|
except Exception as e: |
|
return None, f"Execution failed: {str(e)}" |
|
|
|
def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime: |
|
"""线程安全地获取运行时实例""" |
|
with self._lock: |
|
if session_id not in self._runtime_pool: |
|
self._runtime_pool[session_id] = self.runtime_class(messages) |
|
return self._runtime_pool[session_id] |
|
|
|
def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]: |
|
"""使用共享运行时执行代码""" |
|
code_lines = [line for line in code.split('\n') if line.strip()] |
|
|
|
try: |
|
if self.get_answer_from_stdout: |
|
program_io = io.StringIO() |
|
with redirect_stdout(program_io): |
|
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) |
|
program_io.seek(0) |
|
result = program_io.read() |
|
elif self.answer_symbol: |
|
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) |
|
result = runtime._global_vars.get(self.answer_symbol, "") |
|
elif self.answer_expr: |
|
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) |
|
result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr) |
|
else: |
|
if len(code_lines) > 1: |
|
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1])) |
|
result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1]) |
|
else: |
|
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) |
|
result = "" |
|
|
|
|
|
captured_figures = runtime._global_vars.get("_captured_figures", []) |
|
if captured_figures: |
|
result = { |
|
'text': str(result).strip(), |
|
'images': captured_figures |
|
} |
|
|
|
return result, "Success" |
|
|
|
except Exception as e: |
|
return None, f"Error: {str(e)}" |
|
|
|
def _clean_runtime_vars(self, runtime: GenericRuntime): |
|
"""实现变量保留策略""" |
|
if self.retain_all_vars: |
|
|
|
persistent_vars = { |
|
k: self._serialize_var(v) |
|
for k, v in runtime._global_vars.items() |
|
if not k.startswith('_sys_') |
|
} |
|
else: |
|
|
|
persistent_vars = { |
|
k: self._serialize_var(v) |
|
for k, v in runtime._global_vars.items() |
|
if ( |
|
k.startswith('image_clue_') or |
|
any(k.startswith(p) for p in self.var_whitelist) |
|
) |
|
} |
|
|
|
|
|
runtime._global_vars.clear() |
|
runtime._global_vars.update(persistent_vars) |
|
|
|
|
|
runtime._global_vars.setdefault('_captured_figures', []) |
|
|
|
def _serialize_var(self, var_value: Any) -> Any: |
|
"""处理特殊对象的序列化""" |
|
if isinstance(var_value, Image.Image): |
|
|
|
buf = io.BytesIO() |
|
var_value.save(buf, format='PNG') |
|
return base64.b64encode(buf.getvalue()).decode('utf-8') |
|
return var_value |
|
|
|
def cleanup_session(self, session_id: str): |
|
"""清理指定会话""" |
|
with self._lock: |
|
if session_id in self._runtime_pool: |
|
del self._runtime_pool[session_id] |
|
|
|
def cleanup_all(self): |
|
"""清理所有会话""" |
|
with self._lock: |
|
self._runtime_pool.clear() |