File size: 2,887 Bytes
754afec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""coordinator.py module."""

import tensorflow as tf
from typing import List, Dict
import numpy as np
from collections import defaultdict
import logging
import time

class FederatedCoordinator:
    def __init__(self, config: Dict):
        """Initialize the federated learning coordinator."""
        self.config = config
        self.clients = {}
        self.current_round = 0
        self.min_clients = config.get('server', {}).get('federated', {}).get('min_clients', 2)
        self.rounds = config.get('server', {}).get('federated', {}).get('rounds', 10)
        
    def register_client(self, client_id: int, client_size: int):
        """Register a new client."""
        self.clients[client_id] = {
            'size': client_size,
            'weights': None,
            'metrics': defaultdict(list)
        }
    
    def aggregate_weights(self, client_updates: List[Dict]) -> List:
        """Aggregate weights using FedAvg algorithm."""
        total_size = sum(self.clients[update['client_id']]['size'] 
                        for update in client_updates)
        
        aggregated_weights = [
            np.zeros_like(w) for w in client_updates[0]['weights']
        ]
        
        for update in client_updates:
            client_size = self.clients[update['client_id']]['size']
            weight = client_size / total_size
            
            for i, layer_weights in enumerate(update['weights']):
                aggregated_weights[i] += layer_weights * weight
                
        return aggregated_weights
    
    def start(self):
        """Start the federated learning process."""
        logger = logging.getLogger(__name__)
        
        # Print server startup information
        logger.info("\n" + "=" * 60)
        logger.info(f"{'Federated Learning Server Starting':^60}")
        logger.info("=" * 60)
        
        # Print configuration details
        logger.info("\nServer Configuration:")
        logger.info("-" * 30)
        logger.info(f"Minimum clients required: {self.min_clients}")
        logger.info(f"Total rounds planned: {self.rounds}")
        logger.info(f"Current active clients: {len(self.clients)}")
        logger.info("-" * 30 + "\n")
        
        while self.current_round < self.rounds:
            round_num = self.current_round + 1
            logger.info(f"\nRound {round_num}/{self.rounds}")
            logger.info("-" * 30)
            
            if len(self.clients) < self.min_clients:
                logger.warning(
                    f"Waiting for clients... "
                    f"(active: {len(self.clients)}/{self.min_clients})"
                )
                time.sleep(5)
                continue
            
            logger.info(f"Active clients: {list(self.clients.keys())}")
            logger.info(f"Starting training round {round_num}")
            self.current_round += 1