|
""" |
|
Robust Hugging Face MCP Client - Optimized for HF Spaces |
|
|
|
This module provides a robust client for interacting with Hugging Face's MCP endpoint |
|
with better error handling, TaskGroup avoidance, and compatibility for Hugging Face Spaces. |
|
""" |
|
|
|
import asyncio |
|
import json |
|
import logging |
|
import os |
|
from typing import Any, Dict, List, Optional, Union |
|
from datetime import timedelta |
|
from contextlib import asynccontextmanager |
|
|
|
from mcp.shared.message import SessionMessage |
|
from mcp.types import ( |
|
JSONRPCMessage, |
|
JSONRPCRequest, |
|
JSONRPCNotification, |
|
JSONRPCResponse, |
|
JSONRPCError, |
|
) |
|
from mcp.client.streamable_http import streamablehttp_client |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RobustHFMCPClient: |
|
"""Robust client for interacting with Hugging Face MCP endpoint optimized for Spaces.""" |
|
|
|
def __init__(self, hf_token: str, timeout: int = 120): |
|
""" |
|
Initialize the Robust Hugging Face MCP client. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
timeout: Timeout in seconds for HTTP requests |
|
""" |
|
self.hf_token = hf_token |
|
self.url = "https://huggingface.co/mcp" |
|
self.headers = { |
|
"Authorization": f"Bearer {hf_token}", |
|
"User-Agent": "robust-hf-mcp-client/2.0.0", |
|
"Accept": "application/json, text/event-stream", |
|
"Content-Type": "application/json" |
|
} |
|
self.timeout = timedelta(seconds=timeout) |
|
self.sse_read_timeout = timedelta(seconds=timeout * 2) |
|
self.request_id_counter = 0 |
|
|
|
def _get_next_request_id(self) -> int: |
|
"""Get the next request ID.""" |
|
self.request_id_counter += 1 |
|
return self.request_id_counter |
|
|
|
async def _execute_single_request_session( |
|
self, |
|
method: str, |
|
params: Optional[Dict[str, Any]] = None |
|
) -> Any: |
|
""" |
|
Execute a complete MCP session for a single request. |
|
This avoids TaskGroup issues by handling everything in sequence. |
|
""" |
|
request_id = self._get_next_request_id() |
|
|
|
|
|
main_request = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=request_id, |
|
method=method, |
|
params=params |
|
) |
|
|
|
async with streamablehttp_client( |
|
url=self.url, |
|
headers=self.headers, |
|
timeout=self.timeout, |
|
sse_read_timeout=self.sse_read_timeout, |
|
terminate_on_close=False |
|
) as (read_stream, write_stream, get_session_id): |
|
|
|
|
|
logger.info("Starting MCP session initialization...") |
|
await self._initialize_session(read_stream, write_stream) |
|
|
|
|
|
logger.info(f"Sending main request: {method}") |
|
main_message = JSONRPCMessage(main_request) |
|
main_session_message = SessionMessage(main_message) |
|
await write_stream.send(main_session_message) |
|
|
|
|
|
logger.info("Waiting for main request response...") |
|
response = await self._wait_for_response(read_stream, request_id, timeout=90) |
|
|
|
return response |
|
|
|
async def _initialize_session(self, read_stream, write_stream) -> None: |
|
"""Initialize the MCP session with proper handshake.""" |
|
init_request_id = self._get_next_request_id() |
|
|
|
|
|
init_request = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=init_request_id, |
|
method="initialize", |
|
params={ |
|
"protocolVersion": "2024-11-05", |
|
"capabilities": { |
|
"tools": {}, |
|
"resources": {}, |
|
"prompts": {} |
|
}, |
|
"clientInfo": { |
|
"name": "robust-hf-mcp-client", |
|
"version": "2.0.0" |
|
} |
|
} |
|
) |
|
|
|
init_message = JSONRPCMessage(init_request) |
|
init_session_message = SessionMessage(init_message) |
|
|
|
await write_stream.send(init_session_message) |
|
|
|
|
|
init_response = await self._wait_for_response(read_stream, init_request_id, timeout=60) |
|
logger.info("MCP session initialized successfully") |
|
|
|
|
|
initialized_notification = JSONRPCNotification( |
|
jsonrpc="2.0", |
|
method="notifications/initialized" |
|
) |
|
|
|
init_notif_message = JSONRPCMessage(initialized_notification) |
|
init_notif_session_message = SessionMessage(init_notif_message) |
|
|
|
await write_stream.send(init_notif_session_message) |
|
|
|
|
|
await asyncio.sleep(1.0) |
|
|
|
async def _wait_for_response( |
|
self, |
|
read_stream, |
|
expected_id: int, |
|
timeout: int = 60 |
|
) -> Any: |
|
""" |
|
Wait for a specific response by ID with timeout handling. |
|
""" |
|
start_time = asyncio.get_event_loop().time() |
|
|
|
while True: |
|
current_time = asyncio.get_event_loop().time() |
|
if current_time - start_time > timeout: |
|
raise asyncio.TimeoutError(f"Timeout waiting for response to request {expected_id}") |
|
|
|
try: |
|
|
|
response = await asyncio.wait_for( |
|
read_stream.receive(), |
|
timeout=10.0 |
|
) |
|
|
|
if isinstance(response, Exception): |
|
logger.error(f"Received exception in stream: {response}") |
|
raise response |
|
|
|
if isinstance(response, SessionMessage): |
|
msg_root = response.message.root |
|
|
|
if isinstance(msg_root, JSONRPCResponse) and msg_root.id == expected_id: |
|
logger.info(f"Received successful response for request {expected_id}") |
|
return msg_root.result |
|
|
|
elif isinstance(msg_root, JSONRPCError) and msg_root.id == expected_id: |
|
error_msg = f"Server error for request {expected_id}: {msg_root.error}" |
|
logger.error(error_msg) |
|
raise Exception(error_msg) |
|
|
|
else: |
|
|
|
logger.debug(f"Received unexpected message type: {type(msg_root)} with ID: {getattr(msg_root, 'id', 'N/A')}") |
|
continue |
|
|
|
except asyncio.TimeoutError: |
|
|
|
logger.debug("Receive timeout, continuing to wait...") |
|
continue |
|
except Exception as e: |
|
if "ClosedResourceError" in str(type(e)) or "StreamClosed" in str(e): |
|
raise Exception("Connection closed while waiting for response") |
|
logger.error(f"Error while waiting for response: {e}") |
|
raise |
|
|
|
async def get_all_tools(self) -> List[Dict[str, Any]]: |
|
""" |
|
Get all available tools from the Hugging Face MCP endpoint. |
|
|
|
Returns: |
|
List of tool definitions |
|
""" |
|
try: |
|
logger.info("Fetching all available tools from Hugging Face MCP") |
|
result = await self._execute_single_request_session("tools/list") |
|
|
|
if isinstance(result, dict) and "tools" in result: |
|
tools = result["tools"] |
|
logger.info(f"Successfully fetched {len(tools)} tools") |
|
return tools |
|
else: |
|
logger.warning(f"Unexpected response format for tools/list: {result}") |
|
return [] |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to get tools: {e}") |
|
raise |
|
|
|
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: |
|
""" |
|
Call a specific tool with the given arguments. |
|
|
|
Args: |
|
tool_name: Name of the tool to call |
|
args: Arguments to pass to the tool |
|
|
|
Returns: |
|
The tool's response |
|
""" |
|
try: |
|
logger.info(f"Calling tool '{tool_name}' with args: {args}") |
|
|
|
params = { |
|
"name": tool_name, |
|
"arguments": args |
|
} |
|
|
|
result = await self._execute_single_request_session("tools/call", params) |
|
logger.info(f"Tool '{tool_name}' executed successfully") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to call tool '{tool_name}': {e}") |
|
raise |
|
|
|
|
|
class SimplifiedHFMCPClient: |
|
"""Ultra-simplified client that avoids all TaskGroup usage.""" |
|
|
|
def __init__(self, hf_token: str, timeout: int = 90): |
|
self.hf_token = hf_token |
|
self.timeout = timeout |
|
self.headers = { |
|
"Authorization": f"Bearer {hf_token}", |
|
"User-Agent": "simplified-hf-mcp-client/1.0.0" |
|
} |
|
self.request_counter = 0 |
|
|
|
def _next_id(self) -> int: |
|
self.request_counter += 1 |
|
return self.request_counter |
|
|
|
async def _simple_mcp_call(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any: |
|
"""Make a simple MCP call without complex async patterns.""" |
|
|
|
async with streamablehttp_client( |
|
url="https://huggingface.co/mcp", |
|
headers=self.headers, |
|
timeout=timedelta(seconds=self.timeout), |
|
sse_read_timeout=timedelta(seconds=self.timeout * 2), |
|
terminate_on_close=False |
|
) as (read_stream, write_stream, get_session_id): |
|
|
|
responses = {} |
|
|
|
|
|
async def collect_responses(): |
|
try: |
|
async for message in read_stream: |
|
if isinstance(message, Exception): |
|
responses['error'] = message |
|
break |
|
elif isinstance(message, SessionMessage): |
|
msg_root = message.message.root |
|
if hasattr(msg_root, 'id') and msg_root.id is not None: |
|
responses[msg_root.id] = msg_root |
|
except Exception as e: |
|
responses['error'] = e |
|
|
|
|
|
collector_task = asyncio.create_task(collect_responses()) |
|
|
|
try: |
|
|
|
init_id = self._next_id() |
|
init_req = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=init_id, |
|
method="initialize", |
|
params={ |
|
"protocolVersion": "2024-11-05", |
|
"capabilities": {"tools": {}}, |
|
"clientInfo": {"name": "simple-hf-mcp", "version": "1.0.0"} |
|
} |
|
) |
|
|
|
await write_stream.send(SessionMessage(JSONRPCMessage(init_req))) |
|
|
|
|
|
for _ in range(300): |
|
if init_id in responses: |
|
break |
|
if 'error' in responses: |
|
raise responses['error'] |
|
await asyncio.sleep(0.1) |
|
|
|
if init_id not in responses: |
|
raise Exception("Initialization timeout") |
|
|
|
|
|
notif = JSONRPCNotification( |
|
jsonrpc="2.0", |
|
method="notifications/initialized" |
|
) |
|
await write_stream.send(SessionMessage(JSONRPCMessage(notif))) |
|
await asyncio.sleep(0.5) |
|
|
|
|
|
main_id = self._next_id() |
|
main_req = JSONRPCRequest( |
|
jsonrpc="2.0", |
|
id=main_id, |
|
method=method, |
|
params=params |
|
) |
|
|
|
await write_stream.send(SessionMessage(JSONRPCMessage(main_req))) |
|
|
|
|
|
for _ in range(600): |
|
if main_id in responses: |
|
break |
|
if 'error' in responses: |
|
raise responses['error'] |
|
await asyncio.sleep(0.1) |
|
|
|
if main_id not in responses: |
|
raise Exception("Main request timeout") |
|
|
|
result = responses[main_id] |
|
if isinstance(result, JSONRPCResponse): |
|
return result.result |
|
elif isinstance(result, JSONRPCError): |
|
raise Exception(f"Server error: {result.error}") |
|
else: |
|
raise Exception(f"Unexpected response type: {type(result)}") |
|
|
|
finally: |
|
collector_task.cancel() |
|
try: |
|
await collector_task |
|
except asyncio.CancelledError: |
|
pass |
|
|
|
async def get_tools(self) -> List[Dict[str, Any]]: |
|
"""Get all available tools.""" |
|
result = await self._simple_mcp_call("tools/list") |
|
if isinstance(result, dict) and "tools" in result: |
|
return result["tools"] |
|
return [] |
|
|
|
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: |
|
"""Call a specific tool.""" |
|
params = { |
|
"name": tool_name, |
|
"arguments": args |
|
} |
|
return await self._simple_mcp_call("tools/call", params) |
|
|
|
|
|
|
|
async def get_hf_tools_robust(hf_token: str, max_retries: int = 3) -> List[Dict[str, Any]]: |
|
""" |
|
Get all available tools with multiple fallback strategies. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
max_retries: Maximum retry attempts per method |
|
|
|
Returns: |
|
List of tool definitions |
|
""" |
|
last_error = None |
|
|
|
|
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"Trying robust client (attempt {attempt + 1})") |
|
client = RobustHFMCPClient(hf_token, timeout=90) |
|
tools = await client.get_all_tools() |
|
logger.info(f"Robust client succeeded with {len(tools)} tools") |
|
return tools |
|
except Exception as e: |
|
last_error = e |
|
logger.warning(f"Robust client attempt {attempt + 1} failed: {e}") |
|
if attempt < max_retries - 1: |
|
await asyncio.sleep(2 ** attempt) |
|
|
|
|
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"Trying simplified client (attempt {attempt + 1})") |
|
client = SimplifiedHFMCPClient(hf_token, timeout=120) |
|
tools = await client.get_tools() |
|
logger.info(f"Simplified client succeeded with {len(tools)} tools") |
|
return tools |
|
except Exception as e: |
|
last_error = e |
|
logger.warning(f"Simplified client attempt {attempt + 1} failed: {e}") |
|
if attempt < max_retries - 1: |
|
await asyncio.sleep(2 ** attempt) |
|
|
|
|
|
raise Exception(f"All connection strategies failed. Last error: {last_error}") |
|
|
|
|
|
async def call_hf_tool_robust( |
|
hf_token: str, |
|
tool_name: str, |
|
args: Dict[str, Any], |
|
max_retries: int = 3 |
|
) -> Any: |
|
""" |
|
Call a specific Hugging Face MCP tool with multiple fallback strategies. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
tool_name: Name of the tool to call |
|
args: Arguments to pass to the tool |
|
max_retries: Maximum retry attempts per method |
|
|
|
Returns: |
|
The tool's response |
|
""" |
|
last_error = None |
|
|
|
|
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"Trying robust client for tool call (attempt {attempt + 1})") |
|
client = RobustHFMCPClient(hf_token, timeout=120) |
|
result = await client.call_tool(tool_name, args) |
|
logger.info(f"Robust client tool call succeeded") |
|
return result |
|
except Exception as e: |
|
last_error = e |
|
logger.warning(f"Robust client tool call attempt {attempt + 1} failed: {e}") |
|
if attempt < max_retries - 1: |
|
await asyncio.sleep(2 ** attempt) |
|
|
|
|
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"Trying simplified client for tool call (attempt {attempt + 1})") |
|
client = SimplifiedHFMCPClient(hf_token, timeout=150) |
|
result = await client.call_tool(tool_name, args) |
|
logger.info(f"Simplified client tool call succeeded") |
|
return result |
|
except Exception as e: |
|
last_error = e |
|
logger.warning(f"Simplified client tool call attempt {attempt + 1} failed: {e}") |
|
if attempt < max_retries - 1: |
|
await asyncio.sleep(2 ** attempt) |
|
|
|
|
|
raise Exception(f"All tool call strategies failed. Last error: {last_error}") |
|
|
|
|
|
|
|
async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]: |
|
"""Legacy function - now uses robust implementation.""" |
|
return await get_hf_tools_robust(hf_token) |
|
|
|
|
|
async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any: |
|
"""Legacy function - now uses robust implementation.""" |
|
return await call_hf_tool_robust(hf_token, tool_name, args) |
|
|
|
|
|
|
|
async def diagnose_connection_advanced(hf_token: str) -> Dict[str, Any]: |
|
""" |
|
Advanced connection diagnostics with multiple test scenarios. |
|
|
|
Args: |
|
hf_token: Hugging Face API token |
|
|
|
Returns: |
|
Comprehensive diagnostic information |
|
""" |
|
diagnostics = { |
|
"environment": "huggingface_spaces" if os.getenv("SPACE_ID") else "local", |
|
"space_id": os.getenv("SPACE_ID"), |
|
"python_version": os.sys.version, |
|
"token_length": len(hf_token) if hf_token else 0, |
|
"has_token": bool(hf_token), |
|
"tests": { |
|
"basic_connection": False, |
|
"robust_client": False, |
|
"simplified_client": False, |
|
"tools_fetch": False, |
|
"tool_call_test": False |
|
}, |
|
"errors": {}, |
|
"tool_count": 0, |
|
"sample_tools": [] |
|
} |
|
|
|
|
|
try: |
|
async with streamablehttp_client( |
|
url="https://huggingface.co/mcp", |
|
headers={"Authorization": f"Bearer {hf_token}"}, |
|
timeout=timedelta(seconds=10), |
|
terminate_on_close=False |
|
) as (read_stream, write_stream, get_session_id): |
|
diagnostics["tests"]["basic_connection"] = True |
|
logger.info("Basic connection test passed") |
|
except Exception as e: |
|
diagnostics["errors"]["basic_connection"] = str(e) |
|
logger.error(f"Basic connection test failed: {e}") |
|
|
|
|
|
if diagnostics["tests"]["basic_connection"]: |
|
try: |
|
client = RobustHFMCPClient(hf_token, timeout=60) |
|
tools = await client.get_all_tools() |
|
diagnostics["tests"]["robust_client"] = True |
|
diagnostics["tests"]["tools_fetch"] = True |
|
diagnostics["tool_count"] = len(tools) |
|
diagnostics["sample_tools"] = [ |
|
{"name": tool.get("name"), "description": tool.get("description", "")[:100]} |
|
for tool in tools[:3] |
|
] |
|
logger.info(f"Robust client test passed - {len(tools)} tools") |
|
except Exception as e: |
|
diagnostics["errors"]["robust_client"] = str(e) |
|
logger.error(f"Robust client test failed: {e}") |
|
|
|
|
|
if not diagnostics["tests"]["robust_client"]: |
|
try: |
|
client = SimplifiedHFMCPClient(hf_token, timeout=90) |
|
tools = await client.get_tools() |
|
diagnostics["tests"]["simplified_client"] = True |
|
if not diagnostics["tests"]["tools_fetch"]: |
|
diagnostics["tests"]["tools_fetch"] = True |
|
diagnostics["tool_count"] = len(tools) |
|
diagnostics["sample_tools"] = [ |
|
{"name": tool.get("name"), "description": tool.get("description", "")[:100]} |
|
for tool in tools[:3] |
|
] |
|
logger.info(f"Simplified client test passed - {len(tools)} tools") |
|
except Exception as e: |
|
diagnostics["errors"]["simplified_client"] = str(e) |
|
logger.error(f"Simplified client test failed: {e}") |
|
|
|
|
|
if diagnostics["tests"]["tools_fetch"] and diagnostics["sample_tools"]: |
|
try: |
|
|
|
sample_tool_name = diagnostics["sample_tools"][0]["name"] |
|
if sample_tool_name: |
|
|
|
if diagnostics["tests"]["robust_client"]: |
|
client = RobustHFMCPClient(hf_token, timeout=60) |
|
else: |
|
client = SimplifiedHFMCPClient(hf_token, timeout=90) |
|
|
|
|
|
try: |
|
result = await client.call_tool(sample_tool_name, {}) |
|
diagnostics["tests"]["tool_call_test"] = True |
|
logger.info(f"Tool call test passed with {sample_tool_name}") |
|
except Exception as tool_error: |
|
|
|
diagnostics["errors"]["tool_call_test"] = f"Tool call failed (might need args): {str(tool_error)}" |
|
logger.warning(f"Tool call test failed: {tool_error}") |
|
|
|
except Exception as e: |
|
diagnostics["errors"]["tool_call_test"] = str(e) |
|
logger.error(f"Tool call test setup failed: {e}") |
|
|
|
return diagnostics |