Spaces:
Sleeping
Sleeping
"""aggregator.py module.""" | |
import tensorflow as tf | |
from typing import List, Dict | |
import numpy as np | |
from collections import defaultdict | |
import logging | |
class FederatedAggregator: | |
def __init__(self, config: Dict): | |
logger = logging.getLogger(__name__) | |
logger.debug(f"Initializing FederatedAggregator with config: {config}") | |
# Defensive: try to find aggregation config | |
agg_config = None | |
if 'aggregation' in config: | |
agg_config = config['aggregation'] | |
elif 'server' in config and 'aggregation' in config['server']: | |
agg_config = config['server']['aggregation'] | |
else: | |
logger.error(f"No 'aggregation' key found in config passed to FederatedAggregator: {config}") | |
raise KeyError("'aggregation' config section is required for FederatedAggregator") | |
self.weighted = agg_config.get('weighted', True) | |
logger.info(f"FederatedAggregator initialized. Weighted: {self.weighted}") | |
def federated_averaging(self, updates: List[Dict]) -> List: | |
"""Perform federated averaging (FedAvg) on model weights.""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Performing federated averaging on {len(updates)} client updates") | |
if not updates: | |
logger.warning("No updates provided for federated averaging") | |
return None | |
# Calculate total samples across all clients | |
total_samples = sum(update['size'] for update in updates) | |
logger.debug(f"Total samples across clients: {total_samples}") | |
# Initialize aggregated weights with zeros | |
first_weights = updates[0]['weights'] | |
aggregated_weights = [np.zeros_like(w) for w in first_weights] | |
# Weighted average of model weights | |
for update in updates: | |
client_weights = update['weights'] | |
client_size = update['size'] | |
weight_factor = client_size / total_samples if self.weighted else 1.0 / len(updates) | |
logger.debug(f"Client {update['client_id']}: size={client_size}, weight_factor={weight_factor}") | |
# Add weighted contribution to aggregated weights | |
for i, (agg_w, client_w) in enumerate(zip(aggregated_weights, client_weights)): | |
aggregated_weights[i] += np.array(client_w) * weight_factor | |
logger.info("Federated averaging completed successfully") | |
return aggregated_weights | |
def compute_metrics(self, client_metrics: List[Dict]) -> Dict: | |
logger = logging.getLogger(__name__) | |
logger.debug(f"Computing metrics for {len(client_metrics)} clients") | |
if not client_metrics: | |
logger.warning("No client metrics provided to compute_metrics.") | |
return {} | |
aggregated_metrics = defaultdict(float) | |
total_samples = sum(metrics['num_samples'] for metrics in client_metrics) | |
logger.debug(f"Total samples across clients: {total_samples}") | |
for metrics in client_metrics: | |
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0 | |
logger.debug(f"Client metrics: {metrics}, weight: {weight}") | |
for metric_name, value in metrics['metrics'].items(): | |
aggregated_metrics[metric_name] += value * weight | |
logger.info(f"Aggregated metrics: {dict(aggregated_metrics)}") | |
return dict(aggregated_metrics) | |
def check_convergence(self, | |
old_weights: List, | |
new_weights: List, | |
threshold: float = 1e-5) -> bool: | |
logger = logging.getLogger(__name__) | |
logger.debug("Checking convergence...") | |
if old_weights is None or new_weights is None: | |
logger.warning("Old or new weights are None in check_convergence.") | |
return False | |
weight_differences = [ | |
np.mean(np.abs(old - new)) | |
for old, new in zip(old_weights, new_weights) | |
] | |
logger.debug(f"Weight differences: {weight_differences}") | |
converged = all(diff < threshold for diff in weight_differences) | |
logger.info(f"Convergence status: {converged}") | |
return converged | |