Spaces:
Sleeping
Sleeping
File size: 9,394 Bytes
fe0d3be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import logging
import json
from typing_extensions import Optional, List, Dict, Any
from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform, call_api, get_server_instance, cleanup_server, \
call_function_tool
from mcp.types import TextContent, ImageContent
from aworld.core.common import ActionResult
class McpServers:
def __init__(
self,
mcp_servers: Optional[List[str]] = None,
mcp_config: Dict[str, Any] = None,
sandbox = None,
) -> None:
self.mcp_servers = mcp_servers
self.mcp_config = mcp_config
self.sandbox = sandbox
# Dictionary to store server instances {server_name: server_instance}
self.server_instances = {}
self.tool_list = None
async def list_tools(
self,
) -> List[Dict[str, Any]]:
if self.tool_list:
return self.tool_list
if not self.mcp_servers or not self.mcp_config:
return []
try:
self.tool_list = await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config)
return self.tool_list
except Exception as e:
logging.warning(f"Failed to list tools: {e}")
return []
async def call_tool(
self,
action_list: List[Dict[str, Any]] = None,
task_id: str = None,
session_id: str = None
) -> List[ActionResult]:
results = []
if not action_list:
return None
try:
for action in action_list:
if not isinstance(action, dict):
action_dict = vars(action)
else:
action_dict = action
# Get values from dictionary
server_name = action_dict.get("tool_name")
tool_name = action_dict.get("action_name")
parameter = action_dict.get("params")
result_key = f"{server_name}__{tool_name}"
operation_info = {
"server_name": server_name,
"tool_name": tool_name,
"params": parameter
}
if parameter is None:
parameter = {}
if task_id:
parameter["task_id"] = task_id
if session_id:
parameter["session_id"] = session_id
if not server_name or not tool_name:
continue
# Check server type
server_type = None
if self.mcp_config and self.mcp_config.get("mcpServers"):
server_config = self.mcp_config.get("mcpServers").get(server_name, {})
server_type = server_config.get("type", "")
if server_type == "function_tool":
try:
call_result = await call_function_tool(
server_name, tool_name, parameter, self.mcp_config
)
results.append(call_result)
self._update_metadata(result_key, call_result, operation_info)
except Exception as e:
logging.warning(f"Error calling function_tool tool: {e}")
self._update_metadata(result_key, {"error": str(e)}, operation_info)
continue
# For API type servers, use call_api function directly
if server_type == "api":
try:
call_result = await call_api(
server_name, tool_name, parameter, self.mcp_config
)
results.append(call_result)
self._update_metadata(result_key, call_result, operation_info)
except Exception as e:
logging.warning(f"Error calling API tool: {e}")
self._update_metadata(result_key, {"error": str(e)}, operation_info)
continue
# Prioritize using existing server instances
server = self.server_instances.get(server_name)
if server is None:
# If it doesn't exist, create a new instance and save it
server = await get_server_instance(server_name, self.mcp_config)
if server:
self.server_instances[server_name] = server
logging.info(f"Created and cached new server instance for {server_name}")
else:
logging.warning(f"Created new server failed: {server_name}")
self._update_metadata(result_key, {"error": "Failed to create server instance"}, operation_info)
continue
# Use server instance to call the tool
try:
call_result_raw = await server.call_tool(tool_name, parameter)
# Process the return result, consistent with the original logic
action_result = ActionResult(
tool_name=server_name,
action_name=tool_name,
content="",
keep=True
)
if call_result_raw and call_result_raw.content:
if isinstance(call_result_raw.content[0], TextContent):
action_result = ActionResult(
tool_name=server_name,
action_name=tool_name,
content=call_result_raw.content[0].text,
keep=True,
metadata=call_result_raw.content[0].model_extra.get(
"metadata", {}
),
)
elif isinstance(call_result_raw.content[0], ImageContent):
action_result = ActionResult(
tool_name=server_name,
action_name=tool_name,
content=f"data:image/jpeg;base64,{call_result_raw.content[0].data}",
keep=True,
metadata=call_result_raw.content[0].model_extra.get("metadata", {}),
)
results.append(action_result)
self._update_metadata(result_key, action_result, operation_info)
except Exception as e:
logging.warning(f"Error calling tool with cached server: {e}")
self._update_metadata(result_key, {"error": str(e)}, operation_info)
# If using cached server instance fails, try to clean up and recreate
if server_name in self.server_instances:
try:
await cleanup_server(self.server_instances[server_name])
del self.server_instances[server_name]
except Exception as e:
logging.warning(f"Failed to cleanup server {server_name}: {e}")
except Exception as e:
logging.warning(f"Failed to call_tool: {e}")
return None
return results
def _update_metadata(self, result_key: str, result: Any, operation_info: Dict[str, Any]):
"""
Update sandbox metadata with a single tool call result
Args:
result_key: The key name in metadata
result: Tool call result
operation_info: Operation information
"""
if not self.sandbox or not hasattr(self.sandbox, '_metadata'):
return
try:
metadata = self.sandbox._metadata.get("mcp_metadata",{})
tmp_data = {
"input": operation_info,
"output": result
}
if not metadata:
metadata["mcp_metadata"] = {}
metadata["mcp_metadata"][result_key] = [tmp_data]
self.sandbox._metadata["mcp_metadata"] = metadata
return
_metadata = metadata.get(result_key, [])
if not _metadata:
_metadata[result_key] = [_metadata]
else:
_metadata[result_key].append(tmp_data)
metadata[result_key] = _metadata
self.sandbox._metadata["mcp_metadata"] = metadata
return
except Exception as e:
logging.warning(f"Failed to update sandbox metadata: {e}")
# Add cleanup method, called when Sandbox is destroyed
async def cleanup(self):
"""Clean up all server connections"""
for server_name, server in list(self.server_instances.items()):
try:
await cleanup_server(server)
del self.server_instances[server_name]
logging.info(f"Cleaned up server instance for {server_name}")
except Exception as e:
logging.warning(f"Failed to cleanup server {server_name}: {e}")
|