File size: 4,990 Bytes
f8a73ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from typing import Dict, Any, List, Optional
from contextlib import AsyncExitStack
import json
import aiohttp
import asyncio

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client


class MCPClient:
    def __init__(self, config_path: str, user_id: Optional[str] = None):
        """
        Initialize MCPClient with a list of server configurations.
        Each config should be a dict with 'path' (script path) and optionally 'type' (python/node).
        """
        self.user_id = user_id
        with open(config_path, 'r') as f:
            self.server_configs = json.load(f)['mcpServers']
        self.sessions: Dict[str, Any] = {}  # 存储 stdio 的 ClientSession 或 sse 的 aiohttp session
        self.exit_stack = AsyncExitStack()

    async def connect_to_servers(self):
        """Connect to all configured MCP servers based on their transport type."""
        for server_name, config in self.server_configs.items():
            transport = config.get("transport", "stdio")  # 默认使用 stdio
            print(f"Connecting to {server_name} ({transport})...")
            if transport == "stdio":
                command = config.get("command")
                args = config.get("args", [])
                env = config.get("env", None)
                if not command:
                    raise ValueError(f"No command specified for server {server_name}")

                server_params = StdioServerParameters(command=command, args=args, env=env)
                stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
                stdio, write = stdio_transport
                session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
                await session.initialize()
                self.sessions[server_name] = session
                # self.stdio_transports[server_name] = (stdio, write)
            elif transport == "sse":
                server_url = config.get("url", "")
                if not server_url:
                    raise ValueError(f"No base_url specified for server {server_name}")

                # 建立 SSE 连接
                streams_context = sse_client(url=f"{server_url}/sse")
                streams = await self.exit_stack.enter_async_context(streams_context)
                session_context = ClientSession(*streams)
                session = await self.exit_stack.enter_async_context(session_context)

                # 初始化会话
                await session.initialize()
                self.sessions[server_name] = session
                # self.sse_contexts[server_name] = (streams_context, session_context)

                # 验证连接
                print(f"Initialized SSE client for {server_name}...")
                print("Listing tools...")
                response = await session.list_tools()
                tools = response.tools
                print(f"Connected to {server_name} with tools:", [tool.name for tool in tools])
            else:
                raise ValueError(f"Unsupported transport type '{transport}' for {server_name}")

    async def get_tools(self) -> List[Dict[str, Any]]:
        """
        Fetch the list of available tools from all connected MCP servers.
        Returns a list of tool definitions with name, description, and inputSchema.
        """
        all_tools = []
        for server_name, session in self.sessions.items():
            response = await session.list_tools()
            for tool in response.tools:
                if not self.user_id and tool.name == 'personalized_fashion_recommend':
                    continue
                all_tools.append(
                    {
                        "type": "function",
                        "function": {
                            "name": tool.name,
                            "description": tool.description,
                            "parameters": tool.inputSchema
                        }
                    }
                )
        return all_tools

    async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
        """
        Execute a tool with the given parameters on the appropriate server.
        """
        # Find which server has this tool
        for server_name, session in self.sessions.items():
            response = await session.list_tools()
            for tool in response.tools:
                if tool.name == tool_name:
                    # Execute the tool on the correct server
                    result = await session.call_tool(tool_name, params)
                    return {
                        "result": result.content,
                        "server": server_name
                    }
        raise Exception(f"Tool {tool_name} not found on any connected server")

    async def close(self):
        """Close all server connections."""
        await self.exit_stack.aclose()