Spaces:
Sleeping
Sleeping
File size: 5,291 Bytes
fc5fa78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
HTTP Client for Federated Learning
Handles communication with the federated server
"""
import requests
import json
import logging
import time
from typing import Dict, Any, Optional, List
logger = logging.getLogger(__name__)
class FederatedHTTPClient:
def __init__(self, server_url: str, client_id: str, timeout: int = 30):
self.server_url = server_url.rstrip('/')
self.client_id = client_id
self.timeout = timeout
self.session = requests.Session()
def register(self, client_info: Dict[str, Any] = None) -> Dict[str, Any]:
"""Register this client with the server"""
try:
payload = {
'client_id': self.client_id,
'client_info': client_info or {}
}
response = self.session.post(
f"{self.server_url}/register",
json=payload,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
logger.info(f"Client {self.client_id} registered successfully")
return result
except requests.exceptions.RequestException as e:
logger.error(f"Failed to register client {self.client_id}: {str(e)}")
raise
def get_global_model(self) -> Dict[str, Any]:
"""Get the current global model from server"""
try:
payload = {'client_id': self.client_id}
response = self.session.post(
f"{self.server_url}/get_model",
json=payload,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
logger.debug(f"Retrieved global model for round {result.get('round', 'unknown')}")
return result
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get global model: {str(e)}")
raise
def submit_model_update(self, model_weights: List, metrics: Dict[str, Any] = None) -> Dict[str, Any]:
"""Submit model update to server"""
try:
payload = {
'client_id': self.client_id,
'model_weights': model_weights,
'metrics': metrics or {}
}
response = self.session.post(
f"{self.server_url}/submit_update",
json=payload,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
logger.info(f"Model update submitted successfully by client {self.client_id}")
return result
except requests.exceptions.RequestException as e:
logger.error(f"Failed to submit model update: {str(e)}")
raise
def get_training_status(self) -> Dict[str, Any]:
"""Get current training status from server"""
try:
response = self.session.get(
f"{self.server_url}/training_status",
timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get training status: {str(e)}")
raise
def health_check(self) -> bool:
"""Check if server is healthy"""
try:
response = self.session.get(
f"{self.server_url}/health",
timeout=5 # Short timeout for health checks
)
response.raise_for_status()
result = response.json()
return result.get('status') == 'healthy'
except requests.exceptions.RequestException:
return False
def wait_for_server(self, max_wait: int = 60, check_interval: int = 5) -> bool:
"""Wait for server to become available"""
start_time = time.time()
while time.time() - start_time < max_wait:
if self.health_check():
logger.info(f"Server is available at {self.server_url}")
return True
logger.info(f"Waiting for server at {self.server_url}...")
time.sleep(check_interval)
logger.error(f"Server not available after {max_wait} seconds")
return False
def rag_query(self, query: str) -> Dict[str, Any]:
"""Submit a RAG query to the server"""
try:
payload = {
'query': query,
'client_id': self.client_id
}
response = self.session.post(
f"{self.server_url}/rag/query",
json=payload,
timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Failed to submit RAG query: {str(e)}")
raise
def close(self):
"""Close the HTTP session"""
self.session.close()
|