""" MCP Client implementation for Universal MCP Client - Fixed Version """ import asyncio import json import re import logging import traceback from typing import Dict, Optional, Tuple, List, Any from openai import OpenAI # Import the proper MCP client components from mcp import ClientSession from mcp.client.sse import sse_client from config import MCPServerConfig, AppConfig, HTTPX_AVAILABLE logger = logging.getLogger(__name__) class UniversalMCPClient: """Universal MCP Client using HuggingFace Inference Providers instead of Anthropic""" def __init__(self): self.servers: Dict[str, MCPServerConfig] = {} self.enabled_servers: Dict[str, bool] = {} # Track enabled/disabled servers self.hf_client = None self.current_provider = None self.current_model = None self.server_tools = {} # Cache for server tools # Initialize HF Inference Client if token is available if AppConfig.HF_TOKEN: self.hf_client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=AppConfig.HF_TOKEN ) logger.info("✅ HuggingFace Inference client initialized") else: logger.warning("⚠️ HF_TOKEN not found") def enable_server(self, server_name: str, enabled: bool = True): """Enable or disable a server""" if server_name in self.servers: self.enabled_servers[server_name] = enabled logger.info(f"🔧 Server {server_name} {'enabled' if enabled else 'disabled'}") def get_enabled_servers(self) -> Dict[str, MCPServerConfig]: """Get only enabled servers""" return {name: config for name, config in self.servers.items() if self.enabled_servers.get(name, True)} def remove_all_servers(self): """Remove all servers""" count = len(self.servers) self.servers.clear() self.enabled_servers.clear() self.server_tools.clear() logger.info(f"🗑️ Removed all {count} servers") return count def set_model_and_provider(self, provider_id: str, model_id: str): """Set the current provider and model""" self.current_provider = provider_id self.current_model = model_id logger.info(f"🔧 Set provider: {provider_id}, model: {model_id}") def get_model_endpoint(self) -> str: """Get the current model endpoint for API calls""" if not self.current_provider or not self.current_model: raise ValueError("Provider and model must be set before making API calls") return AppConfig.get_model_endpoint(self.current_model, self.current_provider) async def add_server_async(self, config: MCPServerConfig) -> Tuple[bool, str]: """Add an MCP server using pure MCP protocol""" try: logger.info(f"🔧 Adding MCP server: {config.name} at {config.url}") # Clean and validate URL - handle various input formats original_url = config.url.strip() # Remove common MCP endpoint variations base_url = original_url for endpoint in ["/gradio_api/mcp/sse", "/gradio_api/mcp/", "/gradio_api/mcp"]: if base_url.endswith(endpoint): base_url = base_url[:-len(endpoint)] break # Remove trailing slashes base_url = base_url.rstrip("/") # Construct proper MCP URL mcp_url = f"{base_url}/gradio_api/mcp/sse" logger.info(f"🔧 Original URL: {original_url}") logger.info(f"🔧 Base URL: {base_url}") logger.info(f"🔧 MCP URL: {mcp_url}") # Extract space ID if it's a HuggingFace space if "hf.space" in base_url: space_parts = base_url.split("/") if len(space_parts) >= 1: space_id = space_parts[-1].replace('.hf.space', '').replace('https://', '').replace('http://', '') if '-' in space_id: # Format: username-spacename.hf.space config.space_id = space_id.replace('-', '/', 1) else: config.space_id = space_id logger.info(f"📍 Detected HF Space ID: {config.space_id}") # Update config with proper MCP URL config.url = mcp_url # Test MCP connection and cache tools success, message = await self._test_mcp_connection(config) if success: self.servers[config.name] = config self.enabled_servers[config.name] = True # Enable by default logger.info(f"✅ MCP Server {config.name} added successfully") return True, f"✅ Successfully added MCP server: {config.name}\n{message}" else: logger.error(f"❌ Failed to connect to MCP server {config.name}: {message}") return False, f"❌ Failed to add server: {config.name}\n{message}" except Exception as e: error_msg = f"Failed to add server {config.name}: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) return False, f"❌ {error_msg}" async def _test_mcp_connection(self, config: MCPServerConfig) -> Tuple[bool, str]: """Test MCP server connection with detailed debugging and tool caching""" try: logger.info(f"🔍 Testing MCP connection to {config.url}") async with sse_client(config.url, timeout=20.0) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize MCP session logger.info("🔧 Initializing MCP session...") await session.initialize() # List available tools logger.info("📋 Listing available tools...") tools = await session.list_tools() # Cache tools for this server server_tools = {} tool_info = [] for tool in tools.tools: server_tools[tool.name] = { 'description': tool.description, 'schema': tool.inputSchema if hasattr(tool, 'inputSchema') else None } tool_info.append(f" - {tool.name}: {tool.description}") logger.info(f" 📍 Tool: {tool.name}") logger.info(f" Description: {tool.description}") if hasattr(tool, 'inputSchema') and tool.inputSchema: logger.info(f" Input Schema: {tool.inputSchema}") # Cache tools for this server self.server_tools[config.name] = server_tools if len(tools.tools) == 0: return False, "No tools found on MCP server" message = f"Connected successfully!\nFound {len(tools.tools)} tools:\n" + "\n".join(tool_info) return True, message except asyncio.TimeoutError: return False, "Connection timeout - server may be sleeping or unreachable" except Exception as e: logger.error(f"MCP connection failed: {e}") logger.error(traceback.format_exc()) return False, f"Connection failed: {str(e)}" async def call_mcp_tool_async(self, server_name: str, tool_name: str, arguments: dict) -> Tuple[bool, str]: """Call a tool on a specific MCP server""" logger.info(f"🔧 MCP Tool Call - Server: {server_name}, Tool: {tool_name}") logger.info(f"🔧 Arguments: {arguments}") if server_name not in self.servers: error_msg = f"Server {server_name} not found. Available servers: {list(self.servers.keys())}" logger.error(f"❌ {error_msg}") return False, error_msg config = self.servers[server_name] logger.info(f"🔧 Using server config: {config.url}") try: logger.info(f"🔗 Connecting to MCP server at {config.url}") async with sse_client(config.url, timeout=30.0) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize MCP session logger.info("🔧 Initializing MCP session...") await session.initialize() # Call the tool logger.info(f"🔧 Calling tool {tool_name} with arguments: {arguments}") result = await session.call_tool(tool_name, arguments) # Extract result content if result.content: result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0]) logger.info(f"✅ Tool call successful, result length: {len(result_text)}") logger.info(f"📋 Result preview: {result_text[:200]}...") return True, result_text else: error_msg = "No content returned from tool" logger.error(f"❌ {error_msg}") return False, error_msg except asyncio.TimeoutError: error_msg = f"Tool call timeout for {tool_name} on {server_name}" logger.error(f"❌ {error_msg}") return False, error_msg except Exception as e: error_msg = f"Tool call failed: {str(e)}" logger.error(f"❌ MCP tool call failed: {e}") logger.error(traceback.format_exc()) return False, error_msg def generate_chat_completion(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: """Generate chat completion using HuggingFace Inference Providers""" if not self.hf_client: raise ValueError("HuggingFace client not initialized. Please set HF_TOKEN.") if not self.current_provider or not self.current_model: raise ValueError("Provider and model must be set before making API calls") # Get the model endpoint model_endpoint = self.get_model_endpoint() # Set up default parameters for GPT OSS models with higher limits params = { "model": model_endpoint, "messages": messages, "max_tokens": kwargs.pop("max_tokens", 8192), # Use pop to avoid conflicts "temperature": kwargs.get("temperature", 0.3), "stream": kwargs.get("stream", False) } # Add any remaining kwargs params.update(kwargs) # Add reasoning effort if specified (GPT OSS feature) reasoning_effort = kwargs.pop("reasoning_effort", AppConfig.DEFAULT_REASONING_EFFORT) if reasoning_effort: # For GPT OSS models, we can set reasoning in system prompt system_message = None for msg in messages: if msg.get("role") == "system": system_message = msg break if system_message: system_message["content"] += f"\n\nReasoning: {reasoning_effort}" else: messages.insert(0, { "role": "system", "content": f"You are a helpful AI assistant. Reasoning: {reasoning_effort}" }) try: logger.info(f"🤖 Calling {model_endpoint} via {self.current_provider}") response = self.hf_client.chat.completions.create(**params) return response except Exception as e: logger.error(f"HF Inference API call failed: {e}") raise def generate_chat_completion_with_mcp_tools(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]: """Generate chat completion with MCP tool support""" enabled_servers = self.get_enabled_servers() if not enabled_servers: # No enabled MCP servers available, use regular completion logger.info("🤖 No enabled MCP servers available, using regular chat completion") return self.generate_chat_completion(messages, **kwargs) logger.info(f"🔧 Processing chat with {len(enabled_servers)} enabled MCP servers available") # Add system message about available tools with exact tool names tool_descriptions = [] server_names = [] exact_tool_mappings = [] for server_name, config in enabled_servers.items(): tool_descriptions.append(f"- **{server_name}**: {config.description}") server_names.append(server_name) # Add exact tool names if we have them cached if server_name in self.server_tools: for tool_name, tool_info in self.server_tools[server_name].items(): exact_tool_mappings.append(f" * Server '{server_name}' has tool '{tool_name}': {tool_info['description']}") # Get the actual server name (not the space ID) server_list = ", ".join([f'"{name}"' for name in server_names]) tools_system_msg = f""" You have access to the following MCP tools: {chr(10).join(tool_descriptions)} EXACT TOOL MAPPINGS: {chr(10).join(exact_tool_mappings) if exact_tool_mappings else "Loading tool mappings..."} IMPORTANT SERVER NAMES: {server_list} When you need to use a tool, respond with ONLY a JSON object in this EXACT format: {{"use_tool": true, "server": "exact_server_name", "tool": "exact_tool_name", "arguments": {{"param": "value"}}}} CRITICAL INSTRUCTIONS: - Use ONLY the exact server names from this list: {server_list} - Use the exact tool names as shown in the mappings above - Always include all required parameters in the arguments - Do not include any other text before or after the JSON - Make sure the JSON is complete and properly formatted If you don't need to use a tool, respond normally without any JSON. """ # Add tools system message with increased context enhanced_messages = messages.copy() if enhanced_messages and enhanced_messages[0].get("role") == "system": enhanced_messages[0]["content"] += "\n\n" + tools_system_msg else: enhanced_messages.insert(0, {"role": "system", "content": tools_system_msg}) # Get initial response with higher token limit logger.info("🤖 Getting initial response from LLM...") response = self.generate_chat_completion(enhanced_messages, **{"max_tokens": 8192}) response_text = response.choices[0].message.content logger.info(f"🤖 LLM Response (length: {len(response_text)}): {response_text}") # Check if the response indicates tool usage if '"use_tool": true' in response_text: logger.info("🔧 Tool usage detected, parsing JSON...") # Extract and parse JSON more robustly tool_request = self._extract_tool_json(response_text) if not tool_request: # Fallback: try to extract tool info manually logger.info("🔧 JSON parsing failed, trying manual extraction...") tool_request = self._manual_tool_extraction(response_text) if tool_request: server_name = tool_request.get("server") tool_name = tool_request.get("tool") arguments = tool_request.get("arguments", {}) # Replace any local file paths in arguments with uploaded URLs if hasattr(self, 'chat_handler_file_mapping'): for arg_key, arg_value in arguments.items(): if isinstance(arg_value, str) and arg_value.startswith('/tmp/gradio/'): # Check if we have an uploaded URL for this local path for local_path, uploaded_url in self.chat_handler_file_mapping.items(): if local_path in arg_value or arg_value in local_path: logger.info(f"🔄 Replacing local path {arg_value} with uploaded URL {uploaded_url}") arguments[arg_key] = uploaded_url break logger.info(f"🔧 Tool request - Server: {server_name}, Tool: {tool_name}, Args: {arguments}") if server_name not in self.servers: available_servers = list(self.servers.keys()) logger.error(f"❌ Server '{server_name}' not found. Available servers: {available_servers}") # Try to find a matching server by space_id or similar name matching_server = None for srv_name, srv_config in self.servers.items(): if (srv_config.space_id and server_name in srv_config.space_id) or server_name in srv_name: matching_server = srv_name logger.info(f"🔧 Found matching server: {matching_server}") break if matching_server and self.enabled_servers.get(matching_server, True): server_name = matching_server logger.info(f"🔧 Using corrected server name: {server_name}") else: # Return error response with server name correction error_msg = f"Server '{server_name}' not found or disabled. Available enabled servers: {[name for name, enabled in self.enabled_servers.items() if enabled]}" response._tool_execution = { "server": server_name, "tool": tool_name, "result": error_msg, "success": False } return response elif not self.enabled_servers.get(server_name, True): logger.error(f"❌ Server '{server_name}' is disabled") response._tool_execution = { "server": server_name, "tool": tool_name, "result": f"Server '{server_name}' is currently disabled", "success": False } return response # Validate tool name exists for this server if server_name in self.server_tools and tool_name not in self.server_tools[server_name]: available_tools = list(self.server_tools[server_name].keys()) logger.warning(f"⚠️ Tool '{tool_name}' not found for server '{server_name}'. Available tools: {available_tools}") # Try to find the correct tool name if available_tools: # Use the first available tool if there's only one if len(available_tools) == 1: tool_name = available_tools[0] logger.info(f"🔧 Using only available tool: {tool_name}") # Or try to find a similar tool name else: for available_tool in available_tools: if tool_name.lower() in available_tool.lower() or available_tool.lower() in tool_name.lower(): tool_name = available_tool logger.info(f"🔧 Found similar tool name: {tool_name}") break # Call the MCP tool def run_mcp_tool(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete( self.call_mcp_tool_async(server_name, tool_name, arguments) ) finally: loop.close() success, result = run_mcp_tool() if success: logger.info(f"✅ Tool call successful, result length: {len(str(result))}") # Add tool result to conversation and get final response with better prompting enhanced_messages.append({"role": "assistant", "content": response_text}) enhanced_messages.append({"role": "user", "content": f"Tool '{tool_name}' from server '{server_name}' completed successfully. Result: {result}\n\nPlease provide a helpful response based on this tool result. If the result contains media URLs, present them appropriately."}) # Remove the tool instruction from the system message for the final response final_messages = enhanced_messages.copy() if final_messages[0].get("role") == "system": final_messages[0]["content"] = final_messages[0]["content"].split("You have access to the following MCP tools:")[0].strip() logger.info("🤖 Getting final response with tool result...") final_response = self.generate_chat_completion(final_messages, **{"max_tokens": 4096}) # Store tool execution info for the chat handler final_response._tool_execution = { "server": server_name, "tool": tool_name, "result": result, "success": True } return final_response else: logger.error(f"❌ Tool call failed: {result}") # Return original response with error info response._tool_execution = { "server": server_name, "tool": tool_name, "result": result, "success": False } return response else: logger.warning("⚠️ Failed to parse tool request JSON") else: logger.info("💬 No tool usage detected, returning normal response") # Return original response if no tool usage or tool call failed return response def _extract_tool_json(self, text: str) -> Optional[Dict[str, Any]]: """Extract JSON from LLM response more robustly""" import json import re logger.info(f"🔍 Full LLM response text: {text}") # Try multiple strategies to extract JSON strategies = [ # Strategy 1: Find complete JSON between outer braces lambda t: re.search(r'\{[^{}]*"use_tool"[^{}]*"arguments"[^{}]*\{[^{}]*\}[^{}]*\}', t), # Strategy 2: Find JSON that starts with {"use_tool" and reconstruct if needed lambda t: self._reconstruct_json_from_start(t), # Strategy 3: Find any complete JSON object lambda t: re.search(r'\{(?:[^{}]|\{[^{}]*\})*\}', t), ] for i, strategy in enumerate(strategies, 1): try: if i == 2: # Strategy 2 returns a string directly json_str = strategy(text) if not json_str: continue else: match = strategy(text) if not match: continue json_str = match.group(0) logger.info(f"🔍 JSON extraction strategy {i} found: {json_str}") # Clean up the JSON string json_str = json_str.strip() # Try to parse parsed = json.loads(json_str) # Validate it's a tool request if parsed.get("use_tool") is True: logger.info(f"✅ Valid tool request parsed: {parsed}") return parsed except json.JSONDecodeError as e: logger.warning(f"⚠️ JSON parse error with strategy {i}: {e}") logger.warning(f"⚠️ Problematic JSON: {json_str if 'json_str' in locals() else 'N/A'}") continue except Exception as e: logger.warning(f"⚠️ Strategy {i} failed: {e}") continue logger.error("❌ Failed to extract valid JSON from response") return None def _manual_tool_extraction(self, text: str) -> Optional[Dict[str, Any]]: """Manually extract tool information as fallback""" import re logger.info("🔧 Attempting manual tool extraction...") try: # Extract server name server_match = re.search(r'"server":\s*"([^"]+)"', text) tool_match = re.search(r'"tool":\s*"([^"]+)"', text) if not server_match or not tool_match: logger.warning("⚠️ Could not find server or tool in manual extraction") return None server_name = server_match.group(1) tool_name = tool_match.group(1) # Try to extract arguments args_match = re.search(r'"arguments":\s*\{([^}]+)\}', text) arguments = {} if args_match: args_content = args_match.group(1) # Simple extraction of key-value pairs pairs = re.findall(r'"([^"]+)":\s*"([^"]+)"', args_content) arguments = dict(pairs) manual_request = { "use_tool": True, "server": server_name, "tool": tool_name, "arguments": arguments } logger.info(f"🔧 Manual extraction successful: {manual_request}") return manual_request except Exception as e: logger.error(f"❌ Manual extraction failed: {e}") return None def _reconstruct_json_from_start(self, text: str) -> Optional[str]: """Try to reconstruct JSON if it's truncated""" import re # Find start of JSON match = re.search(r'\{"use_tool":\s*true[^}]*', text) if not match: return None json_start = match.start() json_part = text[json_start:] logger.info(f"🔧 Reconstructing JSON from: {json_part[:200]}...") # Try to find the end or reconstruct brace_count = 0 end_pos = 0 in_string = False escape_next = False for i, char in enumerate(json_part): if escape_next: escape_next = False continue if char == '\\': escape_next = True continue if char == '"' and not escape_next: in_string = not in_string continue if not in_string: if char == '{': brace_count += 1 elif char == '}': brace_count -= 1 if brace_count == 0: end_pos = i + 1 break if end_pos > 0: reconstructed = json_part[:end_pos] logger.info(f"🔧 Reconstructed JSON: {reconstructed}") return reconstructed else: # Try to add missing closing braces missing_braces = json_part.count('{') - json_part.count('}') if missing_braces > 0: reconstructed = json_part + '}' * missing_braces logger.info(f"🔧 Added {missing_braces} closing braces: {reconstructed}") return reconstructed return None def _extract_media_from_mcp_response(self, result_text: str, config: MCPServerConfig) -> Optional[str]: """Enhanced media extraction from MCP responses with better URL resolution""" if not isinstance(result_text, str): logger.info(f"🔍 Non-string result: {type(result_text)}") return None base_url = config.url.replace("/gradio_api/mcp/sse", "") logger.info(f"🔍 Processing MCP result for media: {result_text[:300]}...") logger.info(f"🔍 Base URL: {base_url}") # 1. Try to parse as JSON (most Gradio MCP servers return structured data) try: if result_text.strip().startswith('[') or result_text.strip().startswith('{'): logger.info("🔍 Attempting JSON parse...") data = json.loads(result_text.strip()) logger.info(f"🔍 Parsed JSON structure: {data}") # Handle array format: [{'image': {'url': '...'}}] or [{'url': '...'}] if isinstance(data, list) and len(data) > 0: item = data[0] logger.info(f"🔍 First array item: {item}") if isinstance(item, dict): # Check for nested media structure for media_type in ['image', 'audio', 'video']: if media_type in item and isinstance(item[media_type], dict): media_data = item[media_type] if 'url' in media_data: url = media_data['url'].strip('\'"') # Clean quotes logger.info(f"🎯 Found {media_type} URL: {url}") return self._resolve_media_url(url, base_url) # Check for direct URL if 'url' in item: url = item['url'].strip('\'"') # Clean quotes logger.info(f"🎯 Found direct URL: {url}") return self._resolve_media_url(url, base_url) # Handle object format: {'image': {'url': '...'}} or {'url': '...'} elif isinstance(data, dict): logger.info(f"🔍 Processing dict: {data}") # Check for nested media structure for media_type in ['image', 'audio', 'video']: if media_type in data and isinstance(data[media_type], dict): media_data = data[media_type] if 'url' in media_data: url = media_data['url'].strip('\'"') # Clean quotes logger.info(f"🎯 Found {media_type} URL: {url}") return self._resolve_media_url(url, base_url) # Check for direct URL if 'url' in data: url = data['url'].strip('\'"') # Clean quotes logger.info(f"🎯 Found direct URL: {url}") return self._resolve_media_url(url, base_url) except json.JSONDecodeError: logger.info("🔍 Not valid JSON, trying other formats...") except Exception as e: logger.warning(f"🔍 JSON parsing error: {e}") # 2. Check for Gradio file URLs (common pattern) with better cleaning gradio_file_patterns = [ r'https://[^/]+\.hf\.space/gradio_api/file=/[^/]+/[^/]+/[^"\s\',]+', r'https://[^/]+\.hf\.space/file=[^"\s\',]+', r'/gradio_api/file=/[^"\s\',]+' ] for pattern in gradio_file_patterns: match = re.search(pattern, result_text) if match: url = match.group(0).rstrip('\'",:;') # Remove trailing punctuation logger.info(f"🎯 Found Gradio file URL: {url}") if url.startswith('/'): url = f"{base_url}{url}" return url # 3. Check for simple HTTP URLs in the text http_url_pattern = r'https?://[^\s"<>]+' matches = re.findall(http_url_pattern, result_text) for url in matches: if AppConfig.is_media_file(url): logger.info(f"🎯 Found HTTP media URL: {url}") return url # 4. Check for data URLs (base64 encoded media) if result_text.startswith('data:'): logger.info("🎯 Found data URL") return result_text # 5. For simple file paths, create proper Gradio URLs if AppConfig.is_media_file(result_text): # Extract just the filename if it's a path if '/' in result_text: filename = result_text.split('/')[-1] else: filename = result_text.strip() # Create proper Gradio file URL media_url = f"{base_url}/file={filename}" logger.info(f"🎯 Created media URL from filename: {media_url}") return media_url logger.info("❌ No media detected in result") return None def _resolve_media_url(self, url: str, base_url: str) -> str: """Resolve relative URLs to absolute URLs with better handling""" if url.startswith('http') or url.startswith('data:'): return url elif url.startswith('/gradio_api/file='): return f"{base_url}{url}" elif url.startswith('/file='): return f"{base_url}/gradio_api{url}" elif url.startswith('file='): return f"{base_url}/gradio_api/{url}" elif url.startswith('/'): return f"{base_url}/file={url}" else: return f"{base_url}/file={url}" def get_server_status(self) -> Dict[str, str]: """Get status of all configured servers""" status = {} for name in self.servers: compatibility = self._check_file_upload_compatibility(self.servers[name]) status[name] = f"✅ Connected (MCP Protocol) - {compatibility}" return status def _check_file_upload_compatibility(self, config: MCPServerConfig) -> str: """Check if a server likely supports file uploads""" if "hf.space" in config.url: return "🟡 Hugging Face Space (usually compatible)" elif "gradio" in config.url.lower(): return "🟢 Gradio server (likely compatible)" elif "localhost" in config.url or "127.0.0.1" in config.url: return "🟢 Local server (file access available)" else: return "🔴 Remote server (may need public URLs)"