Spaces:
Running
Running
| """ | |
| Utility functions for A1D MCP Server | |
| Handles API calls and data processing | |
| """ | |
| import requests | |
| import json | |
| import os | |
| import time | |
| import re | |
| from typing import Dict, Any, Optional, Tuple | |
| from config import A1D_API_BASE_URL, API_KEY, TOOLS_CONFIG | |
| class A1DAPIClient: | |
| """Client for making API calls to A1D services""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| # Try to get API key from multiple sources | |
| self.api_key = api_key or self._get_api_key() | |
| self.base_url = A1D_API_BASE_URL | |
| self.session = requests.Session() | |
| if not self.api_key: | |
| raise ValueError( | |
| "API key is required. Set A1D_API_KEY environment variable, pass it directly, or provide via MCP header.") | |
| # Set default headers | |
| self.session.headers.update({ | |
| "Authorization": f"KEY {self.api_key}", | |
| "Content-Type": "application/json", | |
| "User-Agent": "A1D-MCP-Server/1.0.0" | |
| }) | |
| def _get_api_key(self) -> Optional[str]: | |
| """Get API key from various sources""" | |
| # 1. Environment variable | |
| api_key = API_KEY | |
| if api_key: | |
| return api_key | |
| # 2. Try to get from Gradio request headers (if available) | |
| try: | |
| import gradio as gr | |
| request = gr.request() | |
| if request and hasattr(request, 'headers'): | |
| # Check for API_KEY header from MCP client | |
| api_key = request.headers.get( | |
| 'API_KEY') or request.headers.get('api_key') | |
| if api_key: | |
| print(f"π‘ Using API key from MCP client header") | |
| return api_key | |
| except: | |
| pass | |
| return None | |
| def make_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Make API request to A1D service""" | |
| url = f"{self.base_url}{endpoint}" | |
| # Add source field to all requests | |
| request_data = {**data, "source": "mcp"} | |
| # Print detailed request information | |
| print("\n" + "="*60) | |
| print("π A1D API REQUEST DEBUG INFO") | |
| print("="*60) | |
| print(f"π‘ URL: {url}") | |
| print(f"π§ Method: POST") | |
| print(f"\nπ Headers:") | |
| for key, value in self.session.headers.items(): | |
| # Mask API key for security | |
| if key.lower() in ['api_key', 'authorization']: | |
| masked_value = f"{value[:8]}..." if len(value) > 8 else "***" | |
| print(f" {key}: {masked_value}") | |
| else: | |
| print(f" {key}: {value}") | |
| print(f"\nπ¦ Request Body:") | |
| print(f" {json.dumps(request_data, indent=2)}") | |
| try: | |
| print(f"\nβ³ Sending request...") | |
| response = self.session.post(url, json=request_data, timeout=30) | |
| print(f"\nπ Response Info:") | |
| print(f" Status Code: {response.status_code}") | |
| print(f" Status Text: {response.reason}") | |
| print(f"\nπ Response Headers:") | |
| for key, value in response.headers.items(): | |
| print(f" {key}: {value}") | |
| print(f"\nπ¦ Response Body:") | |
| try: | |
| response_json = response.json() | |
| print(f" {json.dumps(response_json, indent=2)}") | |
| except: | |
| print(f" {response.text[:500]}...") | |
| print("="*60) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| print(f"\nβ Request failed: {str(e)}") | |
| print("="*60) | |
| raise Exception(f"API request failed: {str(e)}") | |
| except json.JSONDecodeError as e: | |
| print(f"\nβ JSON decode failed: {str(e)}") | |
| print("="*60) | |
| raise Exception(f"Failed to parse API response: {str(e)}") | |
| def get_task_result(self, task_id: str, timeout: int = 60) -> Dict[str, Any]: | |
| """Get task result using SSE endpoint""" | |
| url = f"{self.base_url}/api/task/{task_id}/sse" | |
| print(f"\nπ Getting task result...") | |
| print(f"π‘ SSE URL: {url}") | |
| print(f"β±οΈ Timeout: {timeout}s") | |
| headers = { | |
| "Authorization": f"KEY {self.api_key}", | |
| "Accept": "text/event-stream" | |
| } | |
| try: | |
| response = requests.get( | |
| url, headers=headers, stream=True, timeout=timeout) | |
| response.raise_for_status() | |
| print(f"π SSE Response Status: {response.status_code}") | |
| # Parse SSE stream | |
| for line in response.iter_lines(decode_unicode=True): | |
| if line: | |
| print(f"π₯ SSE Line: {line}") | |
| # Parse SSE data | |
| if line.startswith("data: "): | |
| data_str = line[6:] # Remove "data: " prefix | |
| if data_str.strip() == "[DONE]": | |
| print("β Task completed!") | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| print( | |
| f"π¦ Parsed data: {json.dumps(data, indent=2)}") | |
| # Check if task is completed | |
| status = data.get("status", "").upper() | |
| if (status in ["COMPLETED", "FINISHED", "SUCCESS"] or | |
| "result" in data or | |
| "imageUrl" in data or | |
| "videoUrl" in data or | |
| "url" in data): | |
| print("β Task result received!") | |
| return data | |
| elif status in ["FAILED", "ERROR"]: | |
| raise Exception( | |
| f"Task failed: {data.get('error', 'Unknown error')}") | |
| else: | |
| print( | |
| f"β³ Task status: {data.get('status', 'processing')}") | |
| except json.JSONDecodeError: | |
| print(f"β οΈ Could not parse JSON: {data_str}") | |
| continue | |
| raise Exception("Task did not complete within timeout") | |
| except requests.exceptions.RequestException as e: | |
| print(f"β SSE request failed: {str(e)}") | |
| raise Exception(f"Failed to get task result: {str(e)}") | |
| def make_request_with_result(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]: | |
| """Make API request and wait for result""" | |
| # First, make the initial request to get task ID | |
| response = self.make_request(endpoint, data) | |
| if "taskId" not in response: | |
| raise Exception("No taskId in response") | |
| task_id = response["taskId"] | |
| print(f"\nπ― Task ID: {task_id}") | |
| # Then get the result | |
| return self.get_task_result(task_id, timeout) | |
| def validate_url(url: str) -> bool: | |
| """Validate if the provided string is a valid URL""" | |
| import re | |
| url_pattern = re.compile( | |
| r'^https?://' # http:// or https:// | |
| # domain... | |
| r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' | |
| r'localhost|' # localhost... | |
| r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
| r'(?::\d+)?' # optional port | |
| r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
| return url_pattern.match(url) is not None | |
| def validate_scale(scale: int) -> bool: | |
| """Validate scale parameter for image upscaling""" | |
| return scale in TOOLS_CONFIG["image_upscaler"]["scale_options"] | |
| def prepare_request_data(tool_name: str, **kwargs) -> Dict[str, Any]: | |
| """Prepare request data based on tool configuration""" | |
| if tool_name not in TOOLS_CONFIG: | |
| raise ValueError(f"Unknown tool: {tool_name}") | |
| config = TOOLS_CONFIG[tool_name] | |
| data = {} | |
| # Add required parameters | |
| for param in config["required_params"]: | |
| if param not in kwargs: | |
| raise ValueError(f"Missing required parameter: {param}") | |
| # Apply parameter mapping if exists | |
| param_mapping = config.get("param_mapping", {}) | |
| api_param_name = param_mapping.get(param, param) | |
| data[api_param_name] = kwargs[param] | |
| # Add optional parameters with defaults | |
| for param in config.get("optional_params", []): | |
| if param in kwargs: | |
| # Apply parameter mapping if exists | |
| param_mapping = config.get("param_mapping", {}) | |
| api_param_name = param_mapping.get(param, param) | |
| data[api_param_name] = kwargs[param] | |
| elif param in config.get("default_values", {}): | |
| # Apply parameter mapping if exists | |
| param_mapping = config.get("param_mapping", {}) | |
| api_param_name = param_mapping.get(param, param) | |
| data[api_param_name] = config["default_values"][param] | |
| return data | |
| def format_response_with_preview(response: Dict[str, Any], tool_name: str) -> Tuple[str, Optional[str]]: | |
| """Format API response for display with media preview | |
| Returns: | |
| Tuple of (message, media_url_for_preview) | |
| """ | |
| if "error" in response: | |
| return f"β Error: {response['error']}", None | |
| # Handle different response formats | |
| result_url = None | |
| # Check for A1D API specific fields first | |
| result_url = (response.get("imageUrl") or | |
| response.get("videoUrl") or | |
| response.get("url")) | |
| # Then check nested result fields | |
| if not result_url and "result" in response: | |
| result = response["result"] | |
| if isinstance(result, dict): | |
| # Try different possible URL fields | |
| result_url = (result.get("imageUrl") or | |
| result.get("videoUrl") or | |
| result.get("url") or | |
| result.get("image_url") or | |
| result.get("video_url") or | |
| result.get("output_url")) | |
| elif isinstance(result, str) and result.startswith("http"): | |
| result_url = result | |
| # Also check other common fields | |
| if not result_url: | |
| result_url = (response.get("image_url") or | |
| response.get("video_url") or | |
| response.get("output_url")) | |
| if result_url: | |
| # Determine media type | |
| media_type = "image" | |
| if any(ext in result_url.lower() for ext in ['.mp4', '.avi', '.mov', '.webm']): | |
| media_type = "video" | |
| message = f"β Success! {media_type.title()} generated: {result_url}" | |
| return message, result_url | |
| return f"β Task completed successfully for {tool_name}", None | |
| def format_response(response: Dict[str, Any], tool_name: str) -> str: | |
| """Format API response for display (backward compatibility)""" | |
| message, _ = format_response_with_preview(response, tool_name) | |
| return message | |
| def get_tool_info(tool_name: str) -> Dict[str, Any]: | |
| """Get tool configuration information""" | |
| if tool_name not in TOOLS_CONFIG: | |
| raise ValueError(f"Unknown tool: {tool_name}") | |
| return TOOLS_CONFIG[tool_name] | |