“Transcendental-Programmer”
fix : minor fixes
fc5fa78
raw
history blame
2.75 kB
"""aggregator.py module."""
import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
import logging
class FederatedAggregator:
def __init__(self, config: Dict):
logger = logging.getLogger(__name__)
logger.debug(f"Initializing FederatedAggregator with config: {config}")
# Defensive: try to find aggregation config
agg_config = None
if 'aggregation' in config:
agg_config = config['aggregation']
elif 'server' in config and 'aggregation' in config['server']:
agg_config = config['server']['aggregation']
else:
logger.error(f"No 'aggregation' key found in config passed to FederatedAggregator: {config}")
raise KeyError("'aggregation' config section is required for FederatedAggregator")
self.weighted = agg_config.get('weighted', True)
logger.info(f"FederatedAggregator initialized. Weighted: {self.weighted}")
def compute_metrics(self, client_metrics: List[Dict]) -> Dict:
logger = logging.getLogger(__name__)
logger.debug(f"Computing metrics for {len(client_metrics)} clients")
if not client_metrics:
logger.warning("No client metrics provided to compute_metrics.")
return {}
aggregated_metrics = defaultdict(float)
total_samples = sum(metrics['num_samples'] for metrics in client_metrics)
logger.debug(f"Total samples across clients: {total_samples}")
for metrics in client_metrics:
weight = metrics['num_samples'] / total_samples if self.weighted else 1.0
logger.debug(f"Client metrics: {metrics}, weight: {weight}")
for metric_name, value in metrics['metrics'].items():
aggregated_metrics[metric_name] += value * weight
logger.info(f"Aggregated metrics: {dict(aggregated_metrics)}")
return dict(aggregated_metrics)
def check_convergence(self,
old_weights: List,
new_weights: List,
threshold: float = 1e-5) -> bool:
logger = logging.getLogger(__name__)
logger.debug("Checking convergence...")
if old_weights is None or new_weights is None:
logger.warning("Old or new weights are None in check_convergence.")
return False
weight_differences = [
np.mean(np.abs(old - new))
for old, new in zip(old_weights, new_weights)
]
logger.debug(f"Weight differences: {weight_differences}")
converged = all(diff < threshold for diff in weight_differences)
logger.info(f"Convergence status: {converged}")
return converged