import logging import json from typing_extensions import Optional, List, Dict, Any from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform, call_api, get_server_instance, cleanup_server, \ call_function_tool from mcp.types import TextContent, ImageContent from aworld.core.common import ActionResult class McpServers: def __init__( self, mcp_servers: Optional[List[str]] = None, mcp_config: Dict[str, Any] = None, sandbox = None, ) -> None: self.mcp_servers = mcp_servers self.mcp_config = mcp_config self.sandbox = sandbox # Dictionary to store server instances {server_name: server_instance} self.server_instances = {} self.tool_list = None async def list_tools( self, ) -> List[Dict[str, Any]]: if self.tool_list: return self.tool_list if not self.mcp_servers or not self.mcp_config: return [] try: self.tool_list = await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config) return self.tool_list except Exception as e: logging.warning(f"Failed to list tools: {e}") return [] async def call_tool( self, action_list: List[Dict[str, Any]] = None, task_id: str = None, session_id: str = None ) -> List[ActionResult]: results = [] if not action_list: return None try: for action in action_list: if not isinstance(action, dict): action_dict = vars(action) else: action_dict = action # Get values from dictionary server_name = action_dict.get("tool_name") tool_name = action_dict.get("action_name") parameter = action_dict.get("params") result_key = f"{server_name}__{tool_name}" operation_info = { "server_name": server_name, "tool_name": tool_name, "params": parameter } if parameter is None: parameter = {} if task_id: parameter["task_id"] = task_id if session_id: parameter["session_id"] = session_id if not server_name or not tool_name: continue # Check server type server_type = None if self.mcp_config and self.mcp_config.get("mcpServers"): server_config = self.mcp_config.get("mcpServers").get(server_name, {}) server_type = server_config.get("type", "") if server_type == "function_tool": try: call_result = await call_function_tool( server_name, tool_name, parameter, self.mcp_config ) results.append(call_result) self._update_metadata(result_key, call_result, operation_info) except Exception as e: logging.warning(f"Error calling function_tool tool: {e}") self._update_metadata(result_key, {"error": str(e)}, operation_info) continue # For API type servers, use call_api function directly if server_type == "api": try: call_result = await call_api( server_name, tool_name, parameter, self.mcp_config ) results.append(call_result) self._update_metadata(result_key, call_result, operation_info) except Exception as e: logging.warning(f"Error calling API tool: {e}") self._update_metadata(result_key, {"error": str(e)}, operation_info) continue # Prioritize using existing server instances server = self.server_instances.get(server_name) if server is None: # If it doesn't exist, create a new instance and save it server = await get_server_instance(server_name, self.mcp_config) if server: self.server_instances[server_name] = server logging.info(f"Created and cached new server instance for {server_name}") else: logging.warning(f"Created new server failed: {server_name}") self._update_metadata(result_key, {"error": "Failed to create server instance"}, operation_info) continue # Use server instance to call the tool try: call_result_raw = await server.call_tool(tool_name, parameter) # Process the return result, consistent with the original logic action_result = ActionResult( tool_name=server_name, action_name=tool_name, content="", keep=True ) if call_result_raw and call_result_raw.content: if isinstance(call_result_raw.content[0], TextContent): action_result = ActionResult( tool_name=server_name, action_name=tool_name, content=call_result_raw.content[0].text, keep=True, metadata=call_result_raw.content[0].model_extra.get( "metadata", {} ), ) elif isinstance(call_result_raw.content[0], ImageContent): action_result = ActionResult( tool_name=server_name, action_name=tool_name, content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}", keep=True, metadata=call_result_raw.content[0].model_extra.get("metadata", {}), ) results.append(action_result) self._update_metadata(result_key, action_result, operation_info) except Exception as e: logging.warning(f"Error calling tool with cached server: {e}") self._update_metadata(result_key, {"error": str(e)}, operation_info) # If using cached server instance fails, try to clean up and recreate if server_name in self.server_instances: try: await cleanup_server(self.server_instances[server_name]) del self.server_instances[server_name] except Exception as e: logging.warning(f"Failed to cleanup server {server_name}: {e}") except Exception as e: logging.warning(f"Failed to call_tool: {e}") return None return results def _update_metadata(self, result_key: str, result: Any, operation_info: Dict[str, Any]): """ Update sandbox metadata with a single tool call result Args: result_key: The key name in metadata result: Tool call result operation_info: Operation information """ if not self.sandbox or not hasattr(self.sandbox, '_metadata'): return try: metadata = self.sandbox._metadata.get("mcp_metadata",{}) tmp_data = { "input": operation_info, "output": result } if not metadata: metadata["mcp_metadata"] = {} metadata["mcp_metadata"][result_key] = [tmp_data] self.sandbox._metadata["mcp_metadata"] = metadata return _metadata = metadata.get(result_key, []) if not _metadata: _metadata[result_key] = [_metadata] else: _metadata[result_key].append(tmp_data) metadata[result_key] = _metadata self.sandbox._metadata["mcp_metadata"] = metadata return except Exception as e: logging.warning(f"Failed to update sandbox metadata: {e}") # Add cleanup method, called when Sandbox is destroyed async def cleanup(self): """Clean up all server connections""" for server_name, server in list(self.server_instances.items()): try: await cleanup_server(server) del self.server_instances[server_name] logging.info(f"Cleaned up server instance for {server_name}") except Exception as e: logging.warning(f"Failed to cleanup server {server_name}: {e}")