File size: 7,493 Bytes
754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 754afec fc5fa78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""coordinator.py module."""
import tensorflow as tf
from typing import List, Dict, Any, Optional
import numpy as np
from collections import defaultdict
import logging
import time
import threading
from .aggregator import FederatedAggregator
class FederatedCoordinator:
def __init__(self, config: Dict):
"""Initialize the federated learning coordinator."""
logger = logging.getLogger(__name__)
logger.debug(f"Initializing FederatedCoordinator with config: {config}")
self.config = config
self.clients = {}
self.client_updates = {} # Store updates for current round
self.global_model_weights = None
self.current_round = 0
self.training_active = False
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
# Debug: log config structure
logger.debug(f"Coordinator received config: {config}")
# Robustly extract aggregation config
agg_config = None
if 'aggregation' in config:
agg_config = config
elif 'server' in config and 'aggregation' in config['server']:
agg_config = config['server']
else:
logger.error(f"No 'aggregation' key found in config for FederatedAggregator: {config}")
raise ValueError("'aggregation' config section is required for FederatedAggregator")
logger.debug(f"Passing aggregation config to FederatedAggregator: {agg_config}")
try:
self.aggregator = FederatedAggregator(agg_config)
except Exception as e:
logger.error(f"Error initializing FederatedAggregator: {e}")
raise
self.lock = threading.Lock() # Thread safety for concurrent API calls
logger.info("FederatedCoordinator initialized.")
def register_client(self, client_id: str, client_info: Dict[str, Any] = None) -> bool:
"""Register a new client."""
with self.lock:
if client_id in self.clients:
logging.getLogger(__name__).warning(f"Client {client_id} already registered")
return True
self.clients[client_id] = {
'info': client_info or {},
'last_seen': time.time(),
'metrics': defaultdict(list)
}
logging.getLogger(__name__).info(f"Client {client_id} registered successfully")
return True
def get_client_config(self) -> Dict[str, Any]:
"""Get configuration to send to clients"""
return {
'model_config': self.config.get('model', {}),
'training_config': self.config.get('training', {}),
'current_round': self.current_round,
'total_rounds': self.rounds
}
def get_global_model(self) -> Optional[List]:
"""Get the current global model weights"""
with self.lock:
return self.global_model_weights
def receive_model_update(self, client_id: str, model_weights: List, metrics: Dict[str, Any]):
"""Receive a model update from a client"""
with self.lock:
if client_id not in self.clients:
raise ValueError(f"Client {client_id} not registered")
self.client_updates[client_id] = {
'weights': model_weights,
'metrics': metrics,
'timestamp': time.time()
}
self.clients[client_id]['last_seen'] = time.time()
logger = logging.getLogger(__name__)
logger.info(f"Received update from client {client_id}")
# Check if we have enough updates for aggregation
if len(self.client_updates) >= self.min_clients:
self._aggregate_models()
def _aggregate_models(self):
"""Aggregate models from all client updates"""
try:
logger = logging.getLogger(__name__)
logger.info(f"Aggregating models from {len(self.client_updates)} clients")
# Prepare updates for aggregation
updates = []
for client_id, update in self.client_updates.items():
client_size = update['metrics'].get('dataset_size', 100) # Default size
updates.append({
'client_id': client_id,
'weights': update['weights'],
'size': client_size
})
# Aggregate using FedAvg
self.global_model_weights = self.aggregator.federated_averaging(updates)
# Clear updates for next round
self.client_updates.clear()
self.current_round += 1
logger.info(f"Model aggregation completed for round {self.current_round}")
except Exception as e:
logger = logging.getLogger(__name__)
logger.error(f"Error during model aggregation: {str(e)}")
def start(self):
"""Start the federated learning process with API server"""
logger = logging.getLogger(__name__)
# Print server startup information
logger.info("\n" + "=" * 60)
logger.info(f"{'Federated Learning Server Starting':^60}")
logger.info("=" * 60)
# Print configuration details
logger.info("\nServer Configuration:")
logger.info("-" * 30)
logger.info(f"Minimum clients required: {self.min_clients}")
logger.info(f"Total rounds planned: {self.rounds}")
active_clients_count = self._count_active_clients()
logger.info(f"Current active clients: {active_clients_count}")
logger.info("-" * 30 + "\n")
self.training_active = True
# Import and start API server
try:
from ..api.server import FederatedAPI
api_config = self.config.get('server', {}).get('api', {})
host = api_config.get('host', '0.0.0.0')
port = api_config.get('port', 8080)
api_server = FederatedAPI(self, host, port)
api_thread = api_server.run_threaded()
logger.info(f"API server started on {host}:{port}")
# Keep server running
try:
while self.training_active and self.current_round < self.rounds:
time.sleep(1) # Keep main thread alive
# Log progress periodically
active_clients_count = self._count_active_clients()
if active_clients_count > 0:
logger.debug(f"Round {self.current_round}/{self.rounds}, "
f"Active Clients: {active_clients_count}, "
f"Updates: {len(self.client_updates)}")
logger.info("Federated learning completed successfully")
except KeyboardInterrupt:
logger.info("Server shutdown requested")
self.training_active = False
except ImportError as e:
logger.error(f"Failed to start API server: {str(e)}")
# Fallback to original behavior
# ...existing code...
|