|
"""coordinator.py module.""" |
|
|
|
import tensorflow as tf |
|
from typing import List, Dict, Any, Optional |
|
import numpy as np |
|
from collections import defaultdict |
|
import logging |
|
import time |
|
import threading |
|
from .aggregator import FederatedAggregator |
|
|
|
class FederatedCoordinator: |
|
def __init__(self, config: Dict): |
|
"""Initialize the federated learning coordinator.""" |
|
logger = logging.getLogger(__name__) |
|
logger.debug(f"Initializing FederatedCoordinator with config: {config}") |
|
self.config = config |
|
self.clients = {} |
|
self.client_updates = {} |
|
self.global_model_weights = None |
|
self.current_round = 0 |
|
self.training_active = False |
|
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2) |
|
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10) |
|
|
|
logger.debug(f"Coordinator received config: {config}") |
|
|
|
agg_config = None |
|
if 'aggregation' in config: |
|
agg_config = config |
|
elif 'server' in config and 'aggregation' in config['server']: |
|
agg_config = config['server'] |
|
else: |
|
logger.error(f"No 'aggregation' key found in config for FederatedAggregator: {config}") |
|
raise ValueError("'aggregation' config section is required for FederatedAggregator") |
|
logger.debug(f"Passing aggregation config to FederatedAggregator: {agg_config}") |
|
try: |
|
self.aggregator = FederatedAggregator(agg_config) |
|
except Exception as e: |
|
logger.error(f"Error initializing FederatedAggregator: {e}") |
|
raise |
|
self.lock = threading.Lock() |
|
logger.info("FederatedCoordinator initialized.") |
|
|
|
def register_client(self, client_id: str, client_info: Dict[str, Any] = None) -> bool: |
|
"""Register a new client.""" |
|
with self.lock: |
|
if client_id in self.clients: |
|
logging.getLogger(__name__).warning(f"Client {client_id} already registered") |
|
return True |
|
|
|
self.clients[client_id] = { |
|
'info': client_info or {}, |
|
'last_seen': time.time(), |
|
'metrics': defaultdict(list) |
|
} |
|
|
|
logging.getLogger(__name__).info(f"Client {client_id} registered successfully") |
|
return True |
|
|
|
def get_client_config(self) -> Dict[str, Any]: |
|
"""Get configuration to send to clients""" |
|
return { |
|
'model_config': self.config.get('model', {}), |
|
'training_config': self.config.get('training', {}), |
|
'current_round': self.current_round, |
|
'total_rounds': self.rounds |
|
} |
|
|
|
def get_global_model(self) -> Optional[List]: |
|
"""Get the current global model weights""" |
|
with self.lock: |
|
return self.global_model_weights |
|
|
|
def receive_model_update(self, client_id: str, model_weights: List, metrics: Dict[str, Any]): |
|
"""Receive a model update from a client""" |
|
with self.lock: |
|
if client_id not in self.clients: |
|
raise ValueError(f"Client {client_id} not registered") |
|
|
|
self.client_updates[client_id] = { |
|
'weights': model_weights, |
|
'metrics': metrics, |
|
'timestamp': time.time() |
|
} |
|
|
|
self.clients[client_id]['last_seen'] = time.time() |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Received update from client {client_id}") |
|
|
|
|
|
if len(self.client_updates) >= self.min_clients: |
|
self._aggregate_models() |
|
|
|
def _aggregate_models(self): |
|
"""Aggregate models from all client updates""" |
|
try: |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Aggregating models from {len(self.client_updates)} clients") |
|
|
|
|
|
updates = [] |
|
for client_id, update in self.client_updates.items(): |
|
client_size = update['metrics'].get('dataset_size', 100) |
|
updates.append({ |
|
'client_id': client_id, |
|
'weights': update['weights'], |
|
'size': client_size |
|
}) |
|
|
|
|
|
self.global_model_weights = self.aggregator.federated_averaging(updates) |
|
|
|
|
|
self.client_updates.clear() |
|
self.current_round += 1 |
|
|
|
logger.info(f"Model aggregation completed for round {self.current_round}") |
|
|
|
except Exception as e: |
|
logger = logging.getLogger(__name__) |
|
logger.error(f"Error during model aggregation: {str(e)}") |
|
|
|
def start(self): |
|
"""Start the federated learning process with API server""" |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
logger.info(f"{'Federated Learning Server Starting':^60}") |
|
logger.info("=" * 60) |
|
|
|
|
|
logger.info("\nServer Configuration:") |
|
logger.info("-" * 30) |
|
logger.info(f"Minimum clients required: {self.min_clients}") |
|
logger.info(f"Total rounds planned: {self.rounds}") |
|
active_clients_count = self._count_active_clients() |
|
logger.info(f"Current active clients: {active_clients_count}") |
|
logger.info("-" * 30 + "\n") |
|
|
|
self.training_active = True |
|
|
|
|
|
try: |
|
from ..api.server import FederatedAPI |
|
|
|
api_config = self.config.get('server', {}).get('api', {}) |
|
host = api_config.get('host', '0.0.0.0') |
|
port = api_config.get('port', 8080) |
|
|
|
api_server = FederatedAPI(self, host, port) |
|
api_thread = api_server.run_threaded() |
|
|
|
logger.info(f"API server started on {host}:{port}") |
|
|
|
|
|
try: |
|
while self.training_active and self.current_round < self.rounds: |
|
time.sleep(1) |
|
|
|
|
|
active_clients_count = self._count_active_clients() |
|
if active_clients_count > 0: |
|
logger.debug(f"Round {self.current_round}/{self.rounds}, " |
|
f"Active Clients: {active_clients_count}, " |
|
f"Updates: {len(self.client_updates)}") |
|
|
|
logger.info("Federated learning completed successfully") |
|
|
|
except KeyboardInterrupt: |
|
logger.info("Server shutdown requested") |
|
self.training_active = False |
|
|
|
except ImportError as e: |
|
logger.error(f"Failed to start API server: {str(e)}") |
|
|
|
|
|
|