Spaces:
Sleeping
Sleeping
File size: 2,887 Bytes
754afec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""coordinator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
import logging
import time
class FederatedCoordinator:
def __init__(self, config: Dict):
"""Initialize the federated learning coordinator."""
self.config = config
self.clients = {}
self.current_round = 0
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
def register_client(self, client_id: int, client_size: int):
"""Register a new client."""
self.clients[client_id] = {
'size': client_size,
'weights': None,
'metrics': defaultdict(list)
}
def aggregate_weights(self, client_updates: List[Dict]) -> List:
"""Aggregate weights using FedAvg algorithm."""
total_size = sum(self.clients[update['client_id']]['size']
for update in client_updates)
aggregated_weights = [
np.zeros_like(w) for w in client_updates[0]['weights']
]
for update in client_updates:
client_size = self.clients[update['client_id']]['size']
weight = client_size / total_size
for i, layer_weights in enumerate(update['weights']):
aggregated_weights[i] += layer_weights * weight
return aggregated_weights
def start(self):
"""Start the federated learning process."""
logger = logging.getLogger(__name__)
# Print server startup information
logger.info("\n" + "=" * 60)
logger.info(f"{'Federated Learning Server Starting':^60}")
logger.info("=" * 60)
# Print configuration details
logger.info("\nServer Configuration:")
logger.info("-" * 30)
logger.info(f"Minimum clients required: {self.min_clients}")
logger.info(f"Total rounds planned: {self.rounds}")
logger.info(f"Current active clients: {len(self.clients)}")
logger.info("-" * 30 + "\n")
while self.current_round < self.rounds:
round_num = self.current_round + 1
logger.info(f"\nRound {round_num}/{self.rounds}")
logger.info("-" * 30)
if len(self.clients) < self.min_clients:
logger.warning(
f"Waiting for clients... "
f"(active: {len(self.clients)}/{self.min_clients})"
)
time.sleep(5)
continue
logger.info(f"Active clients: {list(self.clients.keys())}")
logger.info(f"Starting training round {round_num}")
self.current_round += 1
|