Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- aworld/mcp_client/decorator.py +233 -0
- aworld/mcp_client/server.py +330 -0
- aworld/mcp_client/utils.py +696 -0
aworld/mcp_client/decorator.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
This module defines decorators for creating MCP servers.
|
| 6 |
+
By using the @mcp_server decorator, you can convert a Python class into an MCP server,
|
| 7 |
+
where the class methods will be automatically converted into MCP tools.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import inspect
|
| 11 |
+
import functools
|
| 12 |
+
import threading
|
| 13 |
+
from typing import Type, Dict, Any, Optional, Union, List
|
| 14 |
+
|
| 15 |
+
# Import FastMCP
|
| 16 |
+
from mcp.server import FastMCP
|
| 17 |
+
|
| 18 |
+
from aworld.core.factory import Factory
|
| 19 |
+
from aworld.logs.util import logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Save all decorated MCP server classes
|
| 23 |
+
class MCPServerRegistry(Factory):
|
| 24 |
+
"""Register all MCP server classes"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, type_name: str = None):
|
| 27 |
+
super().__init__(type_name)
|
| 28 |
+
self._instance = {}
|
| 29 |
+
|
| 30 |
+
def register(self, name: str, cls: Type, **kwargs):
|
| 31 |
+
"""Register MCP server class"""
|
| 32 |
+
self._cls[name] = cls
|
| 33 |
+
|
| 34 |
+
def get_instance(self, name: str, *args, **kwargs):
|
| 35 |
+
"""Get MCP server instance"""
|
| 36 |
+
if name not in self._instance:
|
| 37 |
+
if name not in self._cls:
|
| 38 |
+
raise ValueError(f"MCP server {name} not registered")
|
| 39 |
+
self._instance[name] = self._cls[name](*args, **kwargs)
|
| 40 |
+
return self._instance[name]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Create global registry instance
|
| 44 |
+
MCPServers = MCPServerRegistry()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extract_param_desc(method, param_name):
|
| 48 |
+
"""Extract parameter description from method docstring"""
|
| 49 |
+
if not method.__doc__:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
param_docs = [
|
| 53 |
+
line.strip() for line in method.__doc__.split('\n')
|
| 54 |
+
if line.strip().startswith(f":param {param_name}:")
|
| 55 |
+
]
|
| 56 |
+
if param_docs:
|
| 57 |
+
return param_docs[0].replace(f":param {param_name}:", "").strip()
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def mcp_server(name: str = None, **server_config):
|
| 62 |
+
"""
|
| 63 |
+
Decorator to convert a class into an MCP server
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
name: Server name, if None, uses the class name
|
| 67 |
+
**server_config: Server configuration parameters
|
| 68 |
+
- mode: Server running mode, supports 'stdio' and 'sse' (default: 'sse')
|
| 69 |
+
- host: Host address in SSE mode (default: '127.0.0.1')
|
| 70 |
+
- port: Port number in SSE mode (default: 8888)
|
| 71 |
+
- sse_path: Path in SSE mode (default: '/sse')
|
| 72 |
+
- auto_start: Whether to automatically start the server (default: True)
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
@mcp_server(
|
| 76 |
+
name="simple-calculator",
|
| 77 |
+
mode="sse",
|
| 78 |
+
host="localhost",
|
| 79 |
+
port=8085,
|
| 80 |
+
sse_path="/calculator/sse"
|
| 81 |
+
)
|
| 82 |
+
class Calculator:
|
| 83 |
+
'''Server description'''
|
| 84 |
+
|
| 85 |
+
def __init__(self):
|
| 86 |
+
self.data = {}
|
| 87 |
+
|
| 88 |
+
def get_data(self, key: str) -> str:
|
| 89 |
+
'''Get data
|
| 90 |
+
:param key: Data key
|
| 91 |
+
:return: Data value
|
| 92 |
+
'''
|
| 93 |
+
return self.data.get(key, "")
|
| 94 |
+
"""
|
| 95 |
+
# Extract server configuration or use defaults
|
| 96 |
+
mode = server_config.get('mode', 'sse')
|
| 97 |
+
host = server_config.get('host', '127.0.0.1')
|
| 98 |
+
port = server_config.get('port', 8888)
|
| 99 |
+
sse_path = server_config.get('sse_path', '/sse')
|
| 100 |
+
auto_start = server_config.get('auto_start', True)
|
| 101 |
+
|
| 102 |
+
def decorator(cls):
|
| 103 |
+
server_name = name or cls.__name__
|
| 104 |
+
|
| 105 |
+
# Use class docstring as server description
|
| 106 |
+
server_description = cls.__doc__ or f"{server_name} MCP Server"
|
| 107 |
+
|
| 108 |
+
# Original initialization method
|
| 109 |
+
original_init = cls.__init__
|
| 110 |
+
|
| 111 |
+
@functools.wraps(original_init)
|
| 112 |
+
def new_init(self, *args, **kwargs):
|
| 113 |
+
# Call original initialization method
|
| 114 |
+
original_init(self, *args, **kwargs)
|
| 115 |
+
|
| 116 |
+
# Create FastMCP instance, set server name and description
|
| 117 |
+
self._mcp = FastMCP(server_name, description=server_description.strip())
|
| 118 |
+
|
| 119 |
+
# Tool name list for recording
|
| 120 |
+
tool_names = []
|
| 121 |
+
|
| 122 |
+
# Get all methods, filter out built-in and private methods
|
| 123 |
+
for method_name, method in inspect.getmembers(self, inspect.ismethod):
|
| 124 |
+
if not method_name.startswith('_') and method_name != 'run':
|
| 125 |
+
# Get method docstring as tool description
|
| 126 |
+
tool_description = method.__doc__ or f"{method_name} tool"
|
| 127 |
+
tool_description = tool_description.strip()
|
| 128 |
+
|
| 129 |
+
# Record tool name
|
| 130 |
+
tool_names.append(method_name)
|
| 131 |
+
|
| 132 |
+
# Create tool and register, using a function generator to ensure each method is correctly bound
|
| 133 |
+
def create_tool_wrapper(method_to_call):
|
| 134 |
+
# Check if method is async
|
| 135 |
+
is_async = inspect.iscoroutinefunction(method_to_call)
|
| 136 |
+
|
| 137 |
+
if is_async:
|
| 138 |
+
@self._mcp.tool(name=method_name, description=tool_description)
|
| 139 |
+
@functools.wraps(method_to_call)
|
| 140 |
+
async def wrapped_method(*args, **kwargs):
|
| 141 |
+
return await method_to_call(*args, **kwargs)
|
| 142 |
+
else:
|
| 143 |
+
@self._mcp.tool(name=method_name, description=tool_description)
|
| 144 |
+
@functools.wraps(method_to_call)
|
| 145 |
+
def wrapped_method(*args, **kwargs):
|
| 146 |
+
return method_to_call(*args, **kwargs)
|
| 147 |
+
|
| 148 |
+
return wrapped_method
|
| 149 |
+
|
| 150 |
+
# Create a dedicated wrapper for each method
|
| 151 |
+
create_tool_wrapper(method)
|
| 152 |
+
|
| 153 |
+
# Print server information
|
| 154 |
+
logger.info(f"Creating MCP server: {server_name}")
|
| 155 |
+
logger.info(f"Server description: {server_description.strip()}")
|
| 156 |
+
if tool_names:
|
| 157 |
+
logger.info(f"Registered tools: {', '.join(tool_names)}")
|
| 158 |
+
|
| 159 |
+
# Save configuration
|
| 160 |
+
self._server_config = {
|
| 161 |
+
'mode': mode,
|
| 162 |
+
'host': host,
|
| 163 |
+
'port': port,
|
| 164 |
+
'sse_path': sse_path
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
# Auto start server if configured
|
| 168 |
+
if auto_start:
|
| 169 |
+
# Start server in a new thread to avoid blocking
|
| 170 |
+
thread = threading.Thread(
|
| 171 |
+
target=self.run,
|
| 172 |
+
kwargs=self._server_config,
|
| 173 |
+
daemon=True
|
| 174 |
+
)
|
| 175 |
+
thread.start()
|
| 176 |
+
logger.info(f"Server {server_name} started in a background thread")
|
| 177 |
+
self._server_thread = thread
|
| 178 |
+
|
| 179 |
+
# Replace initialization method
|
| 180 |
+
cls.__init__ = new_init
|
| 181 |
+
|
| 182 |
+
# Add method to run server
|
| 183 |
+
def run(self, mode: str = mode, host: str = host, port: int = port, sse_path: str = sse_path):
|
| 184 |
+
"""
|
| 185 |
+
Run MCP server
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
mode: Server running mode, supports 'stdio' and 'sse'
|
| 189 |
+
host: Host address in SSE mode
|
| 190 |
+
port: Port number in SSE mode
|
| 191 |
+
sse_path: Path in SSE mode
|
| 192 |
+
"""
|
| 193 |
+
if not hasattr(self, '_mcp') or self._mcp is None:
|
| 194 |
+
raise RuntimeError("MCP server not initialized")
|
| 195 |
+
|
| 196 |
+
# Run server according to mode
|
| 197 |
+
if mode == "stdio":
|
| 198 |
+
self._mcp.run(transport="stdio")
|
| 199 |
+
elif mode == "sse":
|
| 200 |
+
# Configure SSE mode settings
|
| 201 |
+
self._mcp.settings.host = host
|
| 202 |
+
self._mcp.settings.port = port
|
| 203 |
+
self._mcp.settings.sse_path = sse_path
|
| 204 |
+
|
| 205 |
+
# Print running information
|
| 206 |
+
print(f"Running MCP server: {server_name}")
|
| 207 |
+
print(f"Description: {server_description.strip()}")
|
| 208 |
+
print(f"Address: http://{host}:{port}{sse_path}")
|
| 209 |
+
|
| 210 |
+
self._mcp.run(transport="sse")
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(f"Unsupported mode: {mode}, supported modes are 'stdio' and 'sse'")
|
| 213 |
+
|
| 214 |
+
cls.run = run
|
| 215 |
+
|
| 216 |
+
# Add a stop method to gracefully stop the server
|
| 217 |
+
def stop(self):
|
| 218 |
+
"""Stop the MCP server if it's running"""
|
| 219 |
+
if hasattr(self, '_mcp') and self._mcp is not None:
|
| 220 |
+
# TODO: Implement proper stopping mechanism based on FastMCP API
|
| 221 |
+
logger.info(f"Stopping server {server_name}")
|
| 222 |
+
# Currently there might not be a proper way to stop FastMCP server
|
| 223 |
+
# This is a placeholder for future implementation
|
| 224 |
+
|
| 225 |
+
cls.stop = stop
|
| 226 |
+
|
| 227 |
+
# Register to MCP server registry
|
| 228 |
+
MCPServers.register(server_name, cls)
|
| 229 |
+
|
| 230 |
+
# Return modified class
|
| 231 |
+
return cls
|
| 232 |
+
|
| 233 |
+
return decorator
|
aworld/mcp_client/server.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
import asyncio
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
import logging
|
| 7 |
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Literal
|
| 10 |
+
|
| 11 |
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
| 12 |
+
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
| 13 |
+
from mcp.client.sse import sse_client
|
| 14 |
+
from mcp.types import CallToolResult, JSONRPCMessage
|
| 15 |
+
from typing_extensions import NotRequired, TypedDict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MCPServer(abc.ABC):
|
| 19 |
+
"""Base class for Model Context Protocol servers."""
|
| 20 |
+
|
| 21 |
+
@abc.abstractmethod
|
| 22 |
+
async def connect(self):
|
| 23 |
+
"""Connect to the server. For example, this might mean spawning a subprocess or
|
| 24 |
+
opening a network connection. The server is expected to remain connected until
|
| 25 |
+
`cleanup()` is called.
|
| 26 |
+
"""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
@abc.abstractmethod
|
| 31 |
+
def name(self) -> str:
|
| 32 |
+
"""A readable name for the server."""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
@abc.abstractmethod
|
| 36 |
+
async def cleanup(self):
|
| 37 |
+
"""Cleanup the server. For example, this might mean closing a subprocess or
|
| 38 |
+
closing a network connection.
|
| 39 |
+
"""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
@abc.abstractmethod
|
| 43 |
+
async def list_tools(self) -> list[MCPTool]:
|
| 44 |
+
"""List the tools available on the server."""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abc.abstractmethod
|
| 48 |
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
| 49 |
+
"""Invoke a tool on the server."""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
| 54 |
+
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, cache_tools_list: bool, session_connect_timeout_seconds: int = 30):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
| 60 |
+
cached and only fetched from the server once. If `False`, the tools list will be
|
| 61 |
+
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
| 62 |
+
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
| 63 |
+
server will not change its tools list, because it can drastically improve latency
|
| 64 |
+
(by avoiding a round-trip to the server every time).
|
| 65 |
+
|
| 66 |
+
session_connect_timeout_seconds: session connect timeout seconds
|
| 67 |
+
"""
|
| 68 |
+
self.session: ClientSession | None = None
|
| 69 |
+
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
| 70 |
+
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
| 71 |
+
self.cache_tools_list = cache_tools_list
|
| 72 |
+
self.session_connect_timeout_seconds = timedelta(seconds=session_connect_timeout_seconds)
|
| 73 |
+
|
| 74 |
+
# The cache is always dirty at startup, so that we fetch tools at least once
|
| 75 |
+
self._cache_dirty = True
|
| 76 |
+
self._tools_list: list[MCPTool] | None = None
|
| 77 |
+
|
| 78 |
+
@abc.abstractmethod
|
| 79 |
+
def create_streams(
|
| 80 |
+
self,
|
| 81 |
+
) -> AbstractAsyncContextManager[
|
| 82 |
+
tuple[
|
| 83 |
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
| 84 |
+
MemoryObjectSendStream[JSONRPCMessage],
|
| 85 |
+
]
|
| 86 |
+
]:
|
| 87 |
+
"""Create the streams for the server."""
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
async def __aenter__(self):
|
| 91 |
+
await self.connect()
|
| 92 |
+
return self
|
| 93 |
+
|
| 94 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
| 95 |
+
await self.cleanup()
|
| 96 |
+
|
| 97 |
+
def invalidate_tools_cache(self):
|
| 98 |
+
"""Invalidate the tools cache."""
|
| 99 |
+
self._cache_dirty = True
|
| 100 |
+
|
| 101 |
+
async def connect(self):
|
| 102 |
+
"""Connect to the server."""
|
| 103 |
+
try:
|
| 104 |
+
# Ensure closing previous exit_stack to avoid nested async contexts
|
| 105 |
+
if hasattr(self, 'exit_stack') and self.exit_stack:
|
| 106 |
+
try:
|
| 107 |
+
await self.exit_stack.aclose()
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logging.error(f"Error closing previous exit stack: {e}")
|
| 110 |
+
|
| 111 |
+
self.exit_stack = AsyncExitStack()
|
| 112 |
+
|
| 113 |
+
# Use a single task context to create the connection
|
| 114 |
+
transport = await self.exit_stack.enter_async_context(self.create_streams())
|
| 115 |
+
read, write = transport
|
| 116 |
+
session = await self.exit_stack.enter_async_context(ClientSession(read, write, read_timeout_seconds=self.session_connect_timeout_seconds))
|
| 117 |
+
await session.initialize()
|
| 118 |
+
self.session = session
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logging.error(f"Error initializing MCP server: {e}")
|
| 121 |
+
# Ensure resources are cleaned up if connection fails
|
| 122 |
+
await self.cleanup()
|
| 123 |
+
raise
|
| 124 |
+
|
| 125 |
+
async def list_tools(self) -> list[MCPTool]:
|
| 126 |
+
"""List the tools available on the server."""
|
| 127 |
+
if not self.session:
|
| 128 |
+
raise RuntimeError("Server not initialized. Make sure you call `connect()` first.")
|
| 129 |
+
|
| 130 |
+
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
| 131 |
+
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
| 132 |
+
return self._tools_list
|
| 133 |
+
|
| 134 |
+
# Reset the cache dirty to False
|
| 135 |
+
self._cache_dirty = False
|
| 136 |
+
|
| 137 |
+
# Fetch the tools from the server
|
| 138 |
+
self._tools_list = (await self.session.list_tools()).tools
|
| 139 |
+
return self._tools_list
|
| 140 |
+
|
| 141 |
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
|
| 142 |
+
"""Invoke a tool on the server."""
|
| 143 |
+
if not self.session:
|
| 144 |
+
raise RuntimeError("Server not initialized. Make sure you call `connect()` first.")
|
| 145 |
+
|
| 146 |
+
return await self.session.call_tool(tool_name, arguments)
|
| 147 |
+
|
| 148 |
+
async def cleanup(self):
|
| 149 |
+
"""Cleanup the server."""
|
| 150 |
+
async with self._cleanup_lock:
|
| 151 |
+
try:
|
| 152 |
+
# Ensure cleanup operations occur in the same task context
|
| 153 |
+
session = self.session
|
| 154 |
+
self.session = None # Remove reference first
|
| 155 |
+
|
| 156 |
+
# Wait briefly to ensure any pending operations complete
|
| 157 |
+
try:
|
| 158 |
+
await asyncio.sleep(0.1)
|
| 159 |
+
except asyncio.CancelledError:
|
| 160 |
+
# Ignore cancellation exceptions, continue cleaning resources
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
# Clean up exit_stack, ensuring all resources are properly closed
|
| 164 |
+
exit_stack = self.exit_stack
|
| 165 |
+
if exit_stack:
|
| 166 |
+
try:
|
| 167 |
+
await exit_stack.aclose()
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logging.error(f"Error closing exit stack during cleanup: {e}")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logging.error(f"Error during server cleanup: {e}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class MCPServerStdioParams(TypedDict):
|
| 175 |
+
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
| 176 |
+
import.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
command: str
|
| 180 |
+
"""The executable to run to start the server. For example, `python` or `node`."""
|
| 181 |
+
|
| 182 |
+
args: NotRequired[list[str]]
|
| 183 |
+
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
| 184 |
+
`['server.js', '--port', '8080']`."""
|
| 185 |
+
|
| 186 |
+
env: NotRequired[dict[str, str]]
|
| 187 |
+
"""The environment variables to set for the server. ."""
|
| 188 |
+
|
| 189 |
+
cwd: NotRequired[str | Path]
|
| 190 |
+
"""The working directory to use when spawning the process."""
|
| 191 |
+
|
| 192 |
+
encoding: NotRequired[str]
|
| 193 |
+
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
| 194 |
+
|
| 195 |
+
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
| 196 |
+
"""The text encoding error handler. Defaults to `strict`.
|
| 197 |
+
|
| 198 |
+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
| 199 |
+
explanations of possible values.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class MCPServerStdio(_MCPServerWithClientSession):
|
| 204 |
+
"""MCP server implementation that uses the stdio transport. See the [spec]
|
| 205 |
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
| 206 |
+
details.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
params: MCPServerStdioParams,
|
| 212 |
+
cache_tools_list: bool = False,
|
| 213 |
+
name: str | None = None,
|
| 214 |
+
):
|
| 215 |
+
"""Create a new MCP server based on the stdio transport.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
params: The params that configure the server. This includes the command to run to
|
| 219 |
+
start the server, the args to pass to the command, the environment variables to
|
| 220 |
+
set for the server, the working directory to use when spawning the process, and
|
| 221 |
+
the text encoding used when sending/receiving messages to the server.
|
| 222 |
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
| 223 |
+
cached and only fetched from the server once. If `False`, the tools list will be
|
| 224 |
+
fetched from the server on each call to `list_tools()`. The cache can be
|
| 225 |
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
| 226 |
+
if you know the server will not change its tools list, because it can drastically
|
| 227 |
+
improve latency (by avoiding a round-trip to the server every time).
|
| 228 |
+
name: A readable name for the server. If not provided, we'll create one from the
|
| 229 |
+
command.
|
| 230 |
+
"""
|
| 231 |
+
super().__init__(cache_tools_list, int(params.get("env").get("SESSION_REQUEST_CONNECT_TIMEOUT", "60")))
|
| 232 |
+
|
| 233 |
+
self.params = StdioServerParameters(
|
| 234 |
+
command=params["command"],
|
| 235 |
+
args=params.get("args", []),
|
| 236 |
+
env=params.get("env"),
|
| 237 |
+
cwd=params.get("cwd"),
|
| 238 |
+
encoding=params.get("encoding", "utf-8"),
|
| 239 |
+
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self._name = name or f"stdio: {self.params.command}"
|
| 243 |
+
|
| 244 |
+
def create_streams(
|
| 245 |
+
self,
|
| 246 |
+
) -> AbstractAsyncContextManager[
|
| 247 |
+
tuple[
|
| 248 |
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
| 249 |
+
MemoryObjectSendStream[JSONRPCMessage],
|
| 250 |
+
]
|
| 251 |
+
]:
|
| 252 |
+
"""Create the streams for the server."""
|
| 253 |
+
return stdio_client(self.params)
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def name(self) -> str:
|
| 257 |
+
"""A readable name for the server."""
|
| 258 |
+
return self._name
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class MCPServerSseParams(TypedDict):
|
| 262 |
+
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
| 263 |
+
|
| 264 |
+
url: str
|
| 265 |
+
"""The URL of the server."""
|
| 266 |
+
|
| 267 |
+
headers: NotRequired[dict[str, str]]
|
| 268 |
+
"""The headers to send to the server."""
|
| 269 |
+
|
| 270 |
+
timeout: NotRequired[float]
|
| 271 |
+
"""The timeout for the HTTP request. Defaults to 60 seconds."""
|
| 272 |
+
|
| 273 |
+
sse_read_timeout: NotRequired[float]
|
| 274 |
+
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class MCPServerSse(_MCPServerWithClientSession):
|
| 278 |
+
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
| 279 |
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
| 280 |
+
for details.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(
|
| 284 |
+
self,
|
| 285 |
+
params: MCPServerSseParams,
|
| 286 |
+
cache_tools_list: bool = False,
|
| 287 |
+
name: str | None = None,
|
| 288 |
+
):
|
| 289 |
+
"""Create a new MCP server based on the HTTP with SSE transport.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
params: The params that configure the server. This includes the URL of the server,
|
| 293 |
+
the headers to send to the server, the timeout for the HTTP request, and the
|
| 294 |
+
timeout for the SSE connection.
|
| 295 |
+
|
| 296 |
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
| 297 |
+
cached and only fetched from the server once. If `False`, the tools list will be
|
| 298 |
+
fetched from the server on each call to `list_tools()`. The cache can be
|
| 299 |
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
| 300 |
+
if you know the server will not change its tools list, because it can drastically
|
| 301 |
+
improve latency (by avoiding a round-trip to the server every time).
|
| 302 |
+
|
| 303 |
+
name: A readable name for the server. If not provided, we'll create one from the
|
| 304 |
+
URL.
|
| 305 |
+
"""
|
| 306 |
+
super().__init__(cache_tools_list)
|
| 307 |
+
|
| 308 |
+
self.params = params
|
| 309 |
+
self._name = name or f"sse: {self.params['url']}"
|
| 310 |
+
|
| 311 |
+
def create_streams(
|
| 312 |
+
self,
|
| 313 |
+
) -> AbstractAsyncContextManager[
|
| 314 |
+
tuple[
|
| 315 |
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
| 316 |
+
MemoryObjectSendStream[JSONRPCMessage],
|
| 317 |
+
]
|
| 318 |
+
]:
|
| 319 |
+
"""Create the streams for the server."""
|
| 320 |
+
return sse_client(
|
| 321 |
+
url=self.params["url"],
|
| 322 |
+
headers=self.params.get("headers", None),
|
| 323 |
+
timeout=self.params.get("timeout", 60),
|
| 324 |
+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def name(self) -> str:
|
| 329 |
+
"""A readable name for the server."""
|
| 330 |
+
return self._name
|
aworld/mcp_client/utils.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from contextlib import AsyncExitStack
|
| 6 |
+
import traceback
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
from mcp.types import TextContent, ImageContent
|
| 10 |
+
|
| 11 |
+
from aworld.core.common import ActionResult
|
| 12 |
+
|
| 13 |
+
from aworld.logs.util import logger
|
| 14 |
+
from aworld.mcp_client.server import MCPServer, MCPServerSse, MCPServerStdio
|
| 15 |
+
from aworld.tools import get_function_tools
|
| 16 |
+
from aworld.utils.common import find_file
|
| 17 |
+
|
| 18 |
+
MCP_SERVERS_CONFIG = {}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_function_tool(sever_name: str) -> List[Dict[str, Any]]:
|
| 22 |
+
openai_tools = []
|
| 23 |
+
try:
|
| 24 |
+
if not sever_name:
|
| 25 |
+
return []
|
| 26 |
+
tool_server = get_function_tools(sever_name)
|
| 27 |
+
if not tool_server:
|
| 28 |
+
return []
|
| 29 |
+
tools = tool_server.list_tools()
|
| 30 |
+
if not tools:
|
| 31 |
+
return []
|
| 32 |
+
for tool in tools:
|
| 33 |
+
required = []
|
| 34 |
+
properties = {}
|
| 35 |
+
if tool.inputSchema and tool.inputSchema.get("properties"):
|
| 36 |
+
required = tool.inputSchema.get("required", [])
|
| 37 |
+
_properties = tool.inputSchema["properties"]
|
| 38 |
+
for param_name, param_info in _properties.items():
|
| 39 |
+
param_type = (
|
| 40 |
+
param_info.get("type")
|
| 41 |
+
if param_info.get("type") != "str"
|
| 42 |
+
and param_info.get("type") is not None
|
| 43 |
+
else "string"
|
| 44 |
+
)
|
| 45 |
+
param_desc = param_info.get("description", "")
|
| 46 |
+
if param_type == "array":
|
| 47 |
+
# Handle array type parameters
|
| 48 |
+
items_info = param_info.get("items", {})
|
| 49 |
+
item_type = items_info.get("type", "string")
|
| 50 |
+
|
| 51 |
+
# Process nested array type parameters
|
| 52 |
+
if item_type == "array":
|
| 53 |
+
nested_items = items_info.get("items", {})
|
| 54 |
+
nested_type = nested_items.get("type", "string")
|
| 55 |
+
|
| 56 |
+
# If the nested type is an object
|
| 57 |
+
if nested_type == "object":
|
| 58 |
+
properties[param_name] = {
|
| 59 |
+
"description": param_desc,
|
| 60 |
+
"type": param_type,
|
| 61 |
+
"items": {
|
| 62 |
+
"type": item_type,
|
| 63 |
+
"items": {
|
| 64 |
+
"type": nested_type,
|
| 65 |
+
"properties": nested_items.get(
|
| 66 |
+
"properties", {}
|
| 67 |
+
),
|
| 68 |
+
"required": nested_items.get(
|
| 69 |
+
"required", []
|
| 70 |
+
),
|
| 71 |
+
},
|
| 72 |
+
},
|
| 73 |
+
}
|
| 74 |
+
else:
|
| 75 |
+
properties[param_name] = {
|
| 76 |
+
"description": param_desc,
|
| 77 |
+
"type": param_type,
|
| 78 |
+
"items": {
|
| 79 |
+
"type": item_type,
|
| 80 |
+
"items": {"type": nested_type},
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
# Process object type cases
|
| 84 |
+
elif item_type == "object":
|
| 85 |
+
properties[param_name] = {
|
| 86 |
+
"description": param_desc,
|
| 87 |
+
"type": param_type,
|
| 88 |
+
"items": {
|
| 89 |
+
"type": item_type,
|
| 90 |
+
"properties": items_info.get("properties", {}),
|
| 91 |
+
"required": items_info.get("required", []),
|
| 92 |
+
},
|
| 93 |
+
}
|
| 94 |
+
# Process basic type cases
|
| 95 |
+
else:
|
| 96 |
+
if item_type == "str":
|
| 97 |
+
item_type = "string"
|
| 98 |
+
properties[param_name] = {
|
| 99 |
+
"description": param_desc,
|
| 100 |
+
"type": param_type,
|
| 101 |
+
"items": {"type": item_type},
|
| 102 |
+
}
|
| 103 |
+
else:
|
| 104 |
+
# Handle non-array type parameters
|
| 105 |
+
properties[param_name] = {
|
| 106 |
+
"description": param_desc,
|
| 107 |
+
"type": param_type,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
openai_function_schema = {
|
| 111 |
+
"name": f"mcp__{sever_name}__{tool.name}",
|
| 112 |
+
"description": tool.description,
|
| 113 |
+
"parameters": {
|
| 114 |
+
"type": "object",
|
| 115 |
+
"properties": properties,
|
| 116 |
+
"required": required,
|
| 117 |
+
},
|
| 118 |
+
}
|
| 119 |
+
openai_tools.append(
|
| 120 |
+
{
|
| 121 |
+
"type": "function",
|
| 122 |
+
"function": openai_function_schema,
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
logging.info(
|
| 126 |
+
f"✅ function_tool_server #({sever_name}) connected success,tools: {len(tools)}"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logging.warning(
|
| 131 |
+
f"server_name-get_function_tool:{sever_name} translate failed: {e}"
|
| 132 |
+
)
|
| 133 |
+
return []
|
| 134 |
+
finally:
|
| 135 |
+
return openai_tools
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
async def run(mcp_servers: list[MCPServer]) -> List[Dict[str, Any]]:
|
| 139 |
+
openai_tools = []
|
| 140 |
+
for i, server in enumerate(mcp_servers):
|
| 141 |
+
try:
|
| 142 |
+
tools = await server.list_tools()
|
| 143 |
+
for tool in tools:
|
| 144 |
+
required = []
|
| 145 |
+
properties = {}
|
| 146 |
+
if tool.inputSchema and tool.inputSchema.get("properties"):
|
| 147 |
+
required = tool.inputSchema.get("required", [])
|
| 148 |
+
_properties = tool.inputSchema["properties"]
|
| 149 |
+
for param_name, param_info in _properties.items():
|
| 150 |
+
param_type = (
|
| 151 |
+
param_info.get("type")
|
| 152 |
+
if param_info.get("type") != "str"
|
| 153 |
+
and param_info.get("type") is not None
|
| 154 |
+
else "string"
|
| 155 |
+
)
|
| 156 |
+
param_desc = param_info.get("description", "")
|
| 157 |
+
if param_type == "array":
|
| 158 |
+
# Handle array type parameters
|
| 159 |
+
items_info = param_info.get("items", {})
|
| 160 |
+
item_type = items_info.get("type", "string")
|
| 161 |
+
|
| 162 |
+
# Process nested array type parameters
|
| 163 |
+
if item_type == "array":
|
| 164 |
+
nested_items = items_info.get("items", {})
|
| 165 |
+
nested_type = nested_items.get("type", "string")
|
| 166 |
+
|
| 167 |
+
# If the nested type is an object
|
| 168 |
+
if nested_type == "object":
|
| 169 |
+
properties[param_name] = {
|
| 170 |
+
"description": param_desc,
|
| 171 |
+
"type": param_type,
|
| 172 |
+
"items": {
|
| 173 |
+
"type": item_type,
|
| 174 |
+
"items": {
|
| 175 |
+
"type": nested_type,
|
| 176 |
+
"properties": nested_items.get(
|
| 177 |
+
"properties", {}
|
| 178 |
+
),
|
| 179 |
+
"required": nested_items.get(
|
| 180 |
+
"required", []
|
| 181 |
+
),
|
| 182 |
+
},
|
| 183 |
+
},
|
| 184 |
+
}
|
| 185 |
+
else:
|
| 186 |
+
properties[param_name] = {
|
| 187 |
+
"description": param_desc,
|
| 188 |
+
"type": param_type,
|
| 189 |
+
"items": {
|
| 190 |
+
"type": item_type,
|
| 191 |
+
"items": {"type": nested_type},
|
| 192 |
+
},
|
| 193 |
+
}
|
| 194 |
+
# Process object type cases
|
| 195 |
+
elif item_type == "object":
|
| 196 |
+
properties[param_name] = {
|
| 197 |
+
"description": param_desc,
|
| 198 |
+
"type": param_type,
|
| 199 |
+
"items": {
|
| 200 |
+
"type": item_type,
|
| 201 |
+
"properties": items_info.get("properties", {}),
|
| 202 |
+
"required": items_info.get("required", []),
|
| 203 |
+
},
|
| 204 |
+
}
|
| 205 |
+
# Process basic type cases
|
| 206 |
+
else:
|
| 207 |
+
if item_type == "str":
|
| 208 |
+
item_type = "string"
|
| 209 |
+
properties[param_name] = {
|
| 210 |
+
"description": param_desc,
|
| 211 |
+
"type": param_type,
|
| 212 |
+
"items": {"type": item_type},
|
| 213 |
+
}
|
| 214 |
+
else:
|
| 215 |
+
# Handle non-array type parameters
|
| 216 |
+
properties[param_name] = {
|
| 217 |
+
"description": param_desc,
|
| 218 |
+
"type": param_type,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
openai_function_schema = {
|
| 222 |
+
"name": f"{server.name}__{tool.name}",
|
| 223 |
+
"description": tool.description,
|
| 224 |
+
"parameters": {
|
| 225 |
+
"type": "object",
|
| 226 |
+
"properties": properties,
|
| 227 |
+
"required": required,
|
| 228 |
+
},
|
| 229 |
+
}
|
| 230 |
+
openai_tools.append(
|
| 231 |
+
{
|
| 232 |
+
"type": "function",
|
| 233 |
+
"function": openai_function_schema,
|
| 234 |
+
}
|
| 235 |
+
)
|
| 236 |
+
logging.info(
|
| 237 |
+
f"✅ server #{i + 1} ({server.name}) connected success,tools: {len(tools)}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logging.error(f"❌ server #{i + 1} ({server.name}) connect fail: {e}")
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
return openai_tools
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
async def mcp_tool_desc_transform(
|
| 248 |
+
tools: List[str] = None, mcp_config: Dict[str, Any] = None
|
| 249 |
+
) -> List[Dict[str, Any]]:
|
| 250 |
+
"""Default implement transform framework standard protocol to openai protocol of tool description."""
|
| 251 |
+
config = {}
|
| 252 |
+
global MCP_SERVERS_CONFIG
|
| 253 |
+
|
| 254 |
+
def _replace_env_variables(config):
|
| 255 |
+
if isinstance(config, dict):
|
| 256 |
+
for key, value in config.items():
|
| 257 |
+
if (
|
| 258 |
+
isinstance(value, str)
|
| 259 |
+
and value.startswith("${")
|
| 260 |
+
and value.endswith("}")
|
| 261 |
+
):
|
| 262 |
+
env_var_name = value[2:-1]
|
| 263 |
+
config[key] = os.getenv(env_var_name, value)
|
| 264 |
+
logging.info(f"Replaced {value} with {config[key]}")
|
| 265 |
+
elif isinstance(value, dict) or isinstance(value, list):
|
| 266 |
+
_replace_env_variables(value)
|
| 267 |
+
elif isinstance(config, list):
|
| 268 |
+
for index, item in enumerate(config):
|
| 269 |
+
if (
|
| 270 |
+
isinstance(item, str)
|
| 271 |
+
and item.startswith("${")
|
| 272 |
+
and item.endswith("}")
|
| 273 |
+
):
|
| 274 |
+
env_var_name = item[2:-1]
|
| 275 |
+
config[index] = os.getenv(env_var_name, item)
|
| 276 |
+
logging.info(f"Replaced {item} with {config[index]}")
|
| 277 |
+
elif isinstance(item, dict) or isinstance(item, list):
|
| 278 |
+
_replace_env_variables(item)
|
| 279 |
+
|
| 280 |
+
if mcp_config:
|
| 281 |
+
try:
|
| 282 |
+
config = mcp_config
|
| 283 |
+
MCP_SERVERS_CONFIG = config
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logging.error(f"mcp_config error: {e}")
|
| 286 |
+
return []
|
| 287 |
+
else:
|
| 288 |
+
# Priority given to the running path.
|
| 289 |
+
config_path = find_file(filename="mcp.json")
|
| 290 |
+
if not os.path.exists(config_path):
|
| 291 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 292 |
+
config_path = os.path.normpath(
|
| 293 |
+
os.path.join(current_dir, "../config/mcp.json")
|
| 294 |
+
)
|
| 295 |
+
logger.info(f"mcp conf path: {config_path}")
|
| 296 |
+
|
| 297 |
+
if not os.path.exists(config_path):
|
| 298 |
+
logging.info(f"mcp config is not exist: {config_path}")
|
| 299 |
+
return []
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
with open(config_path, "r") as f:
|
| 303 |
+
config = json.load(f)
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logging.info(f"load config fail: {e}")
|
| 306 |
+
return []
|
| 307 |
+
_replace_env_variables(config)
|
| 308 |
+
|
| 309 |
+
MCP_SERVERS_CONFIG = config
|
| 310 |
+
|
| 311 |
+
mcp_servers_config = config.get("mcpServers", {})
|
| 312 |
+
|
| 313 |
+
server_configs = []
|
| 314 |
+
for server_name, server_config in mcp_servers_config.items():
|
| 315 |
+
# Skip disabled servers
|
| 316 |
+
if server_config.get("disabled", False):
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
if tools is None or server_name in tools:
|
| 320 |
+
# Handle SSE server
|
| 321 |
+
if "url" in server_config:
|
| 322 |
+
server_configs.append(
|
| 323 |
+
{
|
| 324 |
+
"name": "mcp__" + server_name,
|
| 325 |
+
"type": "sse",
|
| 326 |
+
"params": {"url": server_config["url"]},
|
| 327 |
+
}
|
| 328 |
+
)
|
| 329 |
+
# Handle stdio server
|
| 330 |
+
elif "command" in server_config:
|
| 331 |
+
server_configs.append(
|
| 332 |
+
{
|
| 333 |
+
"name": "mcp__" + server_name,
|
| 334 |
+
"type": "stdio",
|
| 335 |
+
"params": {
|
| 336 |
+
"command": server_config["command"],
|
| 337 |
+
"args": server_config.get("args", []),
|
| 338 |
+
"env": server_config.get("env", {}),
|
| 339 |
+
"cwd": server_config.get("cwd"),
|
| 340 |
+
"encoding": server_config.get("encoding", "utf-8"),
|
| 341 |
+
"encoding_error_handler": server_config.get(
|
| 342 |
+
"encoding_error_handler", "strict"
|
| 343 |
+
),
|
| 344 |
+
},
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if not server_configs:
|
| 349 |
+
return []
|
| 350 |
+
|
| 351 |
+
async with AsyncExitStack() as stack:
|
| 352 |
+
servers = []
|
| 353 |
+
for server_config in server_configs:
|
| 354 |
+
try:
|
| 355 |
+
if server_config["type"] == "sse":
|
| 356 |
+
server = MCPServerSse(
|
| 357 |
+
name=server_config["name"], params=server_config["params"]
|
| 358 |
+
)
|
| 359 |
+
elif server_config["type"] == "stdio":
|
| 360 |
+
from aworld.mcp_client.server import MCPServerStdio
|
| 361 |
+
|
| 362 |
+
server = MCPServerStdio(
|
| 363 |
+
name=server_config["name"], params=server_config["params"]
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
logging.warning(
|
| 367 |
+
f"Unsupported MCP server type: {server_config['type']}"
|
| 368 |
+
)
|
| 369 |
+
continue
|
| 370 |
+
|
| 371 |
+
server = await stack.enter_async_context(server)
|
| 372 |
+
servers.append(server)
|
| 373 |
+
except BaseException as err:
|
| 374 |
+
# single
|
| 375 |
+
logging.error(
|
| 376 |
+
f"Failed to get tools for MCP server '{server_config['name']}'.\n"
|
| 377 |
+
f"Error: {err}\n"
|
| 378 |
+
f"Traceback:\n{traceback.format_exc()}"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
openai_tools = await run(servers)
|
| 382 |
+
|
| 383 |
+
return openai_tools
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
async def sandbox_mcp_tool_desc_transform(
|
| 387 |
+
tools: List[str] = None, mcp_config: Dict[str, Any] = None
|
| 388 |
+
) -> List[Dict[str, Any]]:
|
| 389 |
+
# todo sandbox mcp_config get from registry
|
| 390 |
+
|
| 391 |
+
if not mcp_config:
|
| 392 |
+
return None
|
| 393 |
+
config = mcp_config
|
| 394 |
+
mcp_servers_config = config.get("mcpServers", {})
|
| 395 |
+
server_configs = []
|
| 396 |
+
openai_tools = []
|
| 397 |
+
mcp_openai_tools = []
|
| 398 |
+
|
| 399 |
+
for server_name, server_config in mcp_servers_config.items():
|
| 400 |
+
# Skip disabled servers
|
| 401 |
+
if server_config.get("disabled", False):
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
if tools is None or server_name in tools:
|
| 405 |
+
# Handle SSE server
|
| 406 |
+
if "function_tool" == server_config.get("type", ""):
|
| 407 |
+
try:
|
| 408 |
+
tmp_function_tool = get_function_tool(server_name)
|
| 409 |
+
openai_tools.extend(tmp_function_tool)
|
| 410 |
+
except Exception as e:
|
| 411 |
+
logging.warning(f"server_name:{server_name} translate failed: {e}")
|
| 412 |
+
elif "api" == server_config.get("type", ""):
|
| 413 |
+
api_result = requests.get(server_config["url"] + "/list_tools")
|
| 414 |
+
try:
|
| 415 |
+
if not api_result or not api_result.text:
|
| 416 |
+
continue
|
| 417 |
+
# return None
|
| 418 |
+
data = json.loads(api_result.text)
|
| 419 |
+
if not data or not data.get("tools"):
|
| 420 |
+
continue
|
| 421 |
+
for item in data.get("tools"):
|
| 422 |
+
tmp_function = {
|
| 423 |
+
"type": "function",
|
| 424 |
+
"function": {
|
| 425 |
+
"name": "mcp__" + server_name + "__" + item["name"],
|
| 426 |
+
"description": item["description"],
|
| 427 |
+
"parameters": {
|
| 428 |
+
**item["parameters"],
|
| 429 |
+
"properties": {
|
| 430 |
+
k: v
|
| 431 |
+
for k, v in item["parameters"]
|
| 432 |
+
.get("properties", {})
|
| 433 |
+
.items()
|
| 434 |
+
if "default" not in v
|
| 435 |
+
},
|
| 436 |
+
},
|
| 437 |
+
},
|
| 438 |
+
}
|
| 439 |
+
openai_tools.append(tmp_function)
|
| 440 |
+
except Exception as e:
|
| 441 |
+
logging.warning(f"server_name:{server_name} translate failed: {e}")
|
| 442 |
+
elif "sse" == server_config.get("type", ""):
|
| 443 |
+
server_configs.append(
|
| 444 |
+
{
|
| 445 |
+
"name": "mcp__" + server_name,
|
| 446 |
+
"type": "sse",
|
| 447 |
+
"params": {
|
| 448 |
+
"url": server_config["url"],
|
| 449 |
+
"headers": server_config.get("headers"),
|
| 450 |
+
},
|
| 451 |
+
}
|
| 452 |
+
)
|
| 453 |
+
# Handle stdio server
|
| 454 |
+
else:
|
| 455 |
+
# elif "stdio" == server_config.get("type", ""):
|
| 456 |
+
server_configs.append(
|
| 457 |
+
{
|
| 458 |
+
"name": "mcp__" + server_name,
|
| 459 |
+
"type": "stdio",
|
| 460 |
+
"params": {
|
| 461 |
+
"command": server_config["command"],
|
| 462 |
+
"args": server_config.get("args", []),
|
| 463 |
+
"env": server_config.get("env", {}),
|
| 464 |
+
"cwd": server_config.get("cwd"),
|
| 465 |
+
"encoding": server_config.get("encoding", "utf-8"),
|
| 466 |
+
"encoding_error_handler": server_config.get(
|
| 467 |
+
"encoding_error_handler", "strict"
|
| 468 |
+
),
|
| 469 |
+
},
|
| 470 |
+
}
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if not server_configs:
|
| 474 |
+
return openai_tools
|
| 475 |
+
|
| 476 |
+
async with AsyncExitStack() as stack:
|
| 477 |
+
servers = []
|
| 478 |
+
for server_config in server_configs:
|
| 479 |
+
try:
|
| 480 |
+
if server_config["type"] == "sse":
|
| 481 |
+
server = MCPServerSse(
|
| 482 |
+
name=server_config["name"], params=server_config["params"]
|
| 483 |
+
)
|
| 484 |
+
elif server_config["type"] == "stdio":
|
| 485 |
+
server = MCPServerStdio(
|
| 486 |
+
name=server_config["name"], params=server_config["params"]
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
logging.warning(
|
| 490 |
+
f"Unsupported MCP server type: {server_config['type']}"
|
| 491 |
+
)
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
server = await stack.enter_async_context(server)
|
| 495 |
+
servers.append(server)
|
| 496 |
+
except BaseException as err:
|
| 497 |
+
# single
|
| 498 |
+
logging.error(
|
| 499 |
+
f"Failed to get tools for MCP server '{server_config['name']}'.\n"
|
| 500 |
+
f"Error: {err}\n"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
mcp_openai_tools = await run(servers)
|
| 504 |
+
|
| 505 |
+
if mcp_openai_tools:
|
| 506 |
+
openai_tools.extend(mcp_openai_tools)
|
| 507 |
+
|
| 508 |
+
return openai_tools
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
async def call_function_tool(
|
| 512 |
+
server_name: str,
|
| 513 |
+
tool_name: str,
|
| 514 |
+
parameter: Dict[str, Any] = None,
|
| 515 |
+
mcp_config: Dict[str, Any] = None,
|
| 516 |
+
) -> ActionResult:
|
| 517 |
+
"""Specifically handle API type server calls
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
server_name: Server name
|
| 521 |
+
tool_name: Tool name
|
| 522 |
+
parameter: Parameters
|
| 523 |
+
mcp_config: MCP configuration
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
ActionResult: Call result
|
| 527 |
+
"""
|
| 528 |
+
action_result = ActionResult(
|
| 529 |
+
tool_name=server_name, action_name=tool_name, content="", keep=True
|
| 530 |
+
)
|
| 531 |
+
try:
|
| 532 |
+
tool_server = get_function_tools(server_name)
|
| 533 |
+
if not tool_server:
|
| 534 |
+
return action_result
|
| 535 |
+
call_result_raw = tool_server.call_tool(tool_name, parameter)
|
| 536 |
+
if call_result_raw and call_result_raw.content:
|
| 537 |
+
if isinstance(call_result_raw.content[0], TextContent):
|
| 538 |
+
action_result = ActionResult(
|
| 539 |
+
tool_name=server_name,
|
| 540 |
+
action_name=tool_name,
|
| 541 |
+
content=call_result_raw.content[0].text,
|
| 542 |
+
keep=True,
|
| 543 |
+
metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
|
| 544 |
+
)
|
| 545 |
+
elif isinstance(call_result_raw.content[0], ImageContent):
|
| 546 |
+
action_result = ActionResult(
|
| 547 |
+
tool_name=server_name,
|
| 548 |
+
action_name=tool_name,
|
| 549 |
+
content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}",
|
| 550 |
+
keep=True,
|
| 551 |
+
metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
except Exception as e:
|
| 555 |
+
logging.warning(f"call_function_tool ({server_name})({tool_name}) failed: {e}")
|
| 556 |
+
action_result = ActionResult(
|
| 557 |
+
tool_name=server_name, action_name=tool_name, content="", keep=True
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
return action_result
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
async def call_api(
|
| 564 |
+
server_name: str,
|
| 565 |
+
tool_name: str,
|
| 566 |
+
parameter: Dict[str, Any] = None,
|
| 567 |
+
mcp_config: Dict[str, Any] = None,
|
| 568 |
+
) -> ActionResult:
|
| 569 |
+
"""Specifically handle API type server calls
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
server_name: Server name
|
| 573 |
+
tool_name: Tool name
|
| 574 |
+
parameter: Parameters
|
| 575 |
+
mcp_config: MCP configuration
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
ActionResult: Call result
|
| 579 |
+
"""
|
| 580 |
+
action_result = ActionResult(
|
| 581 |
+
tool_name=server_name, action_name=tool_name, content="", keep=True
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
if not mcp_config or mcp_config.get("mcpServers") is None:
|
| 585 |
+
return action_result
|
| 586 |
+
|
| 587 |
+
mcp_servers = mcp_config.get("mcpServers")
|
| 588 |
+
if not mcp_servers.get(server_name):
|
| 589 |
+
return action_result
|
| 590 |
+
|
| 591 |
+
server_config = mcp_servers.get(server_name)
|
| 592 |
+
if "api" != server_config.get("type", ""):
|
| 593 |
+
logging.warning(
|
| 594 |
+
f"Server {server_name} is not API type, should use call_tool instead"
|
| 595 |
+
)
|
| 596 |
+
return action_result
|
| 597 |
+
|
| 598 |
+
try:
|
| 599 |
+
headers = {"Content-Type": "application/json"}
|
| 600 |
+
response = requests.post(
|
| 601 |
+
url=server_config["url"] + "/" + tool_name, headers=headers, json=parameter
|
| 602 |
+
)
|
| 603 |
+
action_result = ActionResult(
|
| 604 |
+
tool_name=server_name,
|
| 605 |
+
action_name=tool_name,
|
| 606 |
+
content=response.text,
|
| 607 |
+
keep=True,
|
| 608 |
+
)
|
| 609 |
+
except Exception as e:
|
| 610 |
+
logging.warning(f"call_api ({server_name})({tool_name}) failed: {e}")
|
| 611 |
+
action_result = ActionResult(
|
| 612 |
+
tool_name=server_name,
|
| 613 |
+
action_name=tool_name,
|
| 614 |
+
content=f"Error calling API: {str(e)}",
|
| 615 |
+
keep=True,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
return action_result
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
async def get_server_instance(
|
| 622 |
+
server_name: str, mcp_config: Dict[str, Any] = None
|
| 623 |
+
) -> Any:
|
| 624 |
+
"""Get server instance, create a new one if it doesn't exist
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
server_name: Server name
|
| 628 |
+
mcp_config: MCP configuration
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
Server instance or None (if creation fails)
|
| 632 |
+
"""
|
| 633 |
+
if not mcp_config or mcp_config.get("mcpServers") is None:
|
| 634 |
+
return None
|
| 635 |
+
|
| 636 |
+
mcp_servers = mcp_config.get("mcpServers")
|
| 637 |
+
if not mcp_servers.get(server_name):
|
| 638 |
+
return None
|
| 639 |
+
|
| 640 |
+
server_config = mcp_servers.get(server_name)
|
| 641 |
+
try:
|
| 642 |
+
# API type servers use special handling, no need for persistent connections
|
| 643 |
+
# Note: We've already handled API type in McpServers.call_tool method
|
| 644 |
+
# Here we don't return None, but let the caller handle it
|
| 645 |
+
if "api" == server_config.get("type", ""):
|
| 646 |
+
logging.info(f"API server {server_name} doesn't need persistent connection")
|
| 647 |
+
return None
|
| 648 |
+
elif "sse" == server_config.get("type", ""):
|
| 649 |
+
server = MCPServerSse(
|
| 650 |
+
name=server_name,
|
| 651 |
+
params={
|
| 652 |
+
"url": server_config["url"],
|
| 653 |
+
"headers": server_config.get("headers"),
|
| 654 |
+
"timeout": server_config.get("timeout", 5.0),
|
| 655 |
+
"sse_read_timeout": server_config.get("sse_read_timeout", 300.0),
|
| 656 |
+
},
|
| 657 |
+
)
|
| 658 |
+
await server.connect()
|
| 659 |
+
logging.info(f"Successfully connected to SSE server: {server_name}")
|
| 660 |
+
return server
|
| 661 |
+
else: # stdio type
|
| 662 |
+
params = {
|
| 663 |
+
"command": server_config["command"],
|
| 664 |
+
"args": server_config.get("args", []),
|
| 665 |
+
"env": server_config.get("env", {}),
|
| 666 |
+
"cwd": server_config.get("cwd"),
|
| 667 |
+
"encoding": server_config.get("encoding", "utf-8"),
|
| 668 |
+
"encoding_error_handler": server_config.get(
|
| 669 |
+
"encoding_error_handler", "strict"
|
| 670 |
+
),
|
| 671 |
+
}
|
| 672 |
+
server = MCPServerStdio(name=server_name, params=params)
|
| 673 |
+
await server.connect()
|
| 674 |
+
logging.info(f"Successfully connected to stdio server: {server_name}")
|
| 675 |
+
return server
|
| 676 |
+
except Exception as e:
|
| 677 |
+
logging.warning(f"Failed to create server instance for {server_name}: {e}")
|
| 678 |
+
return None
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
async def cleanup_server(server):
|
| 682 |
+
"""Clean up server connection
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
server: Server instance
|
| 686 |
+
"""
|
| 687 |
+
try:
|
| 688 |
+
if hasattr(server, "cleanup"):
|
| 689 |
+
await server.cleanup()
|
| 690 |
+
elif hasattr(server, "close"):
|
| 691 |
+
await server.close()
|
| 692 |
+
logging.info(
|
| 693 |
+
f"Successfully cleaned up server: {getattr(server, 'name', 'unknown')}"
|
| 694 |
+
)
|
| 695 |
+
except Exception as e:
|
| 696 |
+
logging.warning(f"Failed to cleanup server: {e}")
|