""" Utility functions for A1D MCP Server Handles API calls and data processing """ import requests import json import os import time import re from typing import Dict, Any, Optional, Tuple from config import A1D_API_BASE_URL, API_KEY, TOOLS_CONFIG class A1DAPIClient: """Client for making API calls to A1D services""" def __init__(self, api_key: Optional[str] = None): # Try to get API key from multiple sources self.api_key = api_key or self._get_api_key() self.base_url = A1D_API_BASE_URL self.session = requests.Session() if not self.api_key: raise ValueError( "API key is required. Set A1D_API_KEY environment variable, pass it directly, or provide via MCP header.") # Set default headers self.session.headers.update({ "Authorization": f"KEY {self.api_key}", "Content-Type": "application/json", "User-Agent": "A1D-MCP-Server/1.0.0" }) def _get_api_key(self) -> Optional[str]: """Get API key from various sources""" # 1. Environment variable api_key = API_KEY if api_key: return api_key # 2. Try to get from Gradio request headers (if available) try: import gradio as gr request = gr.request() if request and hasattr(request, 'headers'): # Check for API_KEY header from MCP client api_key = request.headers.get( 'API_KEY') or request.headers.get('api_key') if api_key: print(f"šŸ“” Using API key from MCP client header") return api_key except: pass return None def make_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: """Make API request to A1D service""" url = f"{self.base_url}{endpoint}" # Add source field to all requests request_data = {**data, "source": "mcp"} # Print detailed request information print("\n" + "="*60) print("šŸš€ A1D API REQUEST DEBUG INFO") print("="*60) print(f"šŸ“” URL: {url}") print(f"šŸ”§ Method: POST") print(f"\nšŸ“‹ Headers:") for key, value in self.session.headers.items(): # Mask API key for security if key.lower() in ['api_key', 'authorization']: masked_value = f"{value[:8]}..." if len(value) > 8 else "***" print(f" {key}: {masked_value}") else: print(f" {key}: {value}") print(f"\nšŸ“¦ Request Body:") print(f" {json.dumps(request_data, indent=2)}") try: print(f"\nā³ Sending request...") response = self.session.post(url, json=request_data, timeout=30) print(f"\nšŸ“Š Response Info:") print(f" Status Code: {response.status_code}") print(f" Status Text: {response.reason}") print(f"\nšŸ“‹ Response Headers:") for key, value in response.headers.items(): print(f" {key}: {value}") print(f"\nšŸ“¦ Response Body:") try: response_json = response.json() print(f" {json.dumps(response_json, indent=2)}") except: print(f" {response.text[:500]}...") print("="*60) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: print(f"\nāŒ Request failed: {str(e)}") print("="*60) raise Exception(f"API request failed: {str(e)}") except json.JSONDecodeError as e: print(f"\nāŒ JSON decode failed: {str(e)}") print("="*60) raise Exception(f"Failed to parse API response: {str(e)}") def get_task_result(self, task_id: str, timeout: int = 60) -> Dict[str, Any]: """Get task result using SSE endpoint""" url = f"{self.base_url}/api/task/{task_id}/sse" print(f"\nšŸ”„ Getting task result...") print(f"šŸ“” SSE URL: {url}") print(f"ā±ļø Timeout: {timeout}s") headers = { "Authorization": f"KEY {self.api_key}", "Accept": "text/event-stream" } try: response = requests.get( url, headers=headers, stream=True, timeout=timeout) response.raise_for_status() print(f"šŸ“Š SSE Response Status: {response.status_code}") # Parse SSE stream for line in response.iter_lines(decode_unicode=True): if line: print(f"šŸ“„ SSE Line: {line}") # Parse SSE data if line.startswith("data: "): data_str = line[6:] # Remove "data: " prefix if data_str.strip() == "[DONE]": print("āœ… Task completed!") break try: data = json.loads(data_str) print( f"šŸ“¦ Parsed data: {json.dumps(data, indent=2)}") # Check if task is completed status = data.get("status", "").upper() if (status in ["COMPLETED", "FINISHED", "SUCCESS"] or "result" in data or "imageUrl" in data or "videoUrl" in data or "url" in data): print("āœ… Task result received!") return data elif status in ["FAILED", "ERROR"]: raise Exception( f"Task failed: {data.get('error', 'Unknown error')}") else: print( f"ā³ Task status: {data.get('status', 'processing')}") except json.JSONDecodeError: print(f"āš ļø Could not parse JSON: {data_str}") continue raise Exception("Task did not complete within timeout") except requests.exceptions.RequestException as e: print(f"āŒ SSE request failed: {str(e)}") raise Exception(f"Failed to get task result: {str(e)}") def make_request_with_result(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]: """Make API request and wait for result""" # First, make the initial request to get task ID response = self.make_request(endpoint, data) if "taskId" not in response: raise Exception("No taskId in response") task_id = response["taskId"] print(f"\nšŸŽÆ Task ID: {task_id}") # Then get the result return self.get_task_result(task_id, timeout) def validate_url(url: str) -> bool: """Validate if the provided string is a valid URL""" import re url_pattern = re.compile( r'^https?://' # http:// or https:// # domain... r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' r'localhost|' # localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) return url_pattern.match(url) is not None def validate_scale(scale: int) -> bool: """Validate scale parameter for image upscaling""" return scale in TOOLS_CONFIG["image_upscaler"]["scale_options"] def prepare_request_data(tool_name: str, **kwargs) -> Dict[str, Any]: """Prepare request data based on tool configuration""" if tool_name not in TOOLS_CONFIG: raise ValueError(f"Unknown tool: {tool_name}") config = TOOLS_CONFIG[tool_name] data = {} # Add required parameters for param in config["required_params"]: if param not in kwargs: raise ValueError(f"Missing required parameter: {param}") # Apply parameter mapping if exists param_mapping = config.get("param_mapping", {}) api_param_name = param_mapping.get(param, param) data[api_param_name] = kwargs[param] # Add optional parameters with defaults for param in config.get("optional_params", []): if param in kwargs: # Apply parameter mapping if exists param_mapping = config.get("param_mapping", {}) api_param_name = param_mapping.get(param, param) data[api_param_name] = kwargs[param] elif param in config.get("default_values", {}): # Apply parameter mapping if exists param_mapping = config.get("param_mapping", {}) api_param_name = param_mapping.get(param, param) data[api_param_name] = config["default_values"][param] return data def format_response_with_preview(response: Dict[str, Any], tool_name: str) -> Tuple[str, Optional[str]]: """Format API response for display with media preview Returns: Tuple of (message, media_url_for_preview) """ if "error" in response: return f"āŒ Error: {response['error']}", None # Handle different response formats result_url = None # Check for A1D API specific fields first result_url = (response.get("imageUrl") or response.get("videoUrl") or response.get("url")) # Then check nested result fields if not result_url and "result" in response: result = response["result"] if isinstance(result, dict): # Try different possible URL fields result_url = (result.get("imageUrl") or result.get("videoUrl") or result.get("url") or result.get("image_url") or result.get("video_url") or result.get("output_url")) elif isinstance(result, str) and result.startswith("http"): result_url = result # Also check other common fields if not result_url: result_url = (response.get("image_url") or response.get("video_url") or response.get("output_url")) if result_url: # Determine media type media_type = "image" if any(ext in result_url.lower() for ext in ['.mp4', '.avi', '.mov', '.webm']): media_type = "video" message = f"āœ… Success! {media_type.title()} generated: {result_url}" return message, result_url return f"āœ… Task completed successfully for {tool_name}", None def format_response(response: Dict[str, Any], tool_name: str) -> str: """Format API response for display (backward compatibility)""" message, _ = format_response_with_preview(response, tool_name) return message def get_tool_info(tool_name: str) -> Dict[str, Any]: """Get tool configuration information""" if tool_name not in TOOLS_CONFIG: raise ValueError(f"Unknown tool: {tool_name}") return TOOLS_CONFIG[tool_name]