|
from typing import Dict, Any, List, Optional |
|
from contextlib import AsyncExitStack |
|
import json |
|
import aiohttp |
|
import asyncio |
|
|
|
from mcp import ClientSession, StdioServerParameters |
|
from mcp.client.stdio import stdio_client |
|
from mcp.client.sse import sse_client |
|
|
|
|
|
class MCPClient: |
|
def __init__(self, config_path: str, user_id: Optional[str] = None): |
|
""" |
|
Initialize MCPClient with a list of server configurations. |
|
Each config should be a dict with 'path' (script path) and optionally 'type' (python/node). |
|
""" |
|
self.user_id = user_id |
|
with open(config_path, 'r') as f: |
|
self.server_configs = json.load(f)['mcpServers'] |
|
self.sessions: Dict[str, Any] = {} |
|
self.exit_stack = AsyncExitStack() |
|
|
|
async def connect_to_servers(self): |
|
"""Connect to all configured MCP servers based on their transport type.""" |
|
for server_name, config in self.server_configs.items(): |
|
transport = config.get("transport", "stdio") |
|
print(f"Connecting to {server_name} ({transport})...") |
|
if transport == "stdio": |
|
command = config.get("command") |
|
args = config.get("args", []) |
|
env = config.get("env", None) |
|
if not command: |
|
raise ValueError(f"No command specified for server {server_name}") |
|
|
|
server_params = StdioServerParameters(command=command, args=args, env=env) |
|
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
|
stdio, write = stdio_transport |
|
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) |
|
await session.initialize() |
|
self.sessions[server_name] = session |
|
|
|
elif transport == "sse": |
|
server_url = config.get("url", "") |
|
if not server_url: |
|
raise ValueError(f"No base_url specified for server {server_name}") |
|
|
|
|
|
streams_context = sse_client(url=f"{server_url}/sse") |
|
streams = await self.exit_stack.enter_async_context(streams_context) |
|
session_context = ClientSession(*streams) |
|
session = await self.exit_stack.enter_async_context(session_context) |
|
|
|
|
|
await session.initialize() |
|
self.sessions[server_name] = session |
|
|
|
|
|
|
|
print(f"Initialized SSE client for {server_name}...") |
|
print("Listing tools...") |
|
response = await session.list_tools() |
|
tools = response.tools |
|
print(f"Connected to {server_name} with tools:", [tool.name for tool in tools]) |
|
else: |
|
raise ValueError(f"Unsupported transport type '{transport}' for {server_name}") |
|
|
|
async def get_tools(self) -> List[Dict[str, Any]]: |
|
""" |
|
Fetch the list of available tools from all connected MCP servers. |
|
Returns a list of tool definitions with name, description, and inputSchema. |
|
""" |
|
all_tools = [] |
|
for server_name, session in self.sessions.items(): |
|
response = await session.list_tools() |
|
for tool in response.tools: |
|
if not self.user_id and tool.name == 'personalized_fashion_recommend': |
|
continue |
|
all_tools.append( |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": tool.inputSchema |
|
} |
|
} |
|
) |
|
return all_tools |
|
|
|
async def execute_tool(self, tool_name: str, params: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Execute a tool with the given parameters on the appropriate server. |
|
""" |
|
|
|
for server_name, session in self.sessions.items(): |
|
response = await session.list_tools() |
|
for tool in response.tools: |
|
if tool.name == tool_name: |
|
|
|
result = await session.call_tool(tool_name, params) |
|
return { |
|
"result": result.content, |
|
"server": server_name |
|
} |
|
raise Exception(f"Tool {tool_name} not found on any connected server") |
|
|
|
async def close(self): |
|
"""Close all server connections.""" |
|
await self.exit_stack.aclose() |
|
|