File size: 9,394 Bytes
fe0d3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import logging
import json
from typing_extensions import Optional, List, Dict, Any

from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform, call_api, get_server_instance, cleanup_server, \
    call_function_tool
from mcp.types import TextContent, ImageContent

from aworld.core.common import ActionResult


class McpServers:

    def __init__(
            self,
            mcp_servers: Optional[List[str]] = None,
            mcp_config: Dict[str, Any] = None,
            sandbox = None,
    ) -> None:
        self.mcp_servers = mcp_servers
        self.mcp_config = mcp_config
        self.sandbox = sandbox
        # Dictionary to store server instances {server_name: server_instance}
        self.server_instances = {}
        self.tool_list = None

    async def list_tools(
            self,
    ) -> List[Dict[str, Any]]:
        if self.tool_list:
            return self.tool_list
        if not self.mcp_servers or not self.mcp_config:
            return []
        try:
            self.tool_list = await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config)
            return self.tool_list
        except Exception as e:
            logging.warning(f"Failed to list tools: {e}")
            return []

    async def call_tool(
            self,
            action_list: List[Dict[str, Any]] = None,
            task_id: str = None,
            session_id: str = None
    ) -> List[ActionResult]:
        results = []
        if not action_list:
            return None

        try:
            for action in action_list:
                if not isinstance(action, dict):
                    action_dict = vars(action)
                else:
                    action_dict = action

                # Get values from dictionary
                server_name = action_dict.get("tool_name")
                tool_name = action_dict.get("action_name")
                parameter = action_dict.get("params")
                result_key = f"{server_name}__{tool_name}"
                

                operation_info = {
                    "server_name": server_name,
                    "tool_name": tool_name,
                    "params": parameter
                }
                
                if parameter is None:
                    parameter = {}
                if task_id:
                    parameter["task_id"] = task_id
                if session_id:
                    parameter["session_id"] = session_id

                if not server_name or not tool_name:
                    continue

                # Check server type
                server_type = None
                if self.mcp_config and self.mcp_config.get("mcpServers"):
                    server_config = self.mcp_config.get("mcpServers").get(server_name, {})
                    server_type = server_config.get("type", "")

                if server_type == "function_tool":
                    try:
                        call_result = await call_function_tool(
                            server_name, tool_name, parameter, self.mcp_config
                        )
                        results.append(call_result)

                        self._update_metadata(result_key, call_result, operation_info)
                    except Exception as e:
                        logging.warning(f"Error calling function_tool tool: {e}")
                        self._update_metadata(result_key, {"error": str(e)}, operation_info)
                    continue

                # For API type servers, use call_api function directly
                if server_type == "api":
                    try:
                        call_result = await call_api(
                            server_name, tool_name, parameter, self.mcp_config
                        )
                        results.append(call_result)

                        self._update_metadata(result_key, call_result, operation_info)
                    except Exception as e:
                        logging.warning(f"Error calling API tool: {e}")
                        self._update_metadata(result_key, {"error": str(e)}, operation_info)
                    continue

                # Prioritize using existing server instances
                server = self.server_instances.get(server_name)
                if server is None:
                    # If it doesn't exist, create a new instance and save it
                    server = await get_server_instance(server_name, self.mcp_config)
                    if server:
                        self.server_instances[server_name] = server
                        logging.info(f"Created and cached new server instance for {server_name}")
                    else:
                        logging.warning(f"Created new server failed: {server_name}")

                        self._update_metadata(result_key, {"error": "Failed to create server instance"}, operation_info)
                        continue

                # Use server instance to call the tool
                try:
                    call_result_raw = await server.call_tool(tool_name, parameter)

                    # Process the return result, consistent with the original logic
                    action_result = ActionResult(
                        tool_name=server_name,
                        action_name=tool_name,
                        content="",
                        keep=True
                    )

                    if call_result_raw and call_result_raw.content:
                        if isinstance(call_result_raw.content[0], TextContent):
                            action_result = ActionResult(
                                tool_name=server_name,
                                action_name=tool_name,
                                content=call_result_raw.content[0].text,
                                keep=True,
                                metadata=call_result_raw.content[0].model_extra.get(
                                    "metadata", {}
                                ),
                            )
                        elif isinstance(call_result_raw.content[0], ImageContent):
                            action_result = ActionResult(
                                tool_name=server_name,
                                action_name=tool_name,
                                content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}",
                                keep=True,
                                metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
                            )
                    results.append(action_result)
                    self._update_metadata(result_key, action_result, operation_info)
                    
                except Exception as e:
                    logging.warning(f"Error calling tool with cached server: {e}")

                    self._update_metadata(result_key, {"error": str(e)}, operation_info)
                    
                    # If using cached server instance fails, try to clean up and recreate
                    if server_name in self.server_instances:
                        try:
                            await cleanup_server(self.server_instances[server_name])
                            del self.server_instances[server_name]
                        except Exception as e:
                            logging.warning(f"Failed to cleanup server {server_name}: {e}")
        except Exception as e:
            logging.warning(f"Failed to call_tool: {e}")
            return None

        return results
    
    def _update_metadata(self, result_key: str, result: Any, operation_info: Dict[str, Any]):
        """
        Update sandbox metadata with a single tool call result

        Args:
            result_key: The key name in metadata
            result: Tool call result
            operation_info: Operation information
        """
        if not self.sandbox or not hasattr(self.sandbox, '_metadata'):
            return
            
        try:
            metadata = self.sandbox._metadata.get("mcp_metadata",{})
            tmp_data = {
                "input": operation_info,
                "output": result
            }
            if not metadata:
                metadata["mcp_metadata"] = {}
                metadata["mcp_metadata"][result_key] = [tmp_data]
                self.sandbox._metadata["mcp_metadata"] = metadata
                return

            _metadata = metadata.get(result_key, [])
            if not _metadata:
                _metadata[result_key] = [_metadata]
            else:
                _metadata[result_key].append(tmp_data)
            metadata[result_key] = _metadata
            self.sandbox._metadata["mcp_metadata"] = metadata
            return

        except Exception as e:
            logging.warning(f"Failed to update sandbox metadata: {e}")

    # Add cleanup method, called when Sandbox is destroyed
    async def cleanup(self):
        """Clean up all server connections"""
        for server_name, server in list(self.server_instances.items()):
            try:
                await cleanup_server(server)
                del self.server_instances[server_name]
                logging.info(f"Cleaned up server instance for {server_name}")
            except Exception as e:
                logging.warning(f"Failed to cleanup server {server_name}: {e}")