|
""" |
|
HTTP Client for Federated Learning |
|
Handles communication with the federated server |
|
""" |
|
|
|
import requests |
|
import json |
|
import logging |
|
import time |
|
from typing import Dict, Any, Optional, List |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class FederatedHTTPClient: |
|
def __init__(self, server_url: str, client_id: str, timeout: int = 30): |
|
self.server_url = server_url.rstrip('/') |
|
self.client_id = client_id |
|
self.timeout = timeout |
|
self.session = requests.Session() |
|
|
|
def register(self, client_info: Dict[str, Any] = None) -> Dict[str, Any]: |
|
"""Register this client with the server""" |
|
try: |
|
payload = { |
|
'client_id': self.client_id, |
|
'client_info': client_info or {} |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.server_url}/register", |
|
json=payload, |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
|
|
result = response.json() |
|
logger.info(f"Client {self.client_id} registered successfully") |
|
return result |
|
|
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to register client {self.client_id}: {str(e)}") |
|
raise |
|
|
|
def get_global_model(self) -> Dict[str, Any]: |
|
"""Get the current global model from server""" |
|
try: |
|
payload = {'client_id': self.client_id} |
|
|
|
response = self.session.post( |
|
f"{self.server_url}/get_model", |
|
json=payload, |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
|
|
result = response.json() |
|
logger.debug(f"Retrieved global model for round {result.get('round', 'unknown')}") |
|
return result |
|
|
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to get global model: {str(e)}") |
|
raise |
|
|
|
def submit_model_update(self, model_weights: List, metrics: Dict[str, Any] = None) -> Dict[str, Any]: |
|
"""Submit model update to server""" |
|
try: |
|
payload = { |
|
'client_id': self.client_id, |
|
'model_weights': model_weights, |
|
'metrics': metrics or {} |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.server_url}/submit_update", |
|
json=payload, |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
|
|
result = response.json() |
|
logger.info(f"Model update submitted successfully by client {self.client_id}") |
|
return result |
|
|
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to submit model update: {str(e)}") |
|
raise |
|
|
|
def get_training_status(self) -> Dict[str, Any]: |
|
"""Get current training status from server""" |
|
try: |
|
response = self.session.get( |
|
f"{self.server_url}/training_status", |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
|
|
return response.json() |
|
|
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to get training status: {str(e)}") |
|
raise |
|
|
|
def health_check(self) -> bool: |
|
"""Check if server is healthy""" |
|
try: |
|
response = self.session.get( |
|
f"{self.server_url}/health", |
|
timeout=5 |
|
) |
|
response.raise_for_status() |
|
|
|
result = response.json() |
|
return result.get('status') == 'healthy' |
|
|
|
except requests.exceptions.RequestException: |
|
return False |
|
|
|
def wait_for_server(self, max_wait: int = 60, check_interval: int = 5) -> bool: |
|
"""Wait for server to become available""" |
|
start_time = time.time() |
|
|
|
while time.time() - start_time < max_wait: |
|
if self.health_check(): |
|
logger.info(f"Server is available at {self.server_url}") |
|
return True |
|
|
|
logger.info(f"Waiting for server at {self.server_url}...") |
|
time.sleep(check_interval) |
|
|
|
logger.error(f"Server not available after {max_wait} seconds") |
|
return False |
|
|
|
def rag_query(self, query: str) -> Dict[str, Any]: |
|
"""Submit a RAG query to the server""" |
|
try: |
|
payload = { |
|
'query': query, |
|
'client_id': self.client_id |
|
} |
|
|
|
response = self.session.post( |
|
f"{self.server_url}/rag/query", |
|
json=payload, |
|
timeout=self.timeout |
|
) |
|
response.raise_for_status() |
|
|
|
return response.json() |
|
|
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to submit RAG query: {str(e)}") |
|
raise |
|
|
|
def close(self): |
|
"""Close the HTTP session""" |
|
self.session.close() |
|
|