Spaces:
Sleeping
Sleeping
"""aggregator.py module.""" | |
import tensorflow as tf | |
from typing import List, Dict | |
import numpy as np | |
from collections import defaultdict | |
class FederatedAggregator: | |
def __init__(self, config: Dict): | |
"""Initialize the federated aggregator.""" | |
self.weighted = config['aggregation']['weighted'] | |
def compute_metrics(self, client_metrics: List[Dict]) -> Dict: | |
"""Compute aggregated metrics from client updates.""" | |
if not client_metrics: | |
return {} | |
aggregated_metrics = defaultdict(float) | |
total_samples = sum(metrics['num_samples'] for metrics in client_metrics) | |
for metrics in client_metrics: | |
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0 | |
for metric_name, value in metrics['metrics'].items(): | |
aggregated_metrics[metric_name] += value * weight | |
return dict(aggregated_metrics) | |
def check_convergence(self, | |
old_weights: List, | |
new_weights: List, | |
threshold: float = 1e-5) -> bool: | |
"""Check if the model has converged.""" | |
if old_weights is None or new_weights is None: | |
return False | |
weight_differences = [ | |
np.mean(np.abs(old - new)) | |
for old, new in zip(old_weights, new_weights) | |
] | |
return all(diff < threshold for diff in weight_differences) | |