|
"""coordinator.py module.""" |
|
|
|
import tensorflow as tf |
|
from typing import List, Dict |
|
import numpy as np |
|
from collections import defaultdict |
|
import logging |
|
import time |
|
|
|
class FederatedCoordinator: |
|
def __init__(self, config: Dict): |
|
"""Initialize the federated learning coordinator.""" |
|
self.config = config |
|
self.clients = {} |
|
self.current_round = 0 |
|
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2) |
|
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10) |
|
|
|
def register_client(self, client_id: int, client_size: int): |
|
"""Register a new client.""" |
|
self.clients[client_id] = { |
|
'size': client_size, |
|
'weights': None, |
|
'metrics': defaultdict(list) |
|
} |
|
|
|
def aggregate_weights(self, client_updates: List[Dict]) -> List: |
|
"""Aggregate weights using FedAvg algorithm.""" |
|
total_size = sum(self.clients[update['client_id']]['size'] |
|
for update in client_updates) |
|
|
|
aggregated_weights = [ |
|
np.zeros_like(w) for w in client_updates[0]['weights'] |
|
] |
|
|
|
for update in client_updates: |
|
client_size = self.clients[update['client_id']]['size'] |
|
weight = client_size / total_size |
|
|
|
for i, layer_weights in enumerate(update['weights']): |
|
aggregated_weights[i] += layer_weights * weight |
|
|
|
return aggregated_weights |
|
|
|
def start(self): |
|
"""Start the federated learning process.""" |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
logger.info(f"{'Federated Learning Server Starting':^60}") |
|
logger.info("=" * 60) |
|
|
|
|
|
logger.info("\nServer Configuration:") |
|
logger.info("-" * 30) |
|
logger.info(f"Minimum clients required: {self.min_clients}") |
|
logger.info(f"Total rounds planned: {self.rounds}") |
|
logger.info(f"Current active clients: {len(self.clients)}") |
|
logger.info("-" * 30 + "\n") |
|
|
|
while self.current_round < self.rounds: |
|
round_num = self.current_round + 1 |
|
logger.info(f"\nRound {round_num}/{self.rounds}") |
|
logger.info("-" * 30) |
|
|
|
if len(self.clients) < self.min_clients: |
|
logger.warning( |
|
f"Waiting for clients... " |
|
f"(active: {len(self.clients)}/{self.min_clients})" |
|
) |
|
time.sleep(5) |
|
continue |
|
|
|
logger.info(f"Active clients: {list(self.clients.keys())}") |
|
logger.info(f"Starting training round {round_num}") |
|
self.current_round += 1 |
|
|