Spaces:
Sleeping
Sleeping
File size: 1,527 Bytes
754afec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""aggregator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
class FederatedAggregator:
def __init__(self, config: Dict):
"""Initialize the federated aggregator."""
self.weighted = config['aggregation']['weighted']
def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
"""Compute aggregated metrics from client updates."""
if not client_metrics:
return {}
aggregated_metrics = defaultdict(float)
total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
for metrics in client_metrics:
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
for metric_name, value in metrics['metrics'].items():
aggregated_metrics[metric_name] += value * weight
return dict(aggregated_metrics)
def check_convergence(self,
old_weights: List,
new_weights: List,
threshold: float = 1e-5) -> bool:
"""Check if the model has converged."""
if old_weights is None or new_weights is None:
return False
weight_differences = [
np.mean(np.abs(old - new))
for old, new in zip(old_weights, new_weights)
]
return all(diff < threshold for diff in weight_differences)
|