File size: 5,291 Bytes
fc5fa78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
HTTP Client for Federated Learning
Handles communication with the federated server
"""

import requests
import json
import logging
import time
from typing import Dict, Any, Optional, List

logger = logging.getLogger(__name__)

class FederatedHTTPClient:
    def __init__(self, server_url: str, client_id: str, timeout: int = 30):
        self.server_url = server_url.rstrip('/')
        self.client_id = client_id
        self.timeout = timeout
        self.session = requests.Session()
        
    def register(self, client_info: Dict[str, Any] = None) -> Dict[str, Any]:
        """Register this client with the server"""
        try:
            payload = {
                'client_id': self.client_id,
                'client_info': client_info or {}
            }
            
            response = self.session.post(
                f"{self.server_url}/register",
                json=payload,
                timeout=self.timeout
            )
            response.raise_for_status()
            
            result = response.json()
            logger.info(f"Client {self.client_id} registered successfully")
            return result
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to register client {self.client_id}: {str(e)}")
            raise
    
    def get_global_model(self) -> Dict[str, Any]:
        """Get the current global model from server"""
        try:
            payload = {'client_id': self.client_id}
            
            response = self.session.post(
                f"{self.server_url}/get_model",
                json=payload,
                timeout=self.timeout
            )
            response.raise_for_status()
            
            result = response.json()
            logger.debug(f"Retrieved global model for round {result.get('round', 'unknown')}")
            return result
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to get global model: {str(e)}")
            raise
    
    def submit_model_update(self, model_weights: List, metrics: Dict[str, Any] = None) -> Dict[str, Any]:
        """Submit model update to server"""
        try:
            payload = {
                'client_id': self.client_id,
                'model_weights': model_weights,
                'metrics': metrics or {}
            }
            
            response = self.session.post(
                f"{self.server_url}/submit_update",
                json=payload,
                timeout=self.timeout
            )
            response.raise_for_status()
            
            result = response.json()
            logger.info(f"Model update submitted successfully by client {self.client_id}")
            return result
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to submit model update: {str(e)}")
            raise
    
    def get_training_status(self) -> Dict[str, Any]:
        """Get current training status from server"""
        try:
            response = self.session.get(
                f"{self.server_url}/training_status",
                timeout=self.timeout
            )
            response.raise_for_status()
            
            return response.json()
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to get training status: {str(e)}")
            raise
    
    def health_check(self) -> bool:
        """Check if server is healthy"""
        try:
            response = self.session.get(
                f"{self.server_url}/health",
                timeout=5  # Short timeout for health checks
            )
            response.raise_for_status()
            
            result = response.json()
            return result.get('status') == 'healthy'
            
        except requests.exceptions.RequestException:
            return False
    
    def wait_for_server(self, max_wait: int = 60, check_interval: int = 5) -> bool:
        """Wait for server to become available"""
        start_time = time.time()
        
        while time.time() - start_time < max_wait:
            if self.health_check():
                logger.info(f"Server is available at {self.server_url}")
                return True
            
            logger.info(f"Waiting for server at {self.server_url}...")
            time.sleep(check_interval)
        
        logger.error(f"Server not available after {max_wait} seconds")
        return False
    
    def rag_query(self, query: str) -> Dict[str, Any]:
        """Submit a RAG query to the server"""
        try:
            payload = {
                'query': query,
                'client_id': self.client_id
            }
            
            response = self.session.post(
                f"{self.server_url}/rag/query",
                json=payload,
                timeout=self.timeout
            )
            response.raise_for_status()
            
            return response.json()
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Failed to submit RAG query: {str(e)}")
            raise
    
    def close(self):
        """Close the HTTP session"""
        self.session.close()