""" Robust Hugging Face MCP Client - Optimized for HF Spaces This module provides a robust client for interacting with Hugging Face's MCP endpoint with better error handling, TaskGroup avoidance, and compatibility for Hugging Face Spaces. """ import asyncio import json import logging import os from typing import Any, Dict, List, Optional, Union from datetime import timedelta from contextlib import asynccontextmanager from mcp.shared.message import SessionMessage from mcp.types import ( JSONRPCMessage, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCError, ) from mcp.client.streamable_http import streamablehttp_client logger = logging.getLogger(__name__) class RobustHFMCPClient: """Robust client for interacting with Hugging Face MCP endpoint optimized for Spaces.""" def __init__(self, hf_token: str, timeout: int = 120): """ Initialize the Robust Hugging Face MCP client. Args: hf_token: Hugging Face API token timeout: Timeout in seconds for HTTP requests """ self.hf_token = hf_token self.url = "https://huggingface.co/mcp" self.headers = { "Authorization": f"Bearer {hf_token}", "User-Agent": "robust-hf-mcp-client/2.0.0", "Accept": "application/json, text/event-stream", "Content-Type": "application/json" } self.timeout = timedelta(seconds=timeout) self.sse_read_timeout = timedelta(seconds=timeout * 2) self.request_id_counter = 0 def _get_next_request_id(self) -> int: """Get the next request ID.""" self.request_id_counter += 1 return self.request_id_counter async def _execute_single_request_session( self, method: str, params: Optional[Dict[str, Any]] = None ) -> Any: """ Execute a complete MCP session for a single request. This avoids TaskGroup issues by handling everything in sequence. """ request_id = self._get_next_request_id() # Create the main request main_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, method=method, params=params ) async with streamablehttp_client( url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout, terminate_on_close=False # Avoid TaskGroup cleanup issues ) as (read_stream, write_stream, get_session_id): # Step 1: Initialize the session logger.info("Starting MCP session initialization...") await self._initialize_session(read_stream, write_stream) # Step 2: Send the main request logger.info(f"Sending main request: {method}") main_message = JSONRPCMessage(main_request) main_session_message = SessionMessage(main_message) await write_stream.send(main_session_message) # Step 3: Wait for the response logger.info("Waiting for main request response...") response = await self._wait_for_response(read_stream, request_id, timeout=90) return response async def _initialize_session(self, read_stream, write_stream) -> None: """Initialize the MCP session with proper handshake.""" init_request_id = self._get_next_request_id() # Send initialize request init_request = JSONRPCRequest( jsonrpc="2.0", id=init_request_id, method="initialize", params={ "protocolVersion": "2024-11-05", "capabilities": { "tools": {}, "resources": {}, "prompts": {} }, "clientInfo": { "name": "robust-hf-mcp-client", "version": "2.0.0" } } ) init_message = JSONRPCMessage(init_request) init_session_message = SessionMessage(init_message) await write_stream.send(init_session_message) # Wait for initialization response init_response = await self._wait_for_response(read_stream, init_request_id, timeout=60) logger.info("MCP session initialized successfully") # Send initialized notification initialized_notification = JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized" ) init_notif_message = JSONRPCMessage(initialized_notification) init_notif_session_message = SessionMessage(init_notif_message) await write_stream.send(init_notif_session_message) # Give the server time to process the notification await asyncio.sleep(1.0) async def _wait_for_response( self, read_stream, expected_id: int, timeout: int = 60 ) -> Any: """ Wait for a specific response by ID with timeout handling. """ start_time = asyncio.get_event_loop().time() while True: current_time = asyncio.get_event_loop().time() if current_time - start_time > timeout: raise asyncio.TimeoutError(f"Timeout waiting for response to request {expected_id}") try: # Use a shorter timeout for each receive to avoid hanging response = await asyncio.wait_for( read_stream.receive(), timeout=10.0 ) if isinstance(response, Exception): logger.error(f"Received exception in stream: {response}") raise response if isinstance(response, SessionMessage): msg_root = response.message.root if isinstance(msg_root, JSONRPCResponse) and msg_root.id == expected_id: logger.info(f"Received successful response for request {expected_id}") return msg_root.result elif isinstance(msg_root, JSONRPCError) and msg_root.id == expected_id: error_msg = f"Server error for request {expected_id}: {msg_root.error}" logger.error(error_msg) raise Exception(error_msg) else: # Log unexpected messages but continue waiting logger.debug(f"Received unexpected message type: {type(msg_root)} with ID: {getattr(msg_root, 'id', 'N/A')}") continue except asyncio.TimeoutError: # Continue the outer loop to check the overall timeout logger.debug("Receive timeout, continuing to wait...") continue except Exception as e: if "ClosedResourceError" in str(type(e)) or "StreamClosed" in str(e): raise Exception("Connection closed while waiting for response") logger.error(f"Error while waiting for response: {e}") raise async def get_all_tools(self) -> List[Dict[str, Any]]: """ Get all available tools from the Hugging Face MCP endpoint. Returns: List of tool definitions """ try: logger.info("Fetching all available tools from Hugging Face MCP") result = await self._execute_single_request_session("tools/list") if isinstance(result, dict) and "tools" in result: tools = result["tools"] logger.info(f"Successfully fetched {len(tools)} tools") return tools else: logger.warning(f"Unexpected response format for tools/list: {result}") return [] except Exception as e: logger.error(f"Failed to get tools: {e}") raise async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: """ Call a specific tool with the given arguments. Args: tool_name: Name of the tool to call args: Arguments to pass to the tool Returns: The tool's response """ try: logger.info(f"Calling tool '{tool_name}' with args: {args}") params = { "name": tool_name, "arguments": args } result = await self._execute_single_request_session("tools/call", params) logger.info(f"Tool '{tool_name}' executed successfully") return result except Exception as e: logger.error(f"Failed to call tool '{tool_name}': {e}") raise class SimplifiedHFMCPClient: """Ultra-simplified client that avoids all TaskGroup usage.""" def __init__(self, hf_token: str, timeout: int = 90): self.hf_token = hf_token self.timeout = timeout self.headers = { "Authorization": f"Bearer {hf_token}", "User-Agent": "simplified-hf-mcp-client/1.0.0" } self.request_counter = 0 def _next_id(self) -> int: self.request_counter += 1 return self.request_counter async def _simple_mcp_call(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any: """Make a simple MCP call without complex async patterns.""" async with streamablehttp_client( url="https://huggingface.co/mcp", headers=self.headers, timeout=timedelta(seconds=self.timeout), sse_read_timeout=timedelta(seconds=self.timeout * 2), terminate_on_close=False ) as (read_stream, write_stream, get_session_id): responses = {} # Simple message handler async def collect_responses(): try: async for message in read_stream: if isinstance(message, Exception): responses['error'] = message break elif isinstance(message, SessionMessage): msg_root = message.message.root if hasattr(msg_root, 'id') and msg_root.id is not None: responses[msg_root.id] = msg_root except Exception as e: responses['error'] = e # Start response collector collector_task = asyncio.create_task(collect_responses()) try: # Step 1: Initialize init_id = self._next_id() init_req = JSONRPCRequest( jsonrpc="2.0", id=init_id, method="initialize", params={ "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, "clientInfo": {"name": "simple-hf-mcp", "version": "1.0.0"} } ) await write_stream.send(SessionMessage(JSONRPCMessage(init_req))) # Wait for init response for _ in range(300): # 30 seconds max if init_id in responses: break if 'error' in responses: raise responses['error'] await asyncio.sleep(0.1) if init_id not in responses: raise Exception("Initialization timeout") # Step 2: Send initialized notification notif = JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized" ) await write_stream.send(SessionMessage(JSONRPCMessage(notif))) await asyncio.sleep(0.5) # Step 3: Send main request main_id = self._next_id() main_req = JSONRPCRequest( jsonrpc="2.0", id=main_id, method=method, params=params ) await write_stream.send(SessionMessage(JSONRPCMessage(main_req))) # Wait for main response for _ in range(600): # 60 seconds max if main_id in responses: break if 'error' in responses: raise responses['error'] await asyncio.sleep(0.1) if main_id not in responses: raise Exception("Main request timeout") result = responses[main_id] if isinstance(result, JSONRPCResponse): return result.result elif isinstance(result, JSONRPCError): raise Exception(f"Server error: {result.error}") else: raise Exception(f"Unexpected response type: {type(result)}") finally: collector_task.cancel() try: await collector_task except asyncio.CancelledError: pass async def get_tools(self) -> List[Dict[str, Any]]: """Get all available tools.""" result = await self._simple_mcp_call("tools/list") if isinstance(result, dict) and "tools" in result: return result["tools"] return [] async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: """Call a specific tool.""" params = { "name": tool_name, "arguments": args } return await self._simple_mcp_call("tools/call", params) # Robust convenience functions async def get_hf_tools_robust(hf_token: str, max_retries: int = 3) -> List[Dict[str, Any]]: """ Get all available tools with multiple fallback strategies. Args: hf_token: Hugging Face API token max_retries: Maximum retry attempts per method Returns: List of tool definitions """ last_error = None # Strategy 1: Try the robust client for attempt in range(max_retries): try: logger.info(f"Trying robust client (attempt {attempt + 1})") client = RobustHFMCPClient(hf_token, timeout=90) tools = await client.get_all_tools() logger.info(f"Robust client succeeded with {len(tools)} tools") return tools except Exception as e: last_error = e logger.warning(f"Robust client attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: await asyncio.sleep(2 ** attempt) # Exponential backoff # Strategy 2: Try the simplified client for attempt in range(max_retries): try: logger.info(f"Trying simplified client (attempt {attempt + 1})") client = SimplifiedHFMCPClient(hf_token, timeout=120) tools = await client.get_tools() logger.info(f"Simplified client succeeded with {len(tools)} tools") return tools except Exception as e: last_error = e logger.warning(f"Simplified client attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: await asyncio.sleep(2 ** attempt) # If all strategies fail raise Exception(f"All connection strategies failed. Last error: {last_error}") async def call_hf_tool_robust( hf_token: str, tool_name: str, args: Dict[str, Any], max_retries: int = 3 ) -> Any: """ Call a specific Hugging Face MCP tool with multiple fallback strategies. Args: hf_token: Hugging Face API token tool_name: Name of the tool to call args: Arguments to pass to the tool max_retries: Maximum retry attempts per method Returns: The tool's response """ last_error = None # Strategy 1: Try the robust client for attempt in range(max_retries): try: logger.info(f"Trying robust client for tool call (attempt {attempt + 1})") client = RobustHFMCPClient(hf_token, timeout=120) result = await client.call_tool(tool_name, args) logger.info(f"Robust client tool call succeeded") return result except Exception as e: last_error = e logger.warning(f"Robust client tool call attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: await asyncio.sleep(2 ** attempt) # Strategy 2: Try the simplified client for attempt in range(max_retries): try: logger.info(f"Trying simplified client for tool call (attempt {attempt + 1})") client = SimplifiedHFMCPClient(hf_token, timeout=150) result = await client.call_tool(tool_name, args) logger.info(f"Simplified client tool call succeeded") return result except Exception as e: last_error = e logger.warning(f"Simplified client tool call attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: await asyncio.sleep(2 ** attempt) # If all strategies fail raise Exception(f"All tool call strategies failed. Last error: {last_error}") # Legacy compatibility functions async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]: """Legacy function - now uses robust implementation.""" return await get_hf_tools_robust(hf_token) async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any: """Legacy function - now uses robust implementation.""" return await call_hf_tool_robust(hf_token, tool_name, args) # Enhanced diagnostics async def diagnose_connection_advanced(hf_token: str) -> Dict[str, Any]: """ Advanced connection diagnostics with multiple test scenarios. Args: hf_token: Hugging Face API token Returns: Comprehensive diagnostic information """ diagnostics = { "environment": "huggingface_spaces" if os.getenv("SPACE_ID") else "local", "space_id": os.getenv("SPACE_ID"), "python_version": os.sys.version, "token_length": len(hf_token) if hf_token else 0, "has_token": bool(hf_token), "tests": { "basic_connection": False, "robust_client": False, "simplified_client": False, "tools_fetch": False, "tool_call_test": False }, "errors": {}, "tool_count": 0, "sample_tools": [] } # Test 1: Basic connection try: async with streamablehttp_client( url="https://huggingface.co/mcp", headers={"Authorization": f"Bearer {hf_token}"}, timeout=timedelta(seconds=10), terminate_on_close=False ) as (read_stream, write_stream, get_session_id): diagnostics["tests"]["basic_connection"] = True logger.info("Basic connection test passed") except Exception as e: diagnostics["errors"]["basic_connection"] = str(e) logger.error(f"Basic connection test failed: {e}") # Test 2: Robust client if diagnostics["tests"]["basic_connection"]: try: client = RobustHFMCPClient(hf_token, timeout=60) tools = await client.get_all_tools() diagnostics["tests"]["robust_client"] = True diagnostics["tests"]["tools_fetch"] = True diagnostics["tool_count"] = len(tools) diagnostics["sample_tools"] = [ {"name": tool.get("name"), "description": tool.get("description", "")[:100]} for tool in tools[:3] ] logger.info(f"Robust client test passed - {len(tools)} tools") except Exception as e: diagnostics["errors"]["robust_client"] = str(e) logger.error(f"Robust client test failed: {e}") # Test 3: Simplified client if not diagnostics["tests"]["robust_client"]: try: client = SimplifiedHFMCPClient(hf_token, timeout=90) tools = await client.get_tools() diagnostics["tests"]["simplified_client"] = True if not diagnostics["tests"]["tools_fetch"]: diagnostics["tests"]["tools_fetch"] = True diagnostics["tool_count"] = len(tools) diagnostics["sample_tools"] = [ {"name": tool.get("name"), "description": tool.get("description", "")[:100]} for tool in tools[:3] ] logger.info(f"Simplified client test passed - {len(tools)} tools") except Exception as e: diagnostics["errors"]["simplified_client"] = str(e) logger.error(f"Simplified client test failed: {e}") # Test 4: Tool call (if we have tools) if diagnostics["tests"]["tools_fetch"] and diagnostics["sample_tools"]: try: # Try to call a simple tool if available sample_tool_name = diagnostics["sample_tools"][0]["name"] if sample_tool_name: # Use the working client if diagnostics["tests"]["robust_client"]: client = RobustHFMCPClient(hf_token, timeout=60) else: client = SimplifiedHFMCPClient(hf_token, timeout=90) # Try with empty args first (many tools accept this) try: result = await client.call_tool(sample_tool_name, {}) diagnostics["tests"]["tool_call_test"] = True logger.info(f"Tool call test passed with {sample_tool_name}") except Exception as tool_error: # Tool call failed but that might be due to wrong args diagnostics["errors"]["tool_call_test"] = f"Tool call failed (might need args): {str(tool_error)}" logger.warning(f"Tool call test failed: {tool_error}") except Exception as e: diagnostics["errors"]["tool_call_test"] = str(e) logger.error(f"Tool call test setup failed: {e}") return diagnostics