File size: 6,289 Bytes
490df47
 
 
 
 
 
 
95b7fd7
490df47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import threading
from typing import Dict, Any, List, Tuple, Optional, Union
import io
from contextlib import redirect_stdout
from timeout_decorator import timeout
import base64
from PIL import Image
from vis_python_exe import PythonExecutor, GenericRuntime

class SharedRuntimeExecutor(PythonExecutor):
    """
    支持变量共享的Python执行器,增强特性:
    1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量
    2. 默认模式仅保留系统必要变量和白名单变量
    3. 线程安全的运行时管理
    """

    def __init__(
        self,
        runtime_class=None,
        get_answer_symbol: Optional[str] = None,
        get_answer_expr: Optional[str] = None,
        get_answer_from_stdout: bool = True,
        timeout_length: int = 20,
        var_whitelist: Union[List[str], str, None] = None,
    ):
        """
        Args:
            var_whitelist: 
                - 列表: 保留指定前缀的变量
                - "RETAIN_ALL_VARS": 保留所有变量
                - None: 仅保留系统变量
        """
        super().__init__(
            runtime_class=runtime_class,
            get_answer_symbol=get_answer_symbol,
            get_answer_expr=get_answer_expr,
            get_answer_from_stdout=get_answer_from_stdout,
            timeout_length=timeout_length,
        )
        
        # 变量保留策略
        self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS")
        self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or [])
        
        # 确保系统必要变量
        if '_captured_figures' not in self.var_whitelist:
            self.var_whitelist.append('_captured_figures')
            
        # 线程安全运行时存储
        self._runtime_pool: Dict[str, GenericRuntime] = {}
        self._lock = threading.Lock()

    def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]:
        """执行代码并保持会话状态"""
        runtime = self._get_runtime(session_id, messages)
        
        try:
            # 执行代码
            result, report = self._execute_shared(code, runtime)
            
            # 清理变量(保留策略在此生效)
            self._clean_runtime_vars(runtime)
            
            return result, report
            
        except Exception as e:
            return None, f"Execution failed: {str(e)}"

    def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime:
        """线程安全地获取运行时实例"""
        with self._lock:
            if session_id not in self._runtime_pool:
                self._runtime_pool[session_id] = self.runtime_class(messages)
            return self._runtime_pool[session_id]

    def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]:
        """使用共享运行时执行代码"""
        code_lines = [line for line in code.split('\n') if line.strip()]
        
        try:
            if self.get_answer_from_stdout:
                program_io = io.StringIO()
                with redirect_stdout(program_io):
                    timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
                program_io.seek(0)
                result = program_io.read()
            elif self.answer_symbol:
                timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
                result = runtime._global_vars.get(self.answer_symbol, "")
            elif self.answer_expr:
                timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
                result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr)
            else:
                if len(code_lines) > 1:
                    timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1]))
                    result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1])
                else:
                    timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines))
                    result = ""

            # 处理捕获的图像
            captured_figures = runtime._global_vars.get("_captured_figures", [])
            if captured_figures:
                result = {
                    'text': str(result).strip(),
                    'images': captured_figures
                }

            return result, "Success"
            
        except Exception as e:
            return None, f"Error: {str(e)}"

    def _clean_runtime_vars(self, runtime: GenericRuntime):
        """实现变量保留策略"""
        if self.retain_all_vars:
            # 全保留模式:保留所有非系统变量
            persistent_vars = {
                k: self._serialize_var(v) 
                for k, v in runtime._global_vars.items()
                if not k.startswith('_sys_')  # 示例:排除真正的系统变量
            }
        else:
            # 正常模式:按白名单保留
            persistent_vars = {
                k: self._serialize_var(v)
                for k, v in runtime._global_vars.items()
                if (
                    k.startswith('image_clue_') or  # 保留注入的图像
                    any(k.startswith(p) for p in self.var_whitelist)  # 用户白名单
                )
            }
        
        # 重建变量空间
        runtime._global_vars.clear()
        runtime._global_vars.update(persistent_vars)
        
        # 确保必要的系统变量存在
        runtime._global_vars.setdefault('_captured_figures', [])

    def _serialize_var(self, var_value: Any) -> Any:
        """处理特殊对象的序列化"""
        if isinstance(var_value, Image.Image):
            # PIL图像转为base64
            buf = io.BytesIO()
            var_value.save(buf, format='PNG')
            return base64.b64encode(buf.getvalue()).decode('utf-8')
        return var_value

    def cleanup_session(self, session_id: str):
        """清理指定会话"""
        with self._lock:
            if session_id in self._runtime_pool:
                del self._runtime_pool[session_id]

    def cleanup_all(self):
        """清理所有会话"""
        with self._lock:
            self._runtime_pool.clear()