Duibonduil commited on
Commit
fe0d3be
·
verified ·
1 Parent(s): c9325c1

Upload mcp_servers.py

Browse files
Files changed (1) hide show
  1. aworld/sandbox/run/mcp_servers.py +225 -0
aworld/sandbox/run/mcp_servers.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from typing_extensions import Optional, List, Dict, Any
4
+
5
+ from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform, call_api, get_server_instance, cleanup_server, \
6
+ call_function_tool
7
+ from mcp.types import TextContent, ImageContent
8
+
9
+ from aworld.core.common import ActionResult
10
+
11
+
12
+ class McpServers:
13
+
14
+ def __init__(
15
+ self,
16
+ mcp_servers: Optional[List[str]] = None,
17
+ mcp_config: Dict[str, Any] = None,
18
+ sandbox = None,
19
+ ) -> None:
20
+ self.mcp_servers = mcp_servers
21
+ self.mcp_config = mcp_config
22
+ self.sandbox = sandbox
23
+ # Dictionary to store server instances {server_name: server_instance}
24
+ self.server_instances = {}
25
+ self.tool_list = None
26
+
27
+ async def list_tools(
28
+ self,
29
+ ) -> List[Dict[str, Any]]:
30
+ if self.tool_list:
31
+ return self.tool_list
32
+ if not self.mcp_servers or not self.mcp_config:
33
+ return []
34
+ try:
35
+ self.tool_list = await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config)
36
+ return self.tool_list
37
+ except Exception as e:
38
+ logging.warning(f"Failed to list tools: {e}")
39
+ return []
40
+
41
+ async def call_tool(
42
+ self,
43
+ action_list: List[Dict[str, Any]] = None,
44
+ task_id: str = None,
45
+ session_id: str = None
46
+ ) -> List[ActionResult]:
47
+ results = []
48
+ if not action_list:
49
+ return None
50
+
51
+ try:
52
+ for action in action_list:
53
+ if not isinstance(action, dict):
54
+ action_dict = vars(action)
55
+ else:
56
+ action_dict = action
57
+
58
+ # Get values from dictionary
59
+ server_name = action_dict.get("tool_name")
60
+ tool_name = action_dict.get("action_name")
61
+ parameter = action_dict.get("params")
62
+ result_key = f"{server_name}__{tool_name}"
63
+
64
+
65
+ operation_info = {
66
+ "server_name": server_name,
67
+ "tool_name": tool_name,
68
+ "params": parameter
69
+ }
70
+
71
+ if parameter is None:
72
+ parameter = {}
73
+ if task_id:
74
+ parameter["task_id"] = task_id
75
+ if session_id:
76
+ parameter["session_id"] = session_id
77
+
78
+ if not server_name or not tool_name:
79
+ continue
80
+
81
+ # Check server type
82
+ server_type = None
83
+ if self.mcp_config and self.mcp_config.get("mcpServers"):
84
+ server_config = self.mcp_config.get("mcpServers").get(server_name, {})
85
+ server_type = server_config.get("type", "")
86
+
87
+ if server_type == "function_tool":
88
+ try:
89
+ call_result = await call_function_tool(
90
+ server_name, tool_name, parameter, self.mcp_config
91
+ )
92
+ results.append(call_result)
93
+
94
+ self._update_metadata(result_key, call_result, operation_info)
95
+ except Exception as e:
96
+ logging.warning(f"Error calling function_tool tool: {e}")
97
+ self._update_metadata(result_key, {"error": str(e)}, operation_info)
98
+ continue
99
+
100
+ # For API type servers, use call_api function directly
101
+ if server_type == "api":
102
+ try:
103
+ call_result = await call_api(
104
+ server_name, tool_name, parameter, self.mcp_config
105
+ )
106
+ results.append(call_result)
107
+
108
+ self._update_metadata(result_key, call_result, operation_info)
109
+ except Exception as e:
110
+ logging.warning(f"Error calling API tool: {e}")
111
+ self._update_metadata(result_key, {"error": str(e)}, operation_info)
112
+ continue
113
+
114
+ # Prioritize using existing server instances
115
+ server = self.server_instances.get(server_name)
116
+ if server is None:
117
+ # If it doesn't exist, create a new instance and save it
118
+ server = await get_server_instance(server_name, self.mcp_config)
119
+ if server:
120
+ self.server_instances[server_name] = server
121
+ logging.info(f"Created and cached new server instance for {server_name}")
122
+ else:
123
+ logging.warning(f"Created new server failed: {server_name}")
124
+
125
+ self._update_metadata(result_key, {"error": "Failed to create server instance"}, operation_info)
126
+ continue
127
+
128
+ # Use server instance to call the tool
129
+ try:
130
+ call_result_raw = await server.call_tool(tool_name, parameter)
131
+
132
+ # Process the return result, consistent with the original logic
133
+ action_result = ActionResult(
134
+ tool_name=server_name,
135
+ action_name=tool_name,
136
+ content="",
137
+ keep=True
138
+ )
139
+
140
+ if call_result_raw and call_result_raw.content:
141
+ if isinstance(call_result_raw.content[0], TextContent):
142
+ action_result = ActionResult(
143
+ tool_name=server_name,
144
+ action_name=tool_name,
145
+ content=call_result_raw.content[0].text,
146
+ keep=True,
147
+ metadata=call_result_raw.content[0].model_extra.get(
148
+ "metadata", {}
149
+ ),
150
+ )
151
+ elif isinstance(call_result_raw.content[0], ImageContent):
152
+ action_result = ActionResult(
153
+ tool_name=server_name,
154
+ action_name=tool_name,
155
+ content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}",
156
+ keep=True,
157
+ metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
158
+ )
159
+ results.append(action_result)
160
+ self._update_metadata(result_key, action_result, operation_info)
161
+
162
+ except Exception as e:
163
+ logging.warning(f"Error calling tool with cached server: {e}")
164
+
165
+ self._update_metadata(result_key, {"error": str(e)}, operation_info)
166
+
167
+ # If using cached server instance fails, try to clean up and recreate
168
+ if server_name in self.server_instances:
169
+ try:
170
+ await cleanup_server(self.server_instances[server_name])
171
+ del self.server_instances[server_name]
172
+ except Exception as e:
173
+ logging.warning(f"Failed to cleanup server {server_name}: {e}")
174
+ except Exception as e:
175
+ logging.warning(f"Failed to call_tool: {e}")
176
+ return None
177
+
178
+ return results
179
+
180
+ def _update_metadata(self, result_key: str, result: Any, operation_info: Dict[str, Any]):
181
+ """
182
+ Update sandbox metadata with a single tool call result
183
+
184
+ Args:
185
+ result_key: The key name in metadata
186
+ result: Tool call result
187
+ operation_info: Operation information
188
+ """
189
+ if not self.sandbox or not hasattr(self.sandbox, '_metadata'):
190
+ return
191
+
192
+ try:
193
+ metadata = self.sandbox._metadata.get("mcp_metadata",{})
194
+ tmp_data = {
195
+ "input": operation_info,
196
+ "output": result
197
+ }
198
+ if not metadata:
199
+ metadata["mcp_metadata"] = {}
200
+ metadata["mcp_metadata"][result_key] = [tmp_data]
201
+ self.sandbox._metadata["mcp_metadata"] = metadata
202
+ return
203
+
204
+ _metadata = metadata.get(result_key, [])
205
+ if not _metadata:
206
+ _metadata[result_key] = [_metadata]
207
+ else:
208
+ _metadata[result_key].append(tmp_data)
209
+ metadata[result_key] = _metadata
210
+ self.sandbox._metadata["mcp_metadata"] = metadata
211
+ return
212
+
213
+ except Exception as e:
214
+ logging.warning(f"Failed to update sandbox metadata: {e}")
215
+
216
+ # Add cleanup method, called when Sandbox is destroyed
217
+ async def cleanup(self):
218
+ """Clean up all server connections"""
219
+ for server_name, server in list(self.server_instances.items()):
220
+ try:
221
+ await cleanup_server(server)
222
+ del self.server_instances[server_name]
223
+ logging.info(f"Cleaned up server instance for {server_name}")
224
+ except Exception as e:
225
+ logging.warning(f"Failed to cleanup server {server_name}: {e}")