"""aggregator.py module.""" import tensorflow as tf from typing import List, Dict import numpy as np from collections import defaultdict class FederatedAggregator: def __init__(self, config: Dict): """Initialize the federated aggregator.""" self.weighted = config['aggregation']['weighted'] def compute_metrics(self, client_metrics: List[Dict]) -> Dict: """Compute aggregated metrics from client updates.""" if not client_metrics: return {} aggregated_metrics = defaultdict(float) total_samples = sum(metrics['num_samples'] for metrics in client_metrics) for metrics in client_metrics: weight = metrics['num_samples'] / total_samples if self.weighted else 1.0 for metric_name, value in metrics['metrics'].items(): aggregated_metrics[metric_name] += value * weight return dict(aggregated_metrics) def check_convergence(self, old_weights: List, new_weights: List, threshold: float = 1e-5) -> bool: """Check if the model has converged.""" if old_weights is None or new_weights is None: return False weight_differences = [ np.mean(np.abs(old - new)) for old, new in zip(old_weights, new_weights) ] return all(diff < threshold for diff in weight_differences)