Duibonduil commited on
Commit
da697a7
·
verified ·
1 Parent(s): 59256a2

Upload 3 files

Browse files
aworld/mcp_client/decorator.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ """
5
+ This module defines decorators for creating MCP servers.
6
+ By using the @mcp_server decorator, you can convert a Python class into an MCP server,
7
+ where the class methods will be automatically converted into MCP tools.
8
+ """
9
+
10
+ import inspect
11
+ import functools
12
+ import threading
13
+ from typing import Type, Dict, Any, Optional, Union, List
14
+
15
+ # Import FastMCP
16
+ from mcp.server import FastMCP
17
+
18
+ from aworld.core.factory import Factory
19
+ from aworld.logs.util import logger
20
+
21
+
22
+ # Save all decorated MCP server classes
23
+ class MCPServerRegistry(Factory):
24
+ """Register all MCP server classes"""
25
+
26
+ def __init__(self, type_name: str = None):
27
+ super().__init__(type_name)
28
+ self._instance = {}
29
+
30
+ def register(self, name: str, cls: Type, **kwargs):
31
+ """Register MCP server class"""
32
+ self._cls[name] = cls
33
+
34
+ def get_instance(self, name: str, *args, **kwargs):
35
+ """Get MCP server instance"""
36
+ if name not in self._instance:
37
+ if name not in self._cls:
38
+ raise ValueError(f"MCP server {name} not registered")
39
+ self._instance[name] = self._cls[name](*args, **kwargs)
40
+ return self._instance[name]
41
+
42
+
43
+ # Create global registry instance
44
+ MCPServers = MCPServerRegistry()
45
+
46
+
47
+ def extract_param_desc(method, param_name):
48
+ """Extract parameter description from method docstring"""
49
+ if not method.__doc__:
50
+ return None
51
+
52
+ param_docs = [
53
+ line.strip() for line in method.__doc__.split('\n')
54
+ if line.strip().startswith(f":param {param_name}:")
55
+ ]
56
+ if param_docs:
57
+ return param_docs[0].replace(f":param {param_name}:", "").strip()
58
+ return None
59
+
60
+
61
+ def mcp_server(name: str = None, **server_config):
62
+ """
63
+ Decorator to convert a class into an MCP server
64
+
65
+ Args:
66
+ name: Server name, if None, uses the class name
67
+ **server_config: Server configuration parameters
68
+ - mode: Server running mode, supports 'stdio' and 'sse' (default: 'sse')
69
+ - host: Host address in SSE mode (default: '127.0.0.1')
70
+ - port: Port number in SSE mode (default: 8888)
71
+ - sse_path: Path in SSE mode (default: '/sse')
72
+ - auto_start: Whether to automatically start the server (default: True)
73
+
74
+ Example:
75
+ @mcp_server(
76
+ name="simple-calculator",
77
+ mode="sse",
78
+ host="localhost",
79
+ port=8085,
80
+ sse_path="/calculator/sse"
81
+ )
82
+ class Calculator:
83
+ '''Server description'''
84
+
85
+ def __init__(self):
86
+ self.data = {}
87
+
88
+ def get_data(self, key: str) -> str:
89
+ '''Get data
90
+ :param key: Data key
91
+ :return: Data value
92
+ '''
93
+ return self.data.get(key, "")
94
+ """
95
+ # Extract server configuration or use defaults
96
+ mode = server_config.get('mode', 'sse')
97
+ host = server_config.get('host', '127.0.0.1')
98
+ port = server_config.get('port', 8888)
99
+ sse_path = server_config.get('sse_path', '/sse')
100
+ auto_start = server_config.get('auto_start', True)
101
+
102
+ def decorator(cls):
103
+ server_name = name or cls.__name__
104
+
105
+ # Use class docstring as server description
106
+ server_description = cls.__doc__ or f"{server_name} MCP Server"
107
+
108
+ # Original initialization method
109
+ original_init = cls.__init__
110
+
111
+ @functools.wraps(original_init)
112
+ def new_init(self, *args, **kwargs):
113
+ # Call original initialization method
114
+ original_init(self, *args, **kwargs)
115
+
116
+ # Create FastMCP instance, set server name and description
117
+ self._mcp = FastMCP(server_name, description=server_description.strip())
118
+
119
+ # Tool name list for recording
120
+ tool_names = []
121
+
122
+ # Get all methods, filter out built-in and private methods
123
+ for method_name, method in inspect.getmembers(self, inspect.ismethod):
124
+ if not method_name.startswith('_') and method_name != 'run':
125
+ # Get method docstring as tool description
126
+ tool_description = method.__doc__ or f"{method_name} tool"
127
+ tool_description = tool_description.strip()
128
+
129
+ # Record tool name
130
+ tool_names.append(method_name)
131
+
132
+ # Create tool and register, using a function generator to ensure each method is correctly bound
133
+ def create_tool_wrapper(method_to_call):
134
+ # Check if method is async
135
+ is_async = inspect.iscoroutinefunction(method_to_call)
136
+
137
+ if is_async:
138
+ @self._mcp.tool(name=method_name, description=tool_description)
139
+ @functools.wraps(method_to_call)
140
+ async def wrapped_method(*args, **kwargs):
141
+ return await method_to_call(*args, **kwargs)
142
+ else:
143
+ @self._mcp.tool(name=method_name, description=tool_description)
144
+ @functools.wraps(method_to_call)
145
+ def wrapped_method(*args, **kwargs):
146
+ return method_to_call(*args, **kwargs)
147
+
148
+ return wrapped_method
149
+
150
+ # Create a dedicated wrapper for each method
151
+ create_tool_wrapper(method)
152
+
153
+ # Print server information
154
+ logger.info(f"Creating MCP server: {server_name}")
155
+ logger.info(f"Server description: {server_description.strip()}")
156
+ if tool_names:
157
+ logger.info(f"Registered tools: {', '.join(tool_names)}")
158
+
159
+ # Save configuration
160
+ self._server_config = {
161
+ 'mode': mode,
162
+ 'host': host,
163
+ 'port': port,
164
+ 'sse_path': sse_path
165
+ }
166
+
167
+ # Auto start server if configured
168
+ if auto_start:
169
+ # Start server in a new thread to avoid blocking
170
+ thread = threading.Thread(
171
+ target=self.run,
172
+ kwargs=self._server_config,
173
+ daemon=True
174
+ )
175
+ thread.start()
176
+ logger.info(f"Server {server_name} started in a background thread")
177
+ self._server_thread = thread
178
+
179
+ # Replace initialization method
180
+ cls.__init__ = new_init
181
+
182
+ # Add method to run server
183
+ def run(self, mode: str = mode, host: str = host, port: int = port, sse_path: str = sse_path):
184
+ """
185
+ Run MCP server
186
+
187
+ Args:
188
+ mode: Server running mode, supports 'stdio' and 'sse'
189
+ host: Host address in SSE mode
190
+ port: Port number in SSE mode
191
+ sse_path: Path in SSE mode
192
+ """
193
+ if not hasattr(self, '_mcp') or self._mcp is None:
194
+ raise RuntimeError("MCP server not initialized")
195
+
196
+ # Run server according to mode
197
+ if mode == "stdio":
198
+ self._mcp.run(transport="stdio")
199
+ elif mode == "sse":
200
+ # Configure SSE mode settings
201
+ self._mcp.settings.host = host
202
+ self._mcp.settings.port = port
203
+ self._mcp.settings.sse_path = sse_path
204
+
205
+ # Print running information
206
+ print(f"Running MCP server: {server_name}")
207
+ print(f"Description: {server_description.strip()}")
208
+ print(f"Address: http://{host}:{port}{sse_path}")
209
+
210
+ self._mcp.run(transport="sse")
211
+ else:
212
+ raise ValueError(f"Unsupported mode: {mode}, supported modes are 'stdio' and 'sse'")
213
+
214
+ cls.run = run
215
+
216
+ # Add a stop method to gracefully stop the server
217
+ def stop(self):
218
+ """Stop the MCP server if it's running"""
219
+ if hasattr(self, '_mcp') and self._mcp is not None:
220
+ # TODO: Implement proper stopping mechanism based on FastMCP API
221
+ logger.info(f"Stopping server {server_name}")
222
+ # Currently there might not be a proper way to stop FastMCP server
223
+ # This is a placeholder for future implementation
224
+
225
+ cls.stop = stop
226
+
227
+ # Register to MCP server registry
228
+ MCPServers.register(server_name, cls)
229
+
230
+ # Return modified class
231
+ return cls
232
+
233
+ return decorator
aworld/mcp_client/server.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import asyncio
5
+ from datetime import timedelta
6
+ import logging
7
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack
8
+ from pathlib import Path
9
+ from typing import Any, Literal
10
+
11
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12
+ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
13
+ from mcp.client.sse import sse_client
14
+ from mcp.types import CallToolResult, JSONRPCMessage
15
+ from typing_extensions import NotRequired, TypedDict
16
+
17
+
18
+ class MCPServer(abc.ABC):
19
+ """Base class for Model Context Protocol servers."""
20
+
21
+ @abc.abstractmethod
22
+ async def connect(self):
23
+ """Connect to the server. For example, this might mean spawning a subprocess or
24
+ opening a network connection. The server is expected to remain connected until
25
+ `cleanup()` is called.
26
+ """
27
+ pass
28
+
29
+ @property
30
+ @abc.abstractmethod
31
+ def name(self) -> str:
32
+ """A readable name for the server."""
33
+ pass
34
+
35
+ @abc.abstractmethod
36
+ async def cleanup(self):
37
+ """Cleanup the server. For example, this might mean closing a subprocess or
38
+ closing a network connection.
39
+ """
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ async def list_tools(self) -> list[MCPTool]:
44
+ """List the tools available on the server."""
45
+ pass
46
+
47
+ @abc.abstractmethod
48
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
49
+ """Invoke a tool on the server."""
50
+ pass
51
+
52
+
53
+ class _MCPServerWithClientSession(MCPServer, abc.ABC):
54
+ """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
55
+
56
+ def __init__(self, cache_tools_list: bool, session_connect_timeout_seconds: int = 30):
57
+ """
58
+ Args:
59
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
60
+ cached and only fetched from the server once. If `False`, the tools list will be
61
+ fetched from the server on each call to `list_tools()`. The cache can be invalidated
62
+ by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
63
+ server will not change its tools list, because it can drastically improve latency
64
+ (by avoiding a round-trip to the server every time).
65
+
66
+ session_connect_timeout_seconds: session connect timeout seconds
67
+ """
68
+ self.session: ClientSession | None = None
69
+ self.exit_stack: AsyncExitStack = AsyncExitStack()
70
+ self._cleanup_lock: asyncio.Lock = asyncio.Lock()
71
+ self.cache_tools_list = cache_tools_list
72
+ self.session_connect_timeout_seconds = timedelta(seconds=session_connect_timeout_seconds)
73
+
74
+ # The cache is always dirty at startup, so that we fetch tools at least once
75
+ self._cache_dirty = True
76
+ self._tools_list: list[MCPTool] | None = None
77
+
78
+ @abc.abstractmethod
79
+ def create_streams(
80
+ self,
81
+ ) -> AbstractAsyncContextManager[
82
+ tuple[
83
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
84
+ MemoryObjectSendStream[JSONRPCMessage],
85
+ ]
86
+ ]:
87
+ """Create the streams for the server."""
88
+ pass
89
+
90
+ async def __aenter__(self):
91
+ await self.connect()
92
+ return self
93
+
94
+ async def __aexit__(self, exc_type, exc_value, traceback):
95
+ await self.cleanup()
96
+
97
+ def invalidate_tools_cache(self):
98
+ """Invalidate the tools cache."""
99
+ self._cache_dirty = True
100
+
101
+ async def connect(self):
102
+ """Connect to the server."""
103
+ try:
104
+ # Ensure closing previous exit_stack to avoid nested async contexts
105
+ if hasattr(self, 'exit_stack') and self.exit_stack:
106
+ try:
107
+ await self.exit_stack.aclose()
108
+ except Exception as e:
109
+ logging.error(f"Error closing previous exit stack: {e}")
110
+
111
+ self.exit_stack = AsyncExitStack()
112
+
113
+ # Use a single task context to create the connection
114
+ transport = await self.exit_stack.enter_async_context(self.create_streams())
115
+ read, write = transport
116
+ session = await self.exit_stack.enter_async_context(ClientSession(read, write, read_timeout_seconds=self.session_connect_timeout_seconds))
117
+ await session.initialize()
118
+ self.session = session
119
+ except Exception as e:
120
+ logging.error(f"Error initializing MCP server: {e}")
121
+ # Ensure resources are cleaned up if connection fails
122
+ await self.cleanup()
123
+ raise
124
+
125
+ async def list_tools(self) -> list[MCPTool]:
126
+ """List the tools available on the server."""
127
+ if not self.session:
128
+ raise RuntimeError("Server not initialized. Make sure you call `connect()` first.")
129
+
130
+ # Return from cache if caching is enabled, we have tools, and the cache is not dirty
131
+ if self.cache_tools_list and not self._cache_dirty and self._tools_list:
132
+ return self._tools_list
133
+
134
+ # Reset the cache dirty to False
135
+ self._cache_dirty = False
136
+
137
+ # Fetch the tools from the server
138
+ self._tools_list = (await self.session.list_tools()).tools
139
+ return self._tools_list
140
+
141
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
142
+ """Invoke a tool on the server."""
143
+ if not self.session:
144
+ raise RuntimeError("Server not initialized. Make sure you call `connect()` first.")
145
+
146
+ return await self.session.call_tool(tool_name, arguments)
147
+
148
+ async def cleanup(self):
149
+ """Cleanup the server."""
150
+ async with self._cleanup_lock:
151
+ try:
152
+ # Ensure cleanup operations occur in the same task context
153
+ session = self.session
154
+ self.session = None # Remove reference first
155
+
156
+ # Wait briefly to ensure any pending operations complete
157
+ try:
158
+ await asyncio.sleep(0.1)
159
+ except asyncio.CancelledError:
160
+ # Ignore cancellation exceptions, continue cleaning resources
161
+ pass
162
+
163
+ # Clean up exit_stack, ensuring all resources are properly closed
164
+ exit_stack = self.exit_stack
165
+ if exit_stack:
166
+ try:
167
+ await exit_stack.aclose()
168
+ except Exception as e:
169
+ logging.error(f"Error closing exit stack during cleanup: {e}")
170
+ except Exception as e:
171
+ logging.error(f"Error during server cleanup: {e}")
172
+
173
+
174
+ class MCPServerStdioParams(TypedDict):
175
+ """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
176
+ import.
177
+ """
178
+
179
+ command: str
180
+ """The executable to run to start the server. For example, `python` or `node`."""
181
+
182
+ args: NotRequired[list[str]]
183
+ """Command line args to pass to the `command` executable. For example, `['foo.py']` or
184
+ `['server.js', '--port', '8080']`."""
185
+
186
+ env: NotRequired[dict[str, str]]
187
+ """The environment variables to set for the server. ."""
188
+
189
+ cwd: NotRequired[str | Path]
190
+ """The working directory to use when spawning the process."""
191
+
192
+ encoding: NotRequired[str]
193
+ """The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
194
+
195
+ encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
196
+ """The text encoding error handler. Defaults to `strict`.
197
+
198
+ See https://docs.python.org/3/library/codecs.html#codec-base-classes for
199
+ explanations of possible values.
200
+ """
201
+
202
+
203
+ class MCPServerStdio(_MCPServerWithClientSession):
204
+ """MCP server implementation that uses the stdio transport. See the [spec]
205
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
206
+ details.
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ params: MCPServerStdioParams,
212
+ cache_tools_list: bool = False,
213
+ name: str | None = None,
214
+ ):
215
+ """Create a new MCP server based on the stdio transport.
216
+
217
+ Args:
218
+ params: The params that configure the server. This includes the command to run to
219
+ start the server, the args to pass to the command, the environment variables to
220
+ set for the server, the working directory to use when spawning the process, and
221
+ the text encoding used when sending/receiving messages to the server.
222
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
223
+ cached and only fetched from the server once. If `False`, the tools list will be
224
+ fetched from the server on each call to `list_tools()`. The cache can be
225
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
226
+ if you know the server will not change its tools list, because it can drastically
227
+ improve latency (by avoiding a round-trip to the server every time).
228
+ name: A readable name for the server. If not provided, we'll create one from the
229
+ command.
230
+ """
231
+ super().__init__(cache_tools_list, int(params.get("env").get("SESSION_REQUEST_CONNECT_TIMEOUT", "60")))
232
+
233
+ self.params = StdioServerParameters(
234
+ command=params["command"],
235
+ args=params.get("args", []),
236
+ env=params.get("env"),
237
+ cwd=params.get("cwd"),
238
+ encoding=params.get("encoding", "utf-8"),
239
+ encoding_error_handler=params.get("encoding_error_handler", "strict"),
240
+ )
241
+
242
+ self._name = name or f"stdio: {self.params.command}"
243
+
244
+ def create_streams(
245
+ self,
246
+ ) -> AbstractAsyncContextManager[
247
+ tuple[
248
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
249
+ MemoryObjectSendStream[JSONRPCMessage],
250
+ ]
251
+ ]:
252
+ """Create the streams for the server."""
253
+ return stdio_client(self.params)
254
+
255
+ @property
256
+ def name(self) -> str:
257
+ """A readable name for the server."""
258
+ return self._name
259
+
260
+
261
+ class MCPServerSseParams(TypedDict):
262
+ """Mirrors the params in`mcp.client.sse.sse_client`."""
263
+
264
+ url: str
265
+ """The URL of the server."""
266
+
267
+ headers: NotRequired[dict[str, str]]
268
+ """The headers to send to the server."""
269
+
270
+ timeout: NotRequired[float]
271
+ """The timeout for the HTTP request. Defaults to 60 seconds."""
272
+
273
+ sse_read_timeout: NotRequired[float]
274
+ """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
275
+
276
+
277
+ class MCPServerSse(_MCPServerWithClientSession):
278
+ """MCP server implementation that uses the HTTP with SSE transport. See the [spec]
279
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
280
+ for details.
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ params: MCPServerSseParams,
286
+ cache_tools_list: bool = False,
287
+ name: str | None = None,
288
+ ):
289
+ """Create a new MCP server based on the HTTP with SSE transport.
290
+
291
+ Args:
292
+ params: The params that configure the server. This includes the URL of the server,
293
+ the headers to send to the server, the timeout for the HTTP request, and the
294
+ timeout for the SSE connection.
295
+
296
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
297
+ cached and only fetched from the server once. If `False`, the tools list will be
298
+ fetched from the server on each call to `list_tools()`. The cache can be
299
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
300
+ if you know the server will not change its tools list, because it can drastically
301
+ improve latency (by avoiding a round-trip to the server every time).
302
+
303
+ name: A readable name for the server. If not provided, we'll create one from the
304
+ URL.
305
+ """
306
+ super().__init__(cache_tools_list)
307
+
308
+ self.params = params
309
+ self._name = name or f"sse: {self.params['url']}"
310
+
311
+ def create_streams(
312
+ self,
313
+ ) -> AbstractAsyncContextManager[
314
+ tuple[
315
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
316
+ MemoryObjectSendStream[JSONRPCMessage],
317
+ ]
318
+ ]:
319
+ """Create the streams for the server."""
320
+ return sse_client(
321
+ url=self.params["url"],
322
+ headers=self.params.get("headers", None),
323
+ timeout=self.params.get("timeout", 60),
324
+ sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
325
+ )
326
+
327
+ @property
328
+ def name(self) -> str:
329
+ """A readable name for the server."""
330
+ return self._name
aworld/mcp_client/utils.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any
3
+ import json
4
+ import os
5
+ from contextlib import AsyncExitStack
6
+ import traceback
7
+
8
+ import requests
9
+ from mcp.types import TextContent, ImageContent
10
+
11
+ from aworld.core.common import ActionResult
12
+
13
+ from aworld.logs.util import logger
14
+ from aworld.mcp_client.server import MCPServer, MCPServerSse, MCPServerStdio
15
+ from aworld.tools import get_function_tools
16
+ from aworld.utils.common import find_file
17
+
18
+ MCP_SERVERS_CONFIG = {}
19
+
20
+
21
+ def get_function_tool(sever_name: str) -> List[Dict[str, Any]]:
22
+ openai_tools = []
23
+ try:
24
+ if not sever_name:
25
+ return []
26
+ tool_server = get_function_tools(sever_name)
27
+ if not tool_server:
28
+ return []
29
+ tools = tool_server.list_tools()
30
+ if not tools:
31
+ return []
32
+ for tool in tools:
33
+ required = []
34
+ properties = {}
35
+ if tool.inputSchema and tool.inputSchema.get("properties"):
36
+ required = tool.inputSchema.get("required", [])
37
+ _properties = tool.inputSchema["properties"]
38
+ for param_name, param_info in _properties.items():
39
+ param_type = (
40
+ param_info.get("type")
41
+ if param_info.get("type") != "str"
42
+ and param_info.get("type") is not None
43
+ else "string"
44
+ )
45
+ param_desc = param_info.get("description", "")
46
+ if param_type == "array":
47
+ # Handle array type parameters
48
+ items_info = param_info.get("items", {})
49
+ item_type = items_info.get("type", "string")
50
+
51
+ # Process nested array type parameters
52
+ if item_type == "array":
53
+ nested_items = items_info.get("items", {})
54
+ nested_type = nested_items.get("type", "string")
55
+
56
+ # If the nested type is an object
57
+ if nested_type == "object":
58
+ properties[param_name] = {
59
+ "description": param_desc,
60
+ "type": param_type,
61
+ "items": {
62
+ "type": item_type,
63
+ "items": {
64
+ "type": nested_type,
65
+ "properties": nested_items.get(
66
+ "properties", {}
67
+ ),
68
+ "required": nested_items.get(
69
+ "required", []
70
+ ),
71
+ },
72
+ },
73
+ }
74
+ else:
75
+ properties[param_name] = {
76
+ "description": param_desc,
77
+ "type": param_type,
78
+ "items": {
79
+ "type": item_type,
80
+ "items": {"type": nested_type},
81
+ },
82
+ }
83
+ # Process object type cases
84
+ elif item_type == "object":
85
+ properties[param_name] = {
86
+ "description": param_desc,
87
+ "type": param_type,
88
+ "items": {
89
+ "type": item_type,
90
+ "properties": items_info.get("properties", {}),
91
+ "required": items_info.get("required", []),
92
+ },
93
+ }
94
+ # Process basic type cases
95
+ else:
96
+ if item_type == "str":
97
+ item_type = "string"
98
+ properties[param_name] = {
99
+ "description": param_desc,
100
+ "type": param_type,
101
+ "items": {"type": item_type},
102
+ }
103
+ else:
104
+ # Handle non-array type parameters
105
+ properties[param_name] = {
106
+ "description": param_desc,
107
+ "type": param_type,
108
+ }
109
+
110
+ openai_function_schema = {
111
+ "name": f"mcp__{sever_name}__{tool.name}",
112
+ "description": tool.description,
113
+ "parameters": {
114
+ "type": "object",
115
+ "properties": properties,
116
+ "required": required,
117
+ },
118
+ }
119
+ openai_tools.append(
120
+ {
121
+ "type": "function",
122
+ "function": openai_function_schema,
123
+ }
124
+ )
125
+ logging.info(
126
+ f"✅ function_tool_server #({sever_name}) connected success,tools: {len(tools)}"
127
+ )
128
+
129
+ except Exception as e:
130
+ logging.warning(
131
+ f"server_name-get_function_tool:{sever_name} translate failed: {e}"
132
+ )
133
+ return []
134
+ finally:
135
+ return openai_tools
136
+
137
+
138
+ async def run(mcp_servers: list[MCPServer]) -> List[Dict[str, Any]]:
139
+ openai_tools = []
140
+ for i, server in enumerate(mcp_servers):
141
+ try:
142
+ tools = await server.list_tools()
143
+ for tool in tools:
144
+ required = []
145
+ properties = {}
146
+ if tool.inputSchema and tool.inputSchema.get("properties"):
147
+ required = tool.inputSchema.get("required", [])
148
+ _properties = tool.inputSchema["properties"]
149
+ for param_name, param_info in _properties.items():
150
+ param_type = (
151
+ param_info.get("type")
152
+ if param_info.get("type") != "str"
153
+ and param_info.get("type") is not None
154
+ else "string"
155
+ )
156
+ param_desc = param_info.get("description", "")
157
+ if param_type == "array":
158
+ # Handle array type parameters
159
+ items_info = param_info.get("items", {})
160
+ item_type = items_info.get("type", "string")
161
+
162
+ # Process nested array type parameters
163
+ if item_type == "array":
164
+ nested_items = items_info.get("items", {})
165
+ nested_type = nested_items.get("type", "string")
166
+
167
+ # If the nested type is an object
168
+ if nested_type == "object":
169
+ properties[param_name] = {
170
+ "description": param_desc,
171
+ "type": param_type,
172
+ "items": {
173
+ "type": item_type,
174
+ "items": {
175
+ "type": nested_type,
176
+ "properties": nested_items.get(
177
+ "properties", {}
178
+ ),
179
+ "required": nested_items.get(
180
+ "required", []
181
+ ),
182
+ },
183
+ },
184
+ }
185
+ else:
186
+ properties[param_name] = {
187
+ "description": param_desc,
188
+ "type": param_type,
189
+ "items": {
190
+ "type": item_type,
191
+ "items": {"type": nested_type},
192
+ },
193
+ }
194
+ # Process object type cases
195
+ elif item_type == "object":
196
+ properties[param_name] = {
197
+ "description": param_desc,
198
+ "type": param_type,
199
+ "items": {
200
+ "type": item_type,
201
+ "properties": items_info.get("properties", {}),
202
+ "required": items_info.get("required", []),
203
+ },
204
+ }
205
+ # Process basic type cases
206
+ else:
207
+ if item_type == "str":
208
+ item_type = "string"
209
+ properties[param_name] = {
210
+ "description": param_desc,
211
+ "type": param_type,
212
+ "items": {"type": item_type},
213
+ }
214
+ else:
215
+ # Handle non-array type parameters
216
+ properties[param_name] = {
217
+ "description": param_desc,
218
+ "type": param_type,
219
+ }
220
+
221
+ openai_function_schema = {
222
+ "name": f"{server.name}__{tool.name}",
223
+ "description": tool.description,
224
+ "parameters": {
225
+ "type": "object",
226
+ "properties": properties,
227
+ "required": required,
228
+ },
229
+ }
230
+ openai_tools.append(
231
+ {
232
+ "type": "function",
233
+ "function": openai_function_schema,
234
+ }
235
+ )
236
+ logging.info(
237
+ f"✅ server #{i + 1} ({server.name}) connected success,tools: {len(tools)}"
238
+ )
239
+
240
+ except Exception as e:
241
+ logging.error(f"❌ server #{i + 1} ({server.name}) connect fail: {e}")
242
+ return []
243
+
244
+ return openai_tools
245
+
246
+
247
+ async def mcp_tool_desc_transform(
248
+ tools: List[str] = None, mcp_config: Dict[str, Any] = None
249
+ ) -> List[Dict[str, Any]]:
250
+ """Default implement transform framework standard protocol to openai protocol of tool description."""
251
+ config = {}
252
+ global MCP_SERVERS_CONFIG
253
+
254
+ def _replace_env_variables(config):
255
+ if isinstance(config, dict):
256
+ for key, value in config.items():
257
+ if (
258
+ isinstance(value, str)
259
+ and value.startswith("${")
260
+ and value.endswith("}")
261
+ ):
262
+ env_var_name = value[2:-1]
263
+ config[key] = os.getenv(env_var_name, value)
264
+ logging.info(f"Replaced {value} with {config[key]}")
265
+ elif isinstance(value, dict) or isinstance(value, list):
266
+ _replace_env_variables(value)
267
+ elif isinstance(config, list):
268
+ for index, item in enumerate(config):
269
+ if (
270
+ isinstance(item, str)
271
+ and item.startswith("${")
272
+ and item.endswith("}")
273
+ ):
274
+ env_var_name = item[2:-1]
275
+ config[index] = os.getenv(env_var_name, item)
276
+ logging.info(f"Replaced {item} with {config[index]}")
277
+ elif isinstance(item, dict) or isinstance(item, list):
278
+ _replace_env_variables(item)
279
+
280
+ if mcp_config:
281
+ try:
282
+ config = mcp_config
283
+ MCP_SERVERS_CONFIG = config
284
+ except Exception as e:
285
+ logging.error(f"mcp_config error: {e}")
286
+ return []
287
+ else:
288
+ # Priority given to the running path.
289
+ config_path = find_file(filename="mcp.json")
290
+ if not os.path.exists(config_path):
291
+ current_dir = os.path.dirname(os.path.abspath(__file__))
292
+ config_path = os.path.normpath(
293
+ os.path.join(current_dir, "../config/mcp.json")
294
+ )
295
+ logger.info(f"mcp conf path: {config_path}")
296
+
297
+ if not os.path.exists(config_path):
298
+ logging.info(f"mcp config is not exist: {config_path}")
299
+ return []
300
+
301
+ try:
302
+ with open(config_path, "r") as f:
303
+ config = json.load(f)
304
+ except Exception as e:
305
+ logging.info(f"load config fail: {e}")
306
+ return []
307
+ _replace_env_variables(config)
308
+
309
+ MCP_SERVERS_CONFIG = config
310
+
311
+ mcp_servers_config = config.get("mcpServers", {})
312
+
313
+ server_configs = []
314
+ for server_name, server_config in mcp_servers_config.items():
315
+ # Skip disabled servers
316
+ if server_config.get("disabled", False):
317
+ continue
318
+
319
+ if tools is None or server_name in tools:
320
+ # Handle SSE server
321
+ if "url" in server_config:
322
+ server_configs.append(
323
+ {
324
+ "name": "mcp__" + server_name,
325
+ "type": "sse",
326
+ "params": {"url": server_config["url"]},
327
+ }
328
+ )
329
+ # Handle stdio server
330
+ elif "command" in server_config:
331
+ server_configs.append(
332
+ {
333
+ "name": "mcp__" + server_name,
334
+ "type": "stdio",
335
+ "params": {
336
+ "command": server_config["command"],
337
+ "args": server_config.get("args", []),
338
+ "env": server_config.get("env", {}),
339
+ "cwd": server_config.get("cwd"),
340
+ "encoding": server_config.get("encoding", "utf-8"),
341
+ "encoding_error_handler": server_config.get(
342
+ "encoding_error_handler", "strict"
343
+ ),
344
+ },
345
+ }
346
+ )
347
+
348
+ if not server_configs:
349
+ return []
350
+
351
+ async with AsyncExitStack() as stack:
352
+ servers = []
353
+ for server_config in server_configs:
354
+ try:
355
+ if server_config["type"] == "sse":
356
+ server = MCPServerSse(
357
+ name=server_config["name"], params=server_config["params"]
358
+ )
359
+ elif server_config["type"] == "stdio":
360
+ from aworld.mcp_client.server import MCPServerStdio
361
+
362
+ server = MCPServerStdio(
363
+ name=server_config["name"], params=server_config["params"]
364
+ )
365
+ else:
366
+ logging.warning(
367
+ f"Unsupported MCP server type: {server_config['type']}"
368
+ )
369
+ continue
370
+
371
+ server = await stack.enter_async_context(server)
372
+ servers.append(server)
373
+ except BaseException as err:
374
+ # single
375
+ logging.error(
376
+ f"Failed to get tools for MCP server '{server_config['name']}'.\n"
377
+ f"Error: {err}\n"
378
+ f"Traceback:\n{traceback.format_exc()}"
379
+ )
380
+
381
+ openai_tools = await run(servers)
382
+
383
+ return openai_tools
384
+
385
+
386
+ async def sandbox_mcp_tool_desc_transform(
387
+ tools: List[str] = None, mcp_config: Dict[str, Any] = None
388
+ ) -> List[Dict[str, Any]]:
389
+ # todo sandbox mcp_config get from registry
390
+
391
+ if not mcp_config:
392
+ return None
393
+ config = mcp_config
394
+ mcp_servers_config = config.get("mcpServers", {})
395
+ server_configs = []
396
+ openai_tools = []
397
+ mcp_openai_tools = []
398
+
399
+ for server_name, server_config in mcp_servers_config.items():
400
+ # Skip disabled servers
401
+ if server_config.get("disabled", False):
402
+ continue
403
+
404
+ if tools is None or server_name in tools:
405
+ # Handle SSE server
406
+ if "function_tool" == server_config.get("type", ""):
407
+ try:
408
+ tmp_function_tool = get_function_tool(server_name)
409
+ openai_tools.extend(tmp_function_tool)
410
+ except Exception as e:
411
+ logging.warning(f"server_name:{server_name} translate failed: {e}")
412
+ elif "api" == server_config.get("type", ""):
413
+ api_result = requests.get(server_config["url"] + "/list_tools")
414
+ try:
415
+ if not api_result or not api_result.text:
416
+ continue
417
+ # return None
418
+ data = json.loads(api_result.text)
419
+ if not data or not data.get("tools"):
420
+ continue
421
+ for item in data.get("tools"):
422
+ tmp_function = {
423
+ "type": "function",
424
+ "function": {
425
+ "name": "mcp__" + server_name + "__" + item["name"],
426
+ "description": item["description"],
427
+ "parameters": {
428
+ **item["parameters"],
429
+ "properties": {
430
+ k: v
431
+ for k, v in item["parameters"]
432
+ .get("properties", {})
433
+ .items()
434
+ if "default" not in v
435
+ },
436
+ },
437
+ },
438
+ }
439
+ openai_tools.append(tmp_function)
440
+ except Exception as e:
441
+ logging.warning(f"server_name:{server_name} translate failed: {e}")
442
+ elif "sse" == server_config.get("type", ""):
443
+ server_configs.append(
444
+ {
445
+ "name": "mcp__" + server_name,
446
+ "type": "sse",
447
+ "params": {
448
+ "url": server_config["url"],
449
+ "headers": server_config.get("headers"),
450
+ },
451
+ }
452
+ )
453
+ # Handle stdio server
454
+ else:
455
+ # elif "stdio" == server_config.get("type", ""):
456
+ server_configs.append(
457
+ {
458
+ "name": "mcp__" + server_name,
459
+ "type": "stdio",
460
+ "params": {
461
+ "command": server_config["command"],
462
+ "args": server_config.get("args", []),
463
+ "env": server_config.get("env", {}),
464
+ "cwd": server_config.get("cwd"),
465
+ "encoding": server_config.get("encoding", "utf-8"),
466
+ "encoding_error_handler": server_config.get(
467
+ "encoding_error_handler", "strict"
468
+ ),
469
+ },
470
+ }
471
+ )
472
+
473
+ if not server_configs:
474
+ return openai_tools
475
+
476
+ async with AsyncExitStack() as stack:
477
+ servers = []
478
+ for server_config in server_configs:
479
+ try:
480
+ if server_config["type"] == "sse":
481
+ server = MCPServerSse(
482
+ name=server_config["name"], params=server_config["params"]
483
+ )
484
+ elif server_config["type"] == "stdio":
485
+ server = MCPServerStdio(
486
+ name=server_config["name"], params=server_config["params"]
487
+ )
488
+ else:
489
+ logging.warning(
490
+ f"Unsupported MCP server type: {server_config['type']}"
491
+ )
492
+ continue
493
+
494
+ server = await stack.enter_async_context(server)
495
+ servers.append(server)
496
+ except BaseException as err:
497
+ # single
498
+ logging.error(
499
+ f"Failed to get tools for MCP server '{server_config['name']}'.\n"
500
+ f"Error: {err}\n"
501
+ )
502
+
503
+ mcp_openai_tools = await run(servers)
504
+
505
+ if mcp_openai_tools:
506
+ openai_tools.extend(mcp_openai_tools)
507
+
508
+ return openai_tools
509
+
510
+
511
+ async def call_function_tool(
512
+ server_name: str,
513
+ tool_name: str,
514
+ parameter: Dict[str, Any] = None,
515
+ mcp_config: Dict[str, Any] = None,
516
+ ) -> ActionResult:
517
+ """Specifically handle API type server calls
518
+
519
+ Args:
520
+ server_name: Server name
521
+ tool_name: Tool name
522
+ parameter: Parameters
523
+ mcp_config: MCP configuration
524
+
525
+ Returns:
526
+ ActionResult: Call result
527
+ """
528
+ action_result = ActionResult(
529
+ tool_name=server_name, action_name=tool_name, content="", keep=True
530
+ )
531
+ try:
532
+ tool_server = get_function_tools(server_name)
533
+ if not tool_server:
534
+ return action_result
535
+ call_result_raw = tool_server.call_tool(tool_name, parameter)
536
+ if call_result_raw and call_result_raw.content:
537
+ if isinstance(call_result_raw.content[0], TextContent):
538
+ action_result = ActionResult(
539
+ tool_name=server_name,
540
+ action_name=tool_name,
541
+ content=call_result_raw.content[0].text,
542
+ keep=True,
543
+ metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
544
+ )
545
+ elif isinstance(call_result_raw.content[0], ImageContent):
546
+ action_result = ActionResult(
547
+ tool_name=server_name,
548
+ action_name=tool_name,
549
+ content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}",
550
+ keep=True,
551
+ metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
552
+ )
553
+
554
+ except Exception as e:
555
+ logging.warning(f"call_function_tool ({server_name})({tool_name}) failed: {e}")
556
+ action_result = ActionResult(
557
+ tool_name=server_name, action_name=tool_name, content="", keep=True
558
+ )
559
+
560
+ return action_result
561
+
562
+
563
+ async def call_api(
564
+ server_name: str,
565
+ tool_name: str,
566
+ parameter: Dict[str, Any] = None,
567
+ mcp_config: Dict[str, Any] = None,
568
+ ) -> ActionResult:
569
+ """Specifically handle API type server calls
570
+
571
+ Args:
572
+ server_name: Server name
573
+ tool_name: Tool name
574
+ parameter: Parameters
575
+ mcp_config: MCP configuration
576
+
577
+ Returns:
578
+ ActionResult: Call result
579
+ """
580
+ action_result = ActionResult(
581
+ tool_name=server_name, action_name=tool_name, content="", keep=True
582
+ )
583
+
584
+ if not mcp_config or mcp_config.get("mcpServers") is None:
585
+ return action_result
586
+
587
+ mcp_servers = mcp_config.get("mcpServers")
588
+ if not mcp_servers.get(server_name):
589
+ return action_result
590
+
591
+ server_config = mcp_servers.get(server_name)
592
+ if "api" != server_config.get("type", ""):
593
+ logging.warning(
594
+ f"Server {server_name} is not API type, should use call_tool instead"
595
+ )
596
+ return action_result
597
+
598
+ try:
599
+ headers = {"Content-Type": "application/json"}
600
+ response = requests.post(
601
+ url=server_config["url"] + "/" + tool_name, headers=headers, json=parameter
602
+ )
603
+ action_result = ActionResult(
604
+ tool_name=server_name,
605
+ action_name=tool_name,
606
+ content=response.text,
607
+ keep=True,
608
+ )
609
+ except Exception as e:
610
+ logging.warning(f"call_api ({server_name})({tool_name}) failed: {e}")
611
+ action_result = ActionResult(
612
+ tool_name=server_name,
613
+ action_name=tool_name,
614
+ content=f"Error calling API: {str(e)}",
615
+ keep=True,
616
+ )
617
+
618
+ return action_result
619
+
620
+
621
+ async def get_server_instance(
622
+ server_name: str, mcp_config: Dict[str, Any] = None
623
+ ) -> Any:
624
+ """Get server instance, create a new one if it doesn't exist
625
+
626
+ Args:
627
+ server_name: Server name
628
+ mcp_config: MCP configuration
629
+
630
+ Returns:
631
+ Server instance or None (if creation fails)
632
+ """
633
+ if not mcp_config or mcp_config.get("mcpServers") is None:
634
+ return None
635
+
636
+ mcp_servers = mcp_config.get("mcpServers")
637
+ if not mcp_servers.get(server_name):
638
+ return None
639
+
640
+ server_config = mcp_servers.get(server_name)
641
+ try:
642
+ # API type servers use special handling, no need for persistent connections
643
+ # Note: We've already handled API type in McpServers.call_tool method
644
+ # Here we don't return None, but let the caller handle it
645
+ if "api" == server_config.get("type", ""):
646
+ logging.info(f"API server {server_name} doesn't need persistent connection")
647
+ return None
648
+ elif "sse" == server_config.get("type", ""):
649
+ server = MCPServerSse(
650
+ name=server_name,
651
+ params={
652
+ "url": server_config["url"],
653
+ "headers": server_config.get("headers"),
654
+ "timeout": server_config.get("timeout", 5.0),
655
+ "sse_read_timeout": server_config.get("sse_read_timeout", 300.0),
656
+ },
657
+ )
658
+ await server.connect()
659
+ logging.info(f"Successfully connected to SSE server: {server_name}")
660
+ return server
661
+ else: # stdio type
662
+ params = {
663
+ "command": server_config["command"],
664
+ "args": server_config.get("args", []),
665
+ "env": server_config.get("env", {}),
666
+ "cwd": server_config.get("cwd"),
667
+ "encoding": server_config.get("encoding", "utf-8"),
668
+ "encoding_error_handler": server_config.get(
669
+ "encoding_error_handler", "strict"
670
+ ),
671
+ }
672
+ server = MCPServerStdio(name=server_name, params=params)
673
+ await server.connect()
674
+ logging.info(f"Successfully connected to stdio server: {server_name}")
675
+ return server
676
+ except Exception as e:
677
+ logging.warning(f"Failed to create server instance for {server_name}: {e}")
678
+ return None
679
+
680
+
681
+ async def cleanup_server(server):
682
+ """Clean up server connection
683
+
684
+ Args:
685
+ server: Server instance
686
+ """
687
+ try:
688
+ if hasattr(server, "cleanup"):
689
+ await server.cleanup()
690
+ elif hasattr(server, "close"):
691
+ await server.close()
692
+ logging.info(
693
+ f"Successfully cleaned up server: {getattr(server, 'name', 'unknown')}"
694
+ )
695
+ except Exception as e:
696
+ logging.warning(f"Failed to cleanup server: {e}")