FashionM3 / mcp_client.py
pangkaicheng
first commit
f8a73ec
from typing import Dict, Any, List, Optional
from contextlib import AsyncExitStack
import json
import aiohttp
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
class MCPClient:
def __init__(self, config_path: str, user_id: Optional[str] = None):
"""
Initialize MCPClient with a list of server configurations.
Each config should be a dict with 'path' (script path) and optionally 'type' (python/node).
"""
self.user_id = user_id
with open(config_path, 'r') as f:
self.server_configs = json.load(f)['mcpServers']
self.sessions: Dict[str, Any] = {} # 存储 stdio 的 ClientSession 或 sse 的 aiohttp session
self.exit_stack = AsyncExitStack()
async def connect_to_servers(self):
"""Connect to all configured MCP servers based on their transport type."""
for server_name, config in self.server_configs.items():
transport = config.get("transport", "stdio") # 默认使用 stdio
print(f"Connecting to {server_name} ({transport})...")
if transport == "stdio":
command = config.get("command")
args = config.get("args", [])
env = config.get("env", None)
if not command:
raise ValueError(f"No command specified for server {server_name}")
server_params = StdioServerParameters(command=command, args=args, env=env)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
await session.initialize()
self.sessions[server_name] = session
# self.stdio_transports[server_name] = (stdio, write)
elif transport == "sse":
server_url = config.get("url", "")
if not server_url:
raise ValueError(f"No base_url specified for server {server_name}")
# 建立 SSE 连接
streams_context = sse_client(url=f"{server_url}/sse")
streams = await self.exit_stack.enter_async_context(streams_context)
session_context = ClientSession(*streams)
session = await self.exit_stack.enter_async_context(session_context)
# 初始化会话
await session.initialize()
self.sessions[server_name] = session
# self.sse_contexts[server_name] = (streams_context, session_context)
# 验证连接
print(f"Initialized SSE client for {server_name}...")
print("Listing tools...")
response = await session.list_tools()
tools = response.tools
print(f"Connected to {server_name} with tools:", [tool.name for tool in tools])
else:
raise ValueError(f"Unsupported transport type '{transport}' for {server_name}")
async def get_tools(self) -> List[Dict[str, Any]]:
"""
Fetch the list of available tools from all connected MCP servers.
Returns a list of tool definitions with name, description, and inputSchema.
"""
all_tools = []
for server_name, session in self.sessions.items():
response = await session.list_tools()
for tool in response.tools:
if not self.user_id and tool.name == 'personalized_fashion_recommend':
continue
all_tools.append(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
}
)
return all_tools
async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute a tool with the given parameters on the appropriate server.
"""
# Find which server has this tool
for server_name, session in self.sessions.items():
response = await session.list_tools()
for tool in response.tools:
if tool.name == tool_name:
# Execute the tool on the correct server
result = await session.call_tool(tool_name, params)
return {
"result": result.content,
"server": server_name
}
raise Exception(f"Tool {tool_name} not found on any connected server")
async def close(self):
"""Close all server connections."""
await self.exit_stack.aclose()