“Transcendental-Programmer”
FEAT: added server coordination and model aggregation logic
754afec
raw
history blame
1.53 kB
"""aggregator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
class FederatedAggregator:
def __init__(self, config: Dict):
"""Initialize the federated aggregator."""
self.weighted = config['aggregation']['weighted']
def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
"""Compute aggregated metrics from client updates."""
if not client_metrics:
return {}
aggregated_metrics = defaultdict(float)
total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
for metrics in client_metrics:
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
for metric_name, value in metrics['metrics'].items():
aggregated_metrics[metric_name] += value * weight
return dict(aggregated_metrics)
def check_convergence(self,
old_weights: List,
new_weights: List,
threshold: float = 1e-5) -> bool:
"""Check if the model has converged."""
if old_weights is None or new_weights is None:
return False
weight_differences = [
np.mean(np.abs(old - new))
for old, new in zip(old_weights, new_weights)
]
return all(diff < threshold for diff in weight_differences)