Duibonduil commited on
Commit
19bf17b
·
verified ·
1 Parent(s): eb2d0a3

Upload 2 files

Browse files
aworld/tools/mcp_tool/async_mcp_tool.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ from aworld.core.context.base import Context
7
+
8
+ from aworld.config.conf import ToolConfig, ConfigDict
9
+ from aworld.core.common import ActionModel, Observation, ActionResult
10
+ from aworld.core.tool.base import ToolFactory, AsyncTool
11
+ from aworld.logs.util import logger
12
+ from aworld.tools.mcp_tool.executor import MCPToolExecutor
13
+ from aworld.tools.utils import build_observation
14
+
15
+
16
+ @ToolFactory.register(name="mcp",
17
+ desc="mcp execute tool",
18
+ asyn=True)
19
+ class McpTool(AsyncTool):
20
+ def __init__(self, conf: Union[Dict[str, Any], ConfigDict, ToolConfig], **kwargs) -> None:
21
+ """Initialize the McpTool.
22
+
23
+ Args:
24
+ conf: tool config
25
+ """
26
+ super(McpTool, self).__init__(conf, **kwargs)
27
+ self.action_executor = MCPToolExecutor(self)
28
+
29
+ async def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[
30
+ Observation, dict[str, Any]]:
31
+ self._finished = False
32
+ return build_observation(observer=self.name(), ability=""), {}
33
+
34
+ async def close(self) -> None:
35
+ self._finished = True
36
+ # default only close playwright
37
+ await self.action_executor.close(self.conf.get('close_servers', ['ms-playwright']))
38
+
39
+ async def do_step(self,
40
+ actions: list[ActionModel],
41
+ **kwargs) -> Tuple[Observation, float, bool, bool, dict[str, Any]]:
42
+ """Step of tool.
43
+
44
+ Args:
45
+ actions: actions
46
+ **kwargs: -
47
+ Returns:
48
+ Observation, float, bool, bool, dict[str, Any]: -
49
+ """
50
+ from aworld.core.agent.base import AgentFactory
51
+
52
+ self._finished = False
53
+ reward = 0
54
+ fail_error = ""
55
+ terminated = kwargs.get("terminated", False)
56
+ # todo sandbox
57
+ agent = AgentFactory.agent_instance(actions[0].agent_name)
58
+ if not agent:
59
+ logger.warning(f"async_mcp_tool can not get agent,agent_name:{actions[0].agent_name}")
60
+ task_id = Context.instance().task_id
61
+ session_id = Context.instance().session_id
62
+
63
+ if not actions:
64
+ self._finished = True
65
+ observation = build_observation(observer=self.name(),
66
+ content="raw actions is empty",
67
+ ability="")
68
+ return (observation,
69
+ reward,
70
+ terminated,
71
+ kwargs.get("truncated", False),
72
+ {"exception": "actions is empty"})
73
+
74
+ mcp_actions = []
75
+ for action in actions:
76
+ tool_name = action.tool_name
77
+ if 'mcp' != tool_name:
78
+ logger.warning(f"Unsupported tool: {tool_name}")
79
+ continue
80
+ full_tool_name = action.action_name
81
+ names = full_tool_name.split("__")
82
+ if len(names) < 2:
83
+ logger.warning(f"{full_tool_name} illegal format")
84
+ continue
85
+ action.action_name = names[1]
86
+ action.tool_name = names[0]
87
+ mcp_actions.append(action)
88
+ if not mcp_actions:
89
+ self._finished = True
90
+ observation = build_observation(observer=self.name(),
91
+ content="no valid mcp actions",
92
+ ability=actions[-1].action_name)
93
+ return (observation, reward,
94
+ terminated,
95
+ kwargs.get("truncated", False),
96
+ {"exception": "no valid mcp actions"})
97
+
98
+ action_results = None
99
+ try:
100
+ # todo sandbox
101
+ if agent and agent.sandbox:
102
+ sand_box = agent.sandbox
103
+ action_results = await sand_box.mcpservers.call_tool(action_list=mcp_actions,task_id=task_id,session_id=session_id)
104
+ else:
105
+ action_results, ignore = await self.action_executor.async_execute_action(mcp_actions)
106
+ reward = 1
107
+ except Exception as e:
108
+ fail_error = str(e)
109
+ finally:
110
+ self._finished = True
111
+
112
+ observation = build_observation(observer=self.name(),
113
+ ability=actions[-1].action_name)
114
+ if action_results:
115
+ for res in action_results:
116
+ if res.is_done:
117
+ terminated = res.is_done
118
+ if res.error:
119
+ fail_error += res.error
120
+
121
+ observation.action_result = action_results
122
+ observation.content = action_results[-1].content
123
+ else:
124
+ if self.conf.get('exit_on_failure'):
125
+ raise Exception(fail_error)
126
+ else:
127
+ logger.warning(f"{actions} no action results, fail info: {fail_error}, will use fail action results")
128
+ # every action need has the result
129
+ action_results = [ActionResult(success=False, content=fail_error, error=fail_error) for _ in actions]
130
+ observation.action_result = action_results
131
+ observation.content = fail_error
132
+
133
+ info = {"exception": fail_error, **kwargs}
134
+ return (observation,
135
+ reward,
136
+ terminated,
137
+ kwargs.get("truncated", False),
138
+ info)
aworld/tools/mcp_tool/executor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import os
4
+ import traceback
5
+ import asyncio
6
+
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ from mcp.types import TextContent, ImageContent
10
+
11
+ from aworld.core.common import ActionModel, ActionResult, Observation
12
+ from aworld.core.tool.base import ToolActionExecutor, Tool, AsyncTool
13
+ from aworld.logs.util import logger
14
+ from aworld.mcp_client.server import MCPServer, MCPServerSse
15
+ import aworld.mcp_client.utils as mcp_utils
16
+ from aworld.utils.common import sync_exec, find_file
17
+
18
+
19
+ class MCPToolExecutor(ToolActionExecutor):
20
+ """A tool executor that uses MCP server to execute actions."""
21
+
22
+ def __init__(self, tool: Union[Tool, AsyncTool] = None):
23
+ """Initialize the MCP tool executor."""
24
+ super().__init__(tool)
25
+ self.initialized = False
26
+ self.mcp_servers: Dict[str, MCPServer] = {}
27
+ self._load_mcp_config()
28
+
29
+ def _replace_env_variables(self, config):
30
+ if isinstance(config, dict):
31
+ for key, value in config.items():
32
+ if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
33
+ env_var_name = value[2:-1]
34
+ config[key] = os.getenv(env_var_name, value)
35
+ logger.info(f"Replaced {value} with {config[key]}")
36
+ elif isinstance(value, dict) or isinstance(value, list):
37
+ self._replace_env_variables(value)
38
+ elif isinstance(config, list):
39
+ for index, item in enumerate(config):
40
+ if isinstance(item, str) and item.startswith("${") and item.endswith("}"):
41
+ env_var_name = item[2:-1]
42
+ config[index] = os.getenv(env_var_name, item)
43
+ logger.info(f"Replaced {item} with {config[index]}")
44
+ elif isinstance(item, dict) or isinstance(item, list):
45
+ self._replace_env_variables(item)
46
+
47
+ def _load_mcp_config(self) -> None:
48
+ """Load MCP server configurations from config file."""
49
+ try:
50
+ config_data = {}
51
+ if mcp_utils.MCP_SERVERS_CONFIG:
52
+ config_data=mcp_utils.MCP_SERVERS_CONFIG
53
+ else:
54
+ # Priority given to the running path.
55
+ config_path = find_file(filename='mcp.json')
56
+ if not os.path.exists(config_path):
57
+ # Use relative path for config file
58
+ current_dir = os.path.dirname(os.path.abspath(__file__))
59
+ config_path = os.path.normpath(os.path.join(current_dir, "../../config/mcp.json"))
60
+ logger.info(f"mcp conf path: {config_path}")
61
+
62
+ with open(config_path, "r") as f:
63
+ config_data = json.load(f)
64
+
65
+ # Replace environment variables in the configuration
66
+ self._replace_env_variables(config_data)
67
+
68
+ # Load all server configurations
69
+ for server_name, server_config in config_data.get("mcpServers", {}).items():
70
+ # Skip disabled servers
71
+ if server_config.get("disabled", False):
72
+ continue
73
+
74
+ # Handle SSE server
75
+ if "url" in server_config:
76
+ self.mcp_servers[server_name] = {
77
+ "type": "sse",
78
+ "url": server_config["url"],
79
+ "instance": None,
80
+ "timeout": server_config.get('timeout', 5.),
81
+ "sse_read_timeout": server_config.get('sse_read_timeout', 300.0),
82
+ "headers": server_config.get('headers')
83
+ }
84
+ # Handle stdio server
85
+ elif "command" in server_config:
86
+ self.mcp_servers[server_name] = {
87
+ "type": "stdio",
88
+ "command": server_config["command"],
89
+ "args": server_config.get("args", []),
90
+ "env": server_config.get("env", {}),
91
+ "cwd": server_config.get("cwd"),
92
+ "encoding": server_config.get("encoding", "utf-8"),
93
+ "encoding_error_handler": server_config.get("encoding_error_handler", "strict"),
94
+ "instance": None
95
+ }
96
+
97
+ self.initialized = True
98
+ except Exception as e:
99
+ logger.error(f"Failed to load MCP config: {traceback.format_exc()}")
100
+
101
+ async def _get_or_create_server(self, server_name: str) -> MCPServer:
102
+ """Get an existing MCP server instance or create a new one."""
103
+ if server_name not in self.mcp_servers:
104
+ raise ValueError(f"MCP server '{server_name}' not found in configuration")
105
+
106
+ server_info = self.mcp_servers[server_name]
107
+
108
+ # If an instance already exists, check if it's available and reuse it
109
+ if server_info.get("instance"):
110
+ return server_info["instance"]
111
+
112
+ server_type = server_info.get("type", "sse")
113
+
114
+ try:
115
+ if server_type == "sse":
116
+ # Create new SSE server instance
117
+ server_params = {
118
+ "url": server_info["url"],
119
+ "timeout": server_info['timeout'],
120
+ "sse_read_timeout": server_info['sse_read_timeout'],
121
+ "headers": server_info['headers']
122
+ }
123
+
124
+ server = MCPServerSse(server_params, cache_tools_list=True, name=server_name)
125
+ elif server_type == "stdio":
126
+ # Create new stdio server instance
127
+ server_params = {
128
+ "command": server_info["command"],
129
+ "args": server_info["args"],
130
+ "env": server_info["env"],
131
+ "cwd": server_info.get("cwd"),
132
+ "encoding": server_info["encoding"],
133
+ "encoding_error_handler": server_info["encoding_error_handler"]
134
+ }
135
+
136
+ from aworld.mcp_client.server import MCPServerStdio
137
+ server = MCPServerStdio(server_params, cache_tools_list=True, name=server_name)
138
+ else:
139
+ raise ValueError(f"Unsupported MCP server type: {server_type}")
140
+
141
+ # Try to connect, with special handling for cancellation exceptions
142
+ try:
143
+ await server.connect()
144
+ except asyncio.CancelledError:
145
+ # When the task is cancelled, ensure resources are cleaned up
146
+ logger.warning(f"Connection to server '{server_name}' was cancelled")
147
+ await server.cleanup()
148
+ raise
149
+
150
+ server_info["instance"] = server
151
+ return server
152
+
153
+ except asyncio.CancelledError:
154
+ # Pass cancellation exceptions up to be handled by the caller
155
+ raise
156
+ except Exception as e:
157
+ logger.error(f"Failed to connect to MCP server '{server_name}': {e}")
158
+ raise
159
+
160
+ async def async_execute_action(self, actions: List[ActionModel], **kwargs) -> Tuple[
161
+ List[ActionResult], Any]:
162
+ """Execute actions using the MCP server.
163
+
164
+ Args:
165
+ actions: A list of action models to execute
166
+ **kwargs: Additional arguments
167
+
168
+ Returns:
169
+ A list of action results
170
+ """
171
+ if not self.initialized:
172
+ raise RuntimeError("MCP Tool Executor not initialized")
173
+
174
+ if not actions:
175
+ return [], None
176
+
177
+ results = []
178
+ for action in actions:
179
+ # Get server and operation information
180
+ server_name = action.tool_name
181
+ if not server_name:
182
+ raise ValueError("Missing tool_name in action model")
183
+
184
+ action_name = action.action_name
185
+ if not action_name:
186
+ raise ValueError("Missing action_name in action model")
187
+
188
+ params = action.params or {}
189
+
190
+ try:
191
+ server = self.mcp_servers.get(server_name, {}).get('instance', None)
192
+ if not server:
193
+ # Get or create MCP server
194
+ server = await self._get_or_create_server(server_name)
195
+
196
+ # Call the tool and process results
197
+ try:
198
+ result = await server.call_tool(action_name, params)
199
+
200
+ if result and result.content:
201
+ if isinstance(result.content[0], TextContent):
202
+ action_result = ActionResult(
203
+ content=result.content[0].text,
204
+ keep=True
205
+ )
206
+ elif isinstance(result.content[0], ImageContent):
207
+ action_result = ActionResult(
208
+ content=f"data:image/jpeg;base64,{result.content[0].data}",
209
+ keep=True
210
+ )
211
+ else:
212
+ action_result = ActionResult(
213
+ content="",
214
+ keep=True
215
+ )
216
+ logger.warning("Unsupported content type is error:")
217
+ else:
218
+ action_result = ActionResult(
219
+ content="",
220
+ keep=True
221
+ )
222
+ logger.warning("mcp result is null")
223
+
224
+ results.append(action_result)
225
+ except asyncio.CancelledError:
226
+ # Log cancellation exception, reset server connection to avoid async context confusion
227
+ logger.warning(f"Tool call to {action_name} on {server_name} was cancelled")
228
+ if server_name in self.mcp_servers and self.mcp_servers[server_name].get("instance"):
229
+ try:
230
+ await self.mcp_servers[server_name]["instance"].cleanup()
231
+ self.mcp_servers[server_name]["instance"] = None
232
+ except Exception as cleanup_error:
233
+ logger.error(f"Error cleaning up server after cancellation: {cleanup_error}")
234
+ # Re-raise exception to notify upper level caller
235
+ raise
236
+
237
+ except asyncio.CancelledError:
238
+ # Pass cancellation exception
239
+ logger.warning("Async execution was cancelled")
240
+ raise
241
+
242
+ except Exception as e:
243
+ # Handle general errors
244
+ error_msg = str(e)
245
+ logger.error(f"Error executing MCP action: {error_msg}")
246
+ action_result = ActionResult(
247
+ content=f"Error executing tool: {error_msg}",
248
+ keep=True
249
+ )
250
+ results.append(action_result)
251
+
252
+ return results, None
253
+
254
+ async def cleanup(self) -> None:
255
+ """Clean up all MCP server connections."""
256
+ for server_name, server_info in self.mcp_servers.items():
257
+ if server_info.get("instance"):
258
+ try:
259
+ await server_info["instance"].cleanup()
260
+ except Exception as e:
261
+ logger.error(f"Error cleaning up MCP server {server_name}: {e}")
262
+
263
+ async def close(self, keys: List[str] = []) -> None:
264
+ if keys:
265
+ for key in keys:
266
+ if key in self.mcp_servers:
267
+ server_info = self.mcp_servers[key]
268
+ # Fixme: Have resources leak, MCP server may clean fail.
269
+ if server_info.get("type") == "stdio":
270
+ server_info["instance"] = None
271
+ continue
272
+
273
+ try:
274
+ await server_info["instance"].cleanup()
275
+ except Exception as e:
276
+ logger.error(f"Error cleaning up MCP server {key}: {e}")
277
+
278
+ def execute_action(self, actions: List[ActionModel], **kwargs) -> Tuple[
279
+ List[ActionResult], Any]:
280
+ return sync_exec(self.async_execute_action, actions, **kwargs)