File size: 11,320 Bytes
aaa3e82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
Utility functions for A1D MCP Server
Handles API calls and data processing
"""

import requests
import json
import os
import time
import re
from typing import Dict, Any, Optional, Tuple
from config import A1D_API_BASE_URL, API_KEY, TOOLS_CONFIG


class A1DAPIClient:
    """Client for making API calls to A1D services"""

    def __init__(self, api_key: Optional[str] = None):
        # Try to get API key from multiple sources
        self.api_key = api_key or self._get_api_key()
        self.base_url = A1D_API_BASE_URL
        self.session = requests.Session()

        if not self.api_key:
            raise ValueError(
                "API key is required. Set A1D_API_KEY environment variable, pass it directly, or provide via MCP header.")

        # Set default headers
        self.session.headers.update({
            "Authorization": f"KEY {self.api_key}",
            "Content-Type": "application/json",
            "User-Agent": "A1D-MCP-Server/1.0.0"
        })

    def _get_api_key(self) -> Optional[str]:
        """Get API key from various sources"""
        # 1. Environment variable
        api_key = API_KEY
        if api_key:
            return api_key

        # 2. Try to get from Gradio request headers (if available)
        try:
            import gradio as gr
            request = gr.request()
            if request and hasattr(request, 'headers'):
                # Check for API_KEY header from MCP client
                api_key = request.headers.get(
                    'API_KEY') or request.headers.get('api_key')
                if api_key:
                    print(f"πŸ“‘ Using API key from MCP client header")
                    return api_key
        except:
            pass

        return None

    def make_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
        """Make API request to A1D service"""
        url = f"{self.base_url}{endpoint}"

        # Add source field to all requests
        request_data = {**data, "source": "mcp"}

        # Print detailed request information
        print("\n" + "="*60)
        print("πŸš€ A1D API REQUEST DEBUG INFO")
        print("="*60)
        print(f"πŸ“‘ URL: {url}")
        print(f"πŸ”§ Method: POST")

        print(f"\nπŸ“‹ Headers:")
        for key, value in self.session.headers.items():
            # Mask API key for security
            if key.lower() in ['api_key', 'authorization']:
                masked_value = f"{value[:8]}..." if len(value) > 8 else "***"
                print(f"   {key}: {masked_value}")
            else:
                print(f"   {key}: {value}")

        print(f"\nπŸ“¦ Request Body:")
        print(f"   {json.dumps(request_data, indent=2)}")

        try:
            print(f"\n⏳ Sending request...")
            response = self.session.post(url, json=request_data, timeout=30)

            print(f"\nπŸ“Š Response Info:")
            print(f"   Status Code: {response.status_code}")
            print(f"   Status Text: {response.reason}")

            print(f"\nπŸ“‹ Response Headers:")
            for key, value in response.headers.items():
                print(f"   {key}: {value}")

            print(f"\nπŸ“¦ Response Body:")
            try:
                response_json = response.json()
                print(f"   {json.dumps(response_json, indent=2)}")
            except:
                print(f"   {response.text[:500]}...")

            print("="*60)

            response.raise_for_status()
            return response.json()

        except requests.exceptions.RequestException as e:
            print(f"\n❌ Request failed: {str(e)}")
            print("="*60)
            raise Exception(f"API request failed: {str(e)}")
        except json.JSONDecodeError as e:
            print(f"\n❌ JSON decode failed: {str(e)}")
            print("="*60)
            raise Exception(f"Failed to parse API response: {str(e)}")

    def get_task_result(self, task_id: str, timeout: int = 60) -> Dict[str, Any]:
        """Get task result using SSE endpoint"""
        url = f"{self.base_url}/api/task/{task_id}/sse"

        print(f"\nπŸ”„ Getting task result...")
        print(f"πŸ“‘ SSE URL: {url}")
        print(f"⏱️ Timeout: {timeout}s")

        headers = {
            "Authorization": f"KEY {self.api_key}",
            "Accept": "text/event-stream"
        }

        try:
            response = requests.get(
                url, headers=headers, stream=True, timeout=timeout)
            response.raise_for_status()

            print(f"πŸ“Š SSE Response Status: {response.status_code}")

            # Parse SSE stream
            for line in response.iter_lines(decode_unicode=True):
                if line:
                    print(f"πŸ“₯ SSE Line: {line}")

                    # Parse SSE data
                    if line.startswith("data: "):
                        data_str = line[6:]  # Remove "data: " prefix
                        if data_str.strip() == "[DONE]":
                            print("βœ… Task completed!")
                            break

                        try:
                            data = json.loads(data_str)
                            print(
                                f"πŸ“¦ Parsed data: {json.dumps(data, indent=2)}")

                            # Check if task is completed
                            status = data.get("status", "").upper()
                            if (status in ["COMPLETED", "FINISHED", "SUCCESS"] or
                                "result" in data or
                                "imageUrl" in data or
                                "videoUrl" in data or
                                    "url" in data):
                                print("βœ… Task result received!")
                                return data
                            elif status in ["FAILED", "ERROR"]:
                                raise Exception(
                                    f"Task failed: {data.get('error', 'Unknown error')}")
                            else:
                                print(
                                    f"⏳ Task status: {data.get('status', 'processing')}")

                        except json.JSONDecodeError:
                            print(f"⚠️ Could not parse JSON: {data_str}")
                            continue

            raise Exception("Task did not complete within timeout")

        except requests.exceptions.RequestException as e:
            print(f"❌ SSE request failed: {str(e)}")
            raise Exception(f"Failed to get task result: {str(e)}")

    def make_request_with_result(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]:
        """Make API request and wait for result"""
        # First, make the initial request to get task ID
        response = self.make_request(endpoint, data)

        if "taskId" not in response:
            raise Exception("No taskId in response")

        task_id = response["taskId"]
        print(f"\n🎯 Task ID: {task_id}")

        # Then get the result
        return self.get_task_result(task_id, timeout)


def validate_url(url: str) -> bool:
    """Validate if the provided string is a valid URL"""
    import re
    url_pattern = re.compile(
        r'^https?://'  # http:// or https://
        # domain...
        r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
        r'localhost|'  # localhost...
        r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'  # ...or ip
        r'(?::\d+)?'  # optional port
        r'(?:/?|[/?]\S+)$', re.IGNORECASE)
    return url_pattern.match(url) is not None


def validate_scale(scale: int) -> bool:
    """Validate scale parameter for image upscaling"""
    return scale in TOOLS_CONFIG["image_upscaler"]["scale_options"]


def prepare_request_data(tool_name: str, **kwargs) -> Dict[str, Any]:
    """Prepare request data based on tool configuration"""
    if tool_name not in TOOLS_CONFIG:
        raise ValueError(f"Unknown tool: {tool_name}")

    config = TOOLS_CONFIG[tool_name]
    data = {}

    # Add required parameters
    for param in config["required_params"]:
        if param not in kwargs:
            raise ValueError(f"Missing required parameter: {param}")

        # Apply parameter mapping if exists
        param_mapping = config.get("param_mapping", {})
        api_param_name = param_mapping.get(param, param)
        data[api_param_name] = kwargs[param]

    # Add optional parameters with defaults
    for param in config.get("optional_params", []):
        if param in kwargs:
            # Apply parameter mapping if exists
            param_mapping = config.get("param_mapping", {})
            api_param_name = param_mapping.get(param, param)
            data[api_param_name] = kwargs[param]
        elif param in config.get("default_values", {}):
            # Apply parameter mapping if exists
            param_mapping = config.get("param_mapping", {})
            api_param_name = param_mapping.get(param, param)
            data[api_param_name] = config["default_values"][param]

    return data


def format_response_with_preview(response: Dict[str, Any], tool_name: str) -> Tuple[str, Optional[str]]:
    """Format API response for display with media preview

    Returns:
        Tuple of (message, media_url_for_preview)
    """
    if "error" in response:
        return f"❌ Error: {response['error']}", None

    # Handle different response formats
    result_url = None

    # Check for A1D API specific fields first
    result_url = (response.get("imageUrl") or
                  response.get("videoUrl") or
                  response.get("url"))

    # Then check nested result fields
    if not result_url and "result" in response:
        result = response["result"]
        if isinstance(result, dict):
            # Try different possible URL fields
            result_url = (result.get("imageUrl") or
                          result.get("videoUrl") or
                          result.get("url") or
                          result.get("image_url") or
                          result.get("video_url") or
                          result.get("output_url"))
        elif isinstance(result, str) and result.startswith("http"):
            result_url = result

    # Also check other common fields
    if not result_url:
        result_url = (response.get("image_url") or
                      response.get("video_url") or
                      response.get("output_url"))

    if result_url:
        # Determine media type
        media_type = "image"
        if any(ext in result_url.lower() for ext in ['.mp4', '.avi', '.mov', '.webm']):
            media_type = "video"

        message = f"βœ… Success! {media_type.title()} generated: {result_url}"
        return message, result_url

    return f"βœ… Task completed successfully for {tool_name}", None


def format_response(response: Dict[str, Any], tool_name: str) -> str:
    """Format API response for display (backward compatibility)"""
    message, _ = format_response_with_preview(response, tool_name)
    return message


def get_tool_info(tool_name: str) -> Dict[str, Any]:
    """Get tool configuration information"""
    if tool_name not in TOOLS_CONFIG:
        raise ValueError(f"Unknown tool: {tool_name}")

    return TOOLS_CONFIG[tool_name]