Spaces:
Running
Running
""" | |
Utility functions for A1D MCP Server | |
Handles API calls and data processing | |
""" | |
import requests | |
import json | |
import os | |
import time | |
import re | |
from typing import Dict, Any, Optional, Tuple | |
from config import A1D_API_BASE_URL, API_KEY, TOOLS_CONFIG | |
class A1DAPIClient: | |
"""Client for making API calls to A1D services""" | |
def __init__(self, api_key: Optional[str] = None): | |
# Try to get API key from multiple sources | |
self.api_key = api_key or self._get_api_key() | |
self.base_url = A1D_API_BASE_URL | |
self.session = requests.Session() | |
if not self.api_key: | |
raise ValueError( | |
"API key is required. Set A1D_API_KEY environment variable, pass it directly, or provide via MCP header.") | |
# Set default headers | |
self.session.headers.update({ | |
"Authorization": f"KEY {self.api_key}", | |
"Content-Type": "application/json", | |
"User-Agent": "A1D-MCP-Server/1.0.0" | |
}) | |
def _get_api_key(self) -> Optional[str]: | |
"""Get API key from various sources""" | |
# 1. Environment variable | |
api_key = API_KEY | |
if api_key: | |
return api_key | |
# 2. Try to get from Gradio request headers (if available) | |
try: | |
import gradio as gr | |
request = gr.request() | |
if request and hasattr(request, 'headers'): | |
# Check for API_KEY header from MCP client | |
api_key = request.headers.get( | |
'API_KEY') or request.headers.get('api_key') | |
if api_key: | |
print(f"π‘ Using API key from MCP client header") | |
return api_key | |
except: | |
pass | |
return None | |
def make_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
"""Make API request to A1D service""" | |
url = f"{self.base_url}{endpoint}" | |
# Add source field to all requests | |
request_data = {**data, "source": "mcp"} | |
# Print detailed request information | |
print("\n" + "="*60) | |
print("π A1D API REQUEST DEBUG INFO") | |
print("="*60) | |
print(f"π‘ URL: {url}") | |
print(f"π§ Method: POST") | |
print(f"\nπ Headers:") | |
for key, value in self.session.headers.items(): | |
# Mask API key for security | |
if key.lower() in ['api_key', 'authorization']: | |
masked_value = f"{value[:8]}..." if len(value) > 8 else "***" | |
print(f" {key}: {masked_value}") | |
else: | |
print(f" {key}: {value}") | |
print(f"\nπ¦ Request Body:") | |
print(f" {json.dumps(request_data, indent=2)}") | |
try: | |
print(f"\nβ³ Sending request...") | |
response = self.session.post(url, json=request_data, timeout=30) | |
print(f"\nπ Response Info:") | |
print(f" Status Code: {response.status_code}") | |
print(f" Status Text: {response.reason}") | |
print(f"\nπ Response Headers:") | |
for key, value in response.headers.items(): | |
print(f" {key}: {value}") | |
print(f"\nπ¦ Response Body:") | |
try: | |
response_json = response.json() | |
print(f" {json.dumps(response_json, indent=2)}") | |
except: | |
print(f" {response.text[:500]}...") | |
print("="*60) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
print(f"\nβ Request failed: {str(e)}") | |
print("="*60) | |
raise Exception(f"API request failed: {str(e)}") | |
except json.JSONDecodeError as e: | |
print(f"\nβ JSON decode failed: {str(e)}") | |
print("="*60) | |
raise Exception(f"Failed to parse API response: {str(e)}") | |
def get_task_result(self, task_id: str, timeout: int = 60) -> Dict[str, Any]: | |
"""Get task result using SSE endpoint""" | |
url = f"{self.base_url}/api/task/{task_id}/sse" | |
print(f"\nπ Getting task result...") | |
print(f"π‘ SSE URL: {url}") | |
print(f"β±οΈ Timeout: {timeout}s") | |
headers = { | |
"Authorization": f"KEY {self.api_key}", | |
"Accept": "text/event-stream" | |
} | |
try: | |
response = requests.get( | |
url, headers=headers, stream=True, timeout=timeout) | |
response.raise_for_status() | |
print(f"π SSE Response Status: {response.status_code}") | |
# Parse SSE stream | |
for line in response.iter_lines(decode_unicode=True): | |
if line: | |
print(f"π₯ SSE Line: {line}") | |
# Parse SSE data | |
if line.startswith("data: "): | |
data_str = line[6:] # Remove "data: " prefix | |
if data_str.strip() == "[DONE]": | |
print("β Task completed!") | |
break | |
try: | |
data = json.loads(data_str) | |
print( | |
f"π¦ Parsed data: {json.dumps(data, indent=2)}") | |
# Check if task is completed | |
status = data.get("status", "").upper() | |
if (status in ["COMPLETED", "FINISHED", "SUCCESS"] or | |
"result" in data or | |
"imageUrl" in data or | |
"videoUrl" in data or | |
"url" in data): | |
print("β Task result received!") | |
return data | |
elif status in ["FAILED", "ERROR"]: | |
raise Exception( | |
f"Task failed: {data.get('error', 'Unknown error')}") | |
else: | |
print( | |
f"β³ Task status: {data.get('status', 'processing')}") | |
except json.JSONDecodeError: | |
print(f"β οΈ Could not parse JSON: {data_str}") | |
continue | |
raise Exception("Task did not complete within timeout") | |
except requests.exceptions.RequestException as e: | |
print(f"β SSE request failed: {str(e)}") | |
raise Exception(f"Failed to get task result: {str(e)}") | |
def make_request_with_result(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]: | |
"""Make API request and wait for result""" | |
# First, make the initial request to get task ID | |
response = self.make_request(endpoint, data) | |
if "taskId" not in response: | |
raise Exception("No taskId in response") | |
task_id = response["taskId"] | |
print(f"\nπ― Task ID: {task_id}") | |
# Then get the result | |
return self.get_task_result(task_id, timeout) | |
def validate_url(url: str) -> bool: | |
"""Validate if the provided string is a valid URL""" | |
import re | |
url_pattern = re.compile( | |
r'^https?://' # http:// or https:// | |
# domain... | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' | |
r'localhost|' # localhost... | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
r'(?::\d+)?' # optional port | |
r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
return url_pattern.match(url) is not None | |
def validate_scale(scale: int) -> bool: | |
"""Validate scale parameter for image upscaling""" | |
return scale in TOOLS_CONFIG["image_upscaler"]["scale_options"] | |
def prepare_request_data(tool_name: str, **kwargs) -> Dict[str, Any]: | |
"""Prepare request data based on tool configuration""" | |
if tool_name not in TOOLS_CONFIG: | |
raise ValueError(f"Unknown tool: {tool_name}") | |
config = TOOLS_CONFIG[tool_name] | |
data = {} | |
# Add required parameters | |
for param in config["required_params"]: | |
if param not in kwargs: | |
raise ValueError(f"Missing required parameter: {param}") | |
# Apply parameter mapping if exists | |
param_mapping = config.get("param_mapping", {}) | |
api_param_name = param_mapping.get(param, param) | |
data[api_param_name] = kwargs[param] | |
# Add optional parameters with defaults | |
for param in config.get("optional_params", []): | |
if param in kwargs: | |
# Apply parameter mapping if exists | |
param_mapping = config.get("param_mapping", {}) | |
api_param_name = param_mapping.get(param, param) | |
data[api_param_name] = kwargs[param] | |
elif param in config.get("default_values", {}): | |
# Apply parameter mapping if exists | |
param_mapping = config.get("param_mapping", {}) | |
api_param_name = param_mapping.get(param, param) | |
data[api_param_name] = config["default_values"][param] | |
return data | |
def format_response_with_preview(response: Dict[str, Any], tool_name: str) -> Tuple[str, Optional[str]]: | |
"""Format API response for display with media preview | |
Returns: | |
Tuple of (message, media_url_for_preview) | |
""" | |
if "error" in response: | |
return f"β Error: {response['error']}", None | |
# Handle different response formats | |
result_url = None | |
# Check for A1D API specific fields first | |
result_url = (response.get("imageUrl") or | |
response.get("videoUrl") or | |
response.get("url")) | |
# Then check nested result fields | |
if not result_url and "result" in response: | |
result = response["result"] | |
if isinstance(result, dict): | |
# Try different possible URL fields | |
result_url = (result.get("imageUrl") or | |
result.get("videoUrl") or | |
result.get("url") or | |
result.get("image_url") or | |
result.get("video_url") or | |
result.get("output_url")) | |
elif isinstance(result, str) and result.startswith("http"): | |
result_url = result | |
# Also check other common fields | |
if not result_url: | |
result_url = (response.get("image_url") or | |
response.get("video_url") or | |
response.get("output_url")) | |
if result_url: | |
# Determine media type | |
media_type = "image" | |
if any(ext in result_url.lower() for ext in ['.mp4', '.avi', '.mov', '.webm']): | |
media_type = "video" | |
message = f"β Success! {media_type.title()} generated: {result_url}" | |
return message, result_url | |
return f"β Task completed successfully for {tool_name}", None | |
def format_response(response: Dict[str, Any], tool_name: str) -> str: | |
"""Format API response for display (backward compatibility)""" | |
message, _ = format_response_with_preview(response, tool_name) | |
return message | |
def get_tool_info(tool_name: str) -> Dict[str, Any]: | |
"""Get tool configuration information""" | |
if tool_name not in TOOLS_CONFIG: | |
raise ValueError(f"Unknown tool: {tool_name}") | |
return TOOLS_CONFIG[tool_name] | |