from typing import Dict, Any, List, Optional from contextlib import AsyncExitStack import json import aiohttp import asyncio from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client class MCPClient: def __init__(self, config_path: str, user_id: Optional[str] = None): """ Initialize MCPClient with a list of server configurations. Each config should be a dict with 'path' (script path) and optionally 'type' (python/node). """ self.user_id = user_id with open(config_path, 'r') as f: self.server_configs = json.load(f)['mcpServers'] self.sessions: Dict[str, Any] = {} # 存储 stdio 的 ClientSession 或 sse 的 aiohttp session self.exit_stack = AsyncExitStack() async def connect_to_servers(self): """Connect to all configured MCP servers based on their transport type.""" for server_name, config in self.server_configs.items(): transport = config.get("transport", "stdio") # 默认使用 stdio print(f"Connecting to {server_name} ({transport})...") if transport == "stdio": command = config.get("command") args = config.get("args", []) env = config.get("env", None) if not command: raise ValueError(f"No command specified for server {server_name}") server_params = StdioServerParameters(command=command, args=args, env=env) stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) stdio, write = stdio_transport session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) await session.initialize() self.sessions[server_name] = session # self.stdio_transports[server_name] = (stdio, write) elif transport == "sse": server_url = config.get("url", "") if not server_url: raise ValueError(f"No base_url specified for server {server_name}") # 建立 SSE 连接 streams_context = sse_client(url=f"{server_url}/sse") streams = await self.exit_stack.enter_async_context(streams_context) session_context = ClientSession(*streams) session = await self.exit_stack.enter_async_context(session_context) # 初始化会话 await session.initialize() self.sessions[server_name] = session # self.sse_contexts[server_name] = (streams_context, session_context) # 验证连接 print(f"Initialized SSE client for {server_name}...") print("Listing tools...") response = await session.list_tools() tools = response.tools print(f"Connected to {server_name} with tools:", [tool.name for tool in tools]) else: raise ValueError(f"Unsupported transport type '{transport}' for {server_name}") async def get_tools(self) -> List[Dict[str, Any]]: """ Fetch the list of available tools from all connected MCP servers. Returns a list of tool definitions with name, description, and inputSchema. """ all_tools = [] for server_name, session in self.sessions.items(): response = await session.list_tools() for tool in response.tools: if not self.user_id and tool.name == 'personalized_fashion_recommend': continue all_tools.append( { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.inputSchema } } ) return all_tools async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: """ Execute a tool with the given parameters on the appropriate server. """ # Find which server has this tool for server_name, session in self.sessions.items(): response = await session.list_tools() for tool in response.tools: if tool.name == tool_name: # Execute the tool on the correct server result = await session.call_tool(tool_name, params) return { "result": result.content, "server": server_name } raise Exception(f"Tool {tool_name} not found on any connected server") async def close(self): """Close all server connections.""" await self.exit_stack.aclose()