a1d-mcp-server / utils.py
yuxh1996's picture
Initial commit: A1D MCP Server with Gradio interface
aaa3e82
"""
Utility functions for A1D MCP Server
Handles API calls and data processing
"""
import requests
import json
import os
import time
import re
from typing import Dict, Any, Optional, Tuple
from config import A1D_API_BASE_URL, API_KEY, TOOLS_CONFIG
class A1DAPIClient:
"""Client for making API calls to A1D services"""
def __init__(self, api_key: Optional[str] = None):
# Try to get API key from multiple sources
self.api_key = api_key or self._get_api_key()
self.base_url = A1D_API_BASE_URL
self.session = requests.Session()
if not self.api_key:
raise ValueError(
"API key is required. Set A1D_API_KEY environment variable, pass it directly, or provide via MCP header.")
# Set default headers
self.session.headers.update({
"Authorization": f"KEY {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "A1D-MCP-Server/1.0.0"
})
def _get_api_key(self) -> Optional[str]:
"""Get API key from various sources"""
# 1. Environment variable
api_key = API_KEY
if api_key:
return api_key
# 2. Try to get from Gradio request headers (if available)
try:
import gradio as gr
request = gr.request()
if request and hasattr(request, 'headers'):
# Check for API_KEY header from MCP client
api_key = request.headers.get(
'API_KEY') or request.headers.get('api_key')
if api_key:
print(f"πŸ“‘ Using API key from MCP client header")
return api_key
except:
pass
return None
def make_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Make API request to A1D service"""
url = f"{self.base_url}{endpoint}"
# Add source field to all requests
request_data = {**data, "source": "mcp"}
# Print detailed request information
print("\n" + "="*60)
print("πŸš€ A1D API REQUEST DEBUG INFO")
print("="*60)
print(f"πŸ“‘ URL: {url}")
print(f"πŸ”§ Method: POST")
print(f"\nπŸ“‹ Headers:")
for key, value in self.session.headers.items():
# Mask API key for security
if key.lower() in ['api_key', 'authorization']:
masked_value = f"{value[:8]}..." if len(value) > 8 else "***"
print(f" {key}: {masked_value}")
else:
print(f" {key}: {value}")
print(f"\nπŸ“¦ Request Body:")
print(f" {json.dumps(request_data, indent=2)}")
try:
print(f"\n⏳ Sending request...")
response = self.session.post(url, json=request_data, timeout=30)
print(f"\nπŸ“Š Response Info:")
print(f" Status Code: {response.status_code}")
print(f" Status Text: {response.reason}")
print(f"\nπŸ“‹ Response Headers:")
for key, value in response.headers.items():
print(f" {key}: {value}")
print(f"\nπŸ“¦ Response Body:")
try:
response_json = response.json()
print(f" {json.dumps(response_json, indent=2)}")
except:
print(f" {response.text[:500]}...")
print("="*60)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"\n❌ Request failed: {str(e)}")
print("="*60)
raise Exception(f"API request failed: {str(e)}")
except json.JSONDecodeError as e:
print(f"\n❌ JSON decode failed: {str(e)}")
print("="*60)
raise Exception(f"Failed to parse API response: {str(e)}")
def get_task_result(self, task_id: str, timeout: int = 60) -> Dict[str, Any]:
"""Get task result using SSE endpoint"""
url = f"{self.base_url}/api/task/{task_id}/sse"
print(f"\nπŸ”„ Getting task result...")
print(f"πŸ“‘ SSE URL: {url}")
print(f"⏱️ Timeout: {timeout}s")
headers = {
"Authorization": f"KEY {self.api_key}",
"Accept": "text/event-stream"
}
try:
response = requests.get(
url, headers=headers, stream=True, timeout=timeout)
response.raise_for_status()
print(f"πŸ“Š SSE Response Status: {response.status_code}")
# Parse SSE stream
for line in response.iter_lines(decode_unicode=True):
if line:
print(f"πŸ“₯ SSE Line: {line}")
# Parse SSE data
if line.startswith("data: "):
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
print("βœ… Task completed!")
break
try:
data = json.loads(data_str)
print(
f"πŸ“¦ Parsed data: {json.dumps(data, indent=2)}")
# Check if task is completed
status = data.get("status", "").upper()
if (status in ["COMPLETED", "FINISHED", "SUCCESS"] or
"result" in data or
"imageUrl" in data or
"videoUrl" in data or
"url" in data):
print("βœ… Task result received!")
return data
elif status in ["FAILED", "ERROR"]:
raise Exception(
f"Task failed: {data.get('error', 'Unknown error')}")
else:
print(
f"⏳ Task status: {data.get('status', 'processing')}")
except json.JSONDecodeError:
print(f"⚠️ Could not parse JSON: {data_str}")
continue
raise Exception("Task did not complete within timeout")
except requests.exceptions.RequestException as e:
print(f"❌ SSE request failed: {str(e)}")
raise Exception(f"Failed to get task result: {str(e)}")
def make_request_with_result(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]:
"""Make API request and wait for result"""
# First, make the initial request to get task ID
response = self.make_request(endpoint, data)
if "taskId" not in response:
raise Exception("No taskId in response")
task_id = response["taskId"]
print(f"\n🎯 Task ID: {task_id}")
# Then get the result
return self.get_task_result(task_id, timeout)
def validate_url(url: str) -> bool:
"""Validate if the provided string is a valid URL"""
import re
url_pattern = re.compile(
r'^https?://' # http:// or https://
# domain...
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
return url_pattern.match(url) is not None
def validate_scale(scale: int) -> bool:
"""Validate scale parameter for image upscaling"""
return scale in TOOLS_CONFIG["image_upscaler"]["scale_options"]
def prepare_request_data(tool_name: str, **kwargs) -> Dict[str, Any]:
"""Prepare request data based on tool configuration"""
if tool_name not in TOOLS_CONFIG:
raise ValueError(f"Unknown tool: {tool_name}")
config = TOOLS_CONFIG[tool_name]
data = {}
# Add required parameters
for param in config["required_params"]:
if param not in kwargs:
raise ValueError(f"Missing required parameter: {param}")
# Apply parameter mapping if exists
param_mapping = config.get("param_mapping", {})
api_param_name = param_mapping.get(param, param)
data[api_param_name] = kwargs[param]
# Add optional parameters with defaults
for param in config.get("optional_params", []):
if param in kwargs:
# Apply parameter mapping if exists
param_mapping = config.get("param_mapping", {})
api_param_name = param_mapping.get(param, param)
data[api_param_name] = kwargs[param]
elif param in config.get("default_values", {}):
# Apply parameter mapping if exists
param_mapping = config.get("param_mapping", {})
api_param_name = param_mapping.get(param, param)
data[api_param_name] = config["default_values"][param]
return data
def format_response_with_preview(response: Dict[str, Any], tool_name: str) -> Tuple[str, Optional[str]]:
"""Format API response for display with media preview
Returns:
Tuple of (message, media_url_for_preview)
"""
if "error" in response:
return f"❌ Error: {response['error']}", None
# Handle different response formats
result_url = None
# Check for A1D API specific fields first
result_url = (response.get("imageUrl") or
response.get("videoUrl") or
response.get("url"))
# Then check nested result fields
if not result_url and "result" in response:
result = response["result"]
if isinstance(result, dict):
# Try different possible URL fields
result_url = (result.get("imageUrl") or
result.get("videoUrl") or
result.get("url") or
result.get("image_url") or
result.get("video_url") or
result.get("output_url"))
elif isinstance(result, str) and result.startswith("http"):
result_url = result
# Also check other common fields
if not result_url:
result_url = (response.get("image_url") or
response.get("video_url") or
response.get("output_url"))
if result_url:
# Determine media type
media_type = "image"
if any(ext in result_url.lower() for ext in ['.mp4', '.avi', '.mov', '.webm']):
media_type = "video"
message = f"βœ… Success! {media_type.title()} generated: {result_url}"
return message, result_url
return f"βœ… Task completed successfully for {tool_name}", None
def format_response(response: Dict[str, Any], tool_name: str) -> str:
"""Format API response for display (backward compatibility)"""
message, _ = format_response_with_preview(response, tool_name)
return message
def get_tool_info(tool_name: str) -> Dict[str, Any]:
"""Get tool configuration information"""
if tool_name not in TOOLS_CONFIG:
raise ValueError(f"Unknown tool: {tool_name}")
return TOOLS_CONFIG[tool_name]