“Transcendental-Programmer”
FEAT: added server coordination and model aggregation logic
754afec
raw
history blame
2.89 kB
"""coordinator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
import logging
import time
class FederatedCoordinator:
def __init__(self, config: Dict):
"""Initialize the federated learning coordinator."""
self.config = config
self.clients = {}
self.current_round = 0
self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
def register_client(self, client_id: int, client_size: int):
"""Register a new client."""
self.clients[client_id] = {
'size': client_size,
'weights': None,
'metrics': defaultdict(list)
}
def aggregate_weights(self, client_updates: List[Dict]) -> List:
"""Aggregate weights using FedAvg algorithm."""
total_size = sum(self.clients[update['client_id']]['size']
for update in client_updates)
aggregated_weights = [
np.zeros_like(w) for w in client_updates[0]['weights']
]
for update in client_updates:
client_size = self.clients[update['client_id']]['size']
weight = client_size / total_size
for i, layer_weights in enumerate(update['weights']):
aggregated_weights[i] += layer_weights * weight
return aggregated_weights
def start(self):
"""Start the federated learning process."""
logger = logging.getLogger(__name__)
# Print server startup information
logger.info("\n" + "=" * 60)
logger.info(f"{'Federated Learning Server Starting':^60}")
logger.info("=" * 60)
# Print configuration details
logger.info("\nServer Configuration:")
logger.info("-" * 30)
logger.info(f"Minimum clients required: {self.min_clients}")
logger.info(f"Total rounds planned: {self.rounds}")
logger.info(f"Current active clients: {len(self.clients)}")
logger.info("-" * 30 + "\n")
while self.current_round < self.rounds:
round_num = self.current_round + 1
logger.info(f"\nRound {round_num}/{self.rounds}")
logger.info("-" * 30)
if len(self.clients) < self.min_clients:
logger.warning(
f"Waiting for clients... "
f"(active: {len(self.clients)}/{self.min_clients})"
)
time.sleep(5)
continue
logger.info(f"Active clients: {list(self.clients.keys())}")
logger.info(f"Starting training round {round_num}")
self.current_round += 1