Transcendental-Programmer
feat: added the server coordinator and aggregator
45309a1
raw
history blame
4.27 kB
"""aggregator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
import logging
class FederatedAggregator:
def __init__(self, config: Dict):
logger = logging.getLogger(__name__)
logger.debug(f"Initializing FederatedAggregator with config: {config}")
# Defensive: try to find aggregation config
agg_config = None
if 'aggregation' in config:
agg_config = config['aggregation']
elif 'server' in config and 'aggregation' in config['server']:
agg_config = config['server']['aggregation']
else:
logger.error(f"No 'aggregation' key found in config passed to FederatedAggregator: {config}")
raise KeyError("'aggregation' config section is required for FederatedAggregator")
self.weighted = agg_config.get('weighted', True)
logger.info(f"FederatedAggregator initialized. Weighted: {self.weighted}")
def federated_averaging(self, updates: List[Dict]) -> List:
"""Perform federated averaging (FedAvg) on model weights."""
logger = logging.getLogger(__name__)
logger.info(f"Performing federated averaging on {len(updates)} client updates")
if not updates:
logger.warning("No updates provided for federated averaging")
return None
# Calculate total samples across all clients
total_samples = sum(update['size'] for update in updates)
logger.debug(f"Total samples across clients: {total_samples}")
# Initialize aggregated weights with zeros
first_weights = updates[0]['weights']
aggregated_weights = [np.zeros_like(w) for w in first_weights]
# Weighted average of model weights
for update in updates:
client_weights = update['weights']
client_size = update['size']
weight_factor = client_size / total_samples if self.weighted else 1.0 / len(updates)
logger.debug(f"Client {update['client_id']}: size={client_size}, weight_factor={weight_factor}")
# Add weighted contribution to aggregated weights
for i, (agg_w, client_w) in enumerate(zip(aggregated_weights, client_weights)):
aggregated_weights[i] += np.array(client_w) * weight_factor
logger.info("Federated averaging completed successfully")
return aggregated_weights
def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
logger = logging.getLogger(__name__)
logger.debug(f"Computing metrics for {len(client_metrics)} clients")
if not client_metrics:
logger.warning("No client metrics provided to compute_metrics.")
return {}
aggregated_metrics = defaultdict(float)
total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
logger.debug(f"Total samples across clients: {total_samples}")
for metrics in client_metrics:
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
logger.debug(f"Client metrics: {metrics}, weight: {weight}")
for metric_name, value in metrics['metrics'].items():
aggregated_metrics[metric_name] += value * weight
logger.info(f"Aggregated metrics: {dict(aggregated_metrics)}")
return dict(aggregated_metrics)
def check_convergence(self,
old_weights: List,
new_weights: List,
threshold: float = 1e-5) -> bool:
logger = logging.getLogger(__name__)
logger.debug("Checking convergence...")
if old_weights is None or new_weights is None:
logger.warning("Old or new weights are None in check_convergence.")
return False
weight_differences = [
np.mean(np.abs(old - new))
for old, new in zip(old_weights, new_weights)
]
logger.debug(f"Weight differences: {weight_differences}")
converged = all(diff < threshold for diff in weight_differences)
logger.info(f"Convergence status: {converged}")
return converged