data_science_agent / utils /huggingface_mcp_llamaindex.py
bpHigh's picture
Revamp stuff
fc0d268
"""
Robust Hugging Face MCP Client - Optimized for HF Spaces
This module provides a robust client for interacting with Hugging Face's MCP endpoint
with better error handling, TaskGroup avoidance, and compatibility for Hugging Face Spaces.
"""
import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional, Union
from datetime import timedelta
from contextlib import asynccontextmanager
from mcp.shared.message import SessionMessage
from mcp.types import (
JSONRPCMessage,
JSONRPCRequest,
JSONRPCNotification,
JSONRPCResponse,
JSONRPCError,
)
from mcp.client.streamable_http import streamablehttp_client
logger = logging.getLogger(__name__)
class RobustHFMCPClient:
"""Robust client for interacting with Hugging Face MCP endpoint optimized for Spaces."""
def __init__(self, hf_token: str, timeout: int = 120):
"""
Initialize the Robust Hugging Face MCP client.
Args:
hf_token: Hugging Face API token
timeout: Timeout in seconds for HTTP requests
"""
self.hf_token = hf_token
self.url = "https://huggingface.co/mcp"
self.headers = {
"Authorization": f"Bearer {hf_token}",
"User-Agent": "robust-hf-mcp-client/2.0.0",
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json"
}
self.timeout = timedelta(seconds=timeout)
self.sse_read_timeout = timedelta(seconds=timeout * 2)
self.request_id_counter = 0
def _get_next_request_id(self) -> int:
"""Get the next request ID."""
self.request_id_counter += 1
return self.request_id_counter
async def _execute_single_request_session(
self,
method: str,
params: Optional[Dict[str, Any]] = None
) -> Any:
"""
Execute a complete MCP session for a single request.
This avoids TaskGroup issues by handling everything in sequence.
"""
request_id = self._get_next_request_id()
# Create the main request
main_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
method=method,
params=params
)
async with streamablehttp_client(
url=self.url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
terminate_on_close=False # Avoid TaskGroup cleanup issues
) as (read_stream, write_stream, get_session_id):
# Step 1: Initialize the session
logger.info("Starting MCP session initialization...")
await self._initialize_session(read_stream, write_stream)
# Step 2: Send the main request
logger.info(f"Sending main request: {method}")
main_message = JSONRPCMessage(main_request)
main_session_message = SessionMessage(main_message)
await write_stream.send(main_session_message)
# Step 3: Wait for the response
logger.info("Waiting for main request response...")
response = await self._wait_for_response(read_stream, request_id, timeout=90)
return response
async def _initialize_session(self, read_stream, write_stream) -> None:
"""Initialize the MCP session with proper handshake."""
init_request_id = self._get_next_request_id()
# Send initialize request
init_request = JSONRPCRequest(
jsonrpc="2.0",
id=init_request_id,
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {},
"resources": {},
"prompts": {}
},
"clientInfo": {
"name": "robust-hf-mcp-client",
"version": "2.0.0"
}
}
)
init_message = JSONRPCMessage(init_request)
init_session_message = SessionMessage(init_message)
await write_stream.send(init_session_message)
# Wait for initialization response
init_response = await self._wait_for_response(read_stream, init_request_id, timeout=60)
logger.info("MCP session initialized successfully")
# Send initialized notification
initialized_notification = JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized"
)
init_notif_message = JSONRPCMessage(initialized_notification)
init_notif_session_message = SessionMessage(init_notif_message)
await write_stream.send(init_notif_session_message)
# Give the server time to process the notification
await asyncio.sleep(1.0)
async def _wait_for_response(
self,
read_stream,
expected_id: int,
timeout: int = 60
) -> Any:
"""
Wait for a specific response by ID with timeout handling.
"""
start_time = asyncio.get_event_loop().time()
while True:
current_time = asyncio.get_event_loop().time()
if current_time - start_time > timeout:
raise asyncio.TimeoutError(f"Timeout waiting for response to request {expected_id}")
try:
# Use a shorter timeout for each receive to avoid hanging
response = await asyncio.wait_for(
read_stream.receive(),
timeout=10.0
)
if isinstance(response, Exception):
logger.error(f"Received exception in stream: {response}")
raise response
if isinstance(response, SessionMessage):
msg_root = response.message.root
if isinstance(msg_root, JSONRPCResponse) and msg_root.id == expected_id:
logger.info(f"Received successful response for request {expected_id}")
return msg_root.result
elif isinstance(msg_root, JSONRPCError) and msg_root.id == expected_id:
error_msg = f"Server error for request {expected_id}: {msg_root.error}"
logger.error(error_msg)
raise Exception(error_msg)
else:
# Log unexpected messages but continue waiting
logger.debug(f"Received unexpected message type: {type(msg_root)} with ID: {getattr(msg_root, 'id', 'N/A')}")
continue
except asyncio.TimeoutError:
# Continue the outer loop to check the overall timeout
logger.debug("Receive timeout, continuing to wait...")
continue
except Exception as e:
if "ClosedResourceError" in str(type(e)) or "StreamClosed" in str(e):
raise Exception("Connection closed while waiting for response")
logger.error(f"Error while waiting for response: {e}")
raise
async def get_all_tools(self) -> List[Dict[str, Any]]:
"""
Get all available tools from the Hugging Face MCP endpoint.
Returns:
List of tool definitions
"""
try:
logger.info("Fetching all available tools from Hugging Face MCP")
result = await self._execute_single_request_session("tools/list")
if isinstance(result, dict) and "tools" in result:
tools = result["tools"]
logger.info(f"Successfully fetched {len(tools)} tools")
return tools
else:
logger.warning(f"Unexpected response format for tools/list: {result}")
return []
except Exception as e:
logger.error(f"Failed to get tools: {e}")
raise
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
"""
Call a specific tool with the given arguments.
Args:
tool_name: Name of the tool to call
args: Arguments to pass to the tool
Returns:
The tool's response
"""
try:
logger.info(f"Calling tool '{tool_name}' with args: {args}")
params = {
"name": tool_name,
"arguments": args
}
result = await self._execute_single_request_session("tools/call", params)
logger.info(f"Tool '{tool_name}' executed successfully")
return result
except Exception as e:
logger.error(f"Failed to call tool '{tool_name}': {e}")
raise
class SimplifiedHFMCPClient:
"""Ultra-simplified client that avoids all TaskGroup usage."""
def __init__(self, hf_token: str, timeout: int = 90):
self.hf_token = hf_token
self.timeout = timeout
self.headers = {
"Authorization": f"Bearer {hf_token}",
"User-Agent": "simplified-hf-mcp-client/1.0.0"
}
self.request_counter = 0
def _next_id(self) -> int:
self.request_counter += 1
return self.request_counter
async def _simple_mcp_call(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Make a simple MCP call without complex async patterns."""
async with streamablehttp_client(
url="https://huggingface.co/mcp",
headers=self.headers,
timeout=timedelta(seconds=self.timeout),
sse_read_timeout=timedelta(seconds=self.timeout * 2),
terminate_on_close=False
) as (read_stream, write_stream, get_session_id):
responses = {}
# Simple message handler
async def collect_responses():
try:
async for message in read_stream:
if isinstance(message, Exception):
responses['error'] = message
break
elif isinstance(message, SessionMessage):
msg_root = message.message.root
if hasattr(msg_root, 'id') and msg_root.id is not None:
responses[msg_root.id] = msg_root
except Exception as e:
responses['error'] = e
# Start response collector
collector_task = asyncio.create_task(collect_responses())
try:
# Step 1: Initialize
init_id = self._next_id()
init_req = JSONRPCRequest(
jsonrpc="2.0",
id=init_id,
method="initialize",
params={
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "simple-hf-mcp", "version": "1.0.0"}
}
)
await write_stream.send(SessionMessage(JSONRPCMessage(init_req)))
# Wait for init response
for _ in range(300): # 30 seconds max
if init_id in responses:
break
if 'error' in responses:
raise responses['error']
await asyncio.sleep(0.1)
if init_id not in responses:
raise Exception("Initialization timeout")
# Step 2: Send initialized notification
notif = JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized"
)
await write_stream.send(SessionMessage(JSONRPCMessage(notif)))
await asyncio.sleep(0.5)
# Step 3: Send main request
main_id = self._next_id()
main_req = JSONRPCRequest(
jsonrpc="2.0",
id=main_id,
method=method,
params=params
)
await write_stream.send(SessionMessage(JSONRPCMessage(main_req)))
# Wait for main response
for _ in range(600): # 60 seconds max
if main_id in responses:
break
if 'error' in responses:
raise responses['error']
await asyncio.sleep(0.1)
if main_id not in responses:
raise Exception("Main request timeout")
result = responses[main_id]
if isinstance(result, JSONRPCResponse):
return result.result
elif isinstance(result, JSONRPCError):
raise Exception(f"Server error: {result.error}")
else:
raise Exception(f"Unexpected response type: {type(result)}")
finally:
collector_task.cancel()
try:
await collector_task
except asyncio.CancelledError:
pass
async def get_tools(self) -> List[Dict[str, Any]]:
"""Get all available tools."""
result = await self._simple_mcp_call("tools/list")
if isinstance(result, dict) and "tools" in result:
return result["tools"]
return []
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
"""Call a specific tool."""
params = {
"name": tool_name,
"arguments": args
}
return await self._simple_mcp_call("tools/call", params)
# Robust convenience functions
async def get_hf_tools_robust(hf_token: str, max_retries: int = 3) -> List[Dict[str, Any]]:
"""
Get all available tools with multiple fallback strategies.
Args:
hf_token: Hugging Face API token
max_retries: Maximum retry attempts per method
Returns:
List of tool definitions
"""
last_error = None
# Strategy 1: Try the robust client
for attempt in range(max_retries):
try:
logger.info(f"Trying robust client (attempt {attempt + 1})")
client = RobustHFMCPClient(hf_token, timeout=90)
tools = await client.get_all_tools()
logger.info(f"Robust client succeeded with {len(tools)} tools")
return tools
except Exception as e:
last_error = e
logger.warning(f"Robust client attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
# Strategy 2: Try the simplified client
for attempt in range(max_retries):
try:
logger.info(f"Trying simplified client (attempt {attempt + 1})")
client = SimplifiedHFMCPClient(hf_token, timeout=120)
tools = await client.get_tools()
logger.info(f"Simplified client succeeded with {len(tools)} tools")
return tools
except Exception as e:
last_error = e
logger.warning(f"Simplified client attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt)
# If all strategies fail
raise Exception(f"All connection strategies failed. Last error: {last_error}")
async def call_hf_tool_robust(
hf_token: str,
tool_name: str,
args: Dict[str, Any],
max_retries: int = 3
) -> Any:
"""
Call a specific Hugging Face MCP tool with multiple fallback strategies.
Args:
hf_token: Hugging Face API token
tool_name: Name of the tool to call
args: Arguments to pass to the tool
max_retries: Maximum retry attempts per method
Returns:
The tool's response
"""
last_error = None
# Strategy 1: Try the robust client
for attempt in range(max_retries):
try:
logger.info(f"Trying robust client for tool call (attempt {attempt + 1})")
client = RobustHFMCPClient(hf_token, timeout=120)
result = await client.call_tool(tool_name, args)
logger.info(f"Robust client tool call succeeded")
return result
except Exception as e:
last_error = e
logger.warning(f"Robust client tool call attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt)
# Strategy 2: Try the simplified client
for attempt in range(max_retries):
try:
logger.info(f"Trying simplified client for tool call (attempt {attempt + 1})")
client = SimplifiedHFMCPClient(hf_token, timeout=150)
result = await client.call_tool(tool_name, args)
logger.info(f"Simplified client tool call succeeded")
return result
except Exception as e:
last_error = e
logger.warning(f"Simplified client tool call attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt)
# If all strategies fail
raise Exception(f"All tool call strategies failed. Last error: {last_error}")
# Legacy compatibility functions
async def get_hf_tools(hf_token: str) -> List[Dict[str, Any]]:
"""Legacy function - now uses robust implementation."""
return await get_hf_tools_robust(hf_token)
async def call_hf_tool(hf_token: str, tool_name: str, args: Dict[str, Any]) -> Any:
"""Legacy function - now uses robust implementation."""
return await call_hf_tool_robust(hf_token, tool_name, args)
# Enhanced diagnostics
async def diagnose_connection_advanced(hf_token: str) -> Dict[str, Any]:
"""
Advanced connection diagnostics with multiple test scenarios.
Args:
hf_token: Hugging Face API token
Returns:
Comprehensive diagnostic information
"""
diagnostics = {
"environment": "huggingface_spaces" if os.getenv("SPACE_ID") else "local",
"space_id": os.getenv("SPACE_ID"),
"python_version": os.sys.version,
"token_length": len(hf_token) if hf_token else 0,
"has_token": bool(hf_token),
"tests": {
"basic_connection": False,
"robust_client": False,
"simplified_client": False,
"tools_fetch": False,
"tool_call_test": False
},
"errors": {},
"tool_count": 0,
"sample_tools": []
}
# Test 1: Basic connection
try:
async with streamablehttp_client(
url="https://huggingface.co/mcp",
headers={"Authorization": f"Bearer {hf_token}"},
timeout=timedelta(seconds=10),
terminate_on_close=False
) as (read_stream, write_stream, get_session_id):
diagnostics["tests"]["basic_connection"] = True
logger.info("Basic connection test passed")
except Exception as e:
diagnostics["errors"]["basic_connection"] = str(e)
logger.error(f"Basic connection test failed: {e}")
# Test 2: Robust client
if diagnostics["tests"]["basic_connection"]:
try:
client = RobustHFMCPClient(hf_token, timeout=60)
tools = await client.get_all_tools()
diagnostics["tests"]["robust_client"] = True
diagnostics["tests"]["tools_fetch"] = True
diagnostics["tool_count"] = len(tools)
diagnostics["sample_tools"] = [
{"name": tool.get("name"), "description": tool.get("description", "")[:100]}
for tool in tools[:3]
]
logger.info(f"Robust client test passed - {len(tools)} tools")
except Exception as e:
diagnostics["errors"]["robust_client"] = str(e)
logger.error(f"Robust client test failed: {e}")
# Test 3: Simplified client
if not diagnostics["tests"]["robust_client"]:
try:
client = SimplifiedHFMCPClient(hf_token, timeout=90)
tools = await client.get_tools()
diagnostics["tests"]["simplified_client"] = True
if not diagnostics["tests"]["tools_fetch"]:
diagnostics["tests"]["tools_fetch"] = True
diagnostics["tool_count"] = len(tools)
diagnostics["sample_tools"] = [
{"name": tool.get("name"), "description": tool.get("description", "")[:100]}
for tool in tools[:3]
]
logger.info(f"Simplified client test passed - {len(tools)} tools")
except Exception as e:
diagnostics["errors"]["simplified_client"] = str(e)
logger.error(f"Simplified client test failed: {e}")
# Test 4: Tool call (if we have tools)
if diagnostics["tests"]["tools_fetch"] and diagnostics["sample_tools"]:
try:
# Try to call a simple tool if available
sample_tool_name = diagnostics["sample_tools"][0]["name"]
if sample_tool_name:
# Use the working client
if diagnostics["tests"]["robust_client"]:
client = RobustHFMCPClient(hf_token, timeout=60)
else:
client = SimplifiedHFMCPClient(hf_token, timeout=90)
# Try with empty args first (many tools accept this)
try:
result = await client.call_tool(sample_tool_name, {})
diagnostics["tests"]["tool_call_test"] = True
logger.info(f"Tool call test passed with {sample_tool_name}")
except Exception as tool_error:
# Tool call failed but that might be due to wrong args
diagnostics["errors"]["tool_call_test"] = f"Tool call failed (might need args): {str(tool_error)}"
logger.warning(f"Tool call test failed: {tool_error}")
except Exception as e:
diagnostics["errors"]["tool_call_test"] = str(e)
logger.error(f"Tool call test setup failed: {e}")
return diagnostics