Spaces:
Sleeping
Sleeping
File size: 1,834 Bytes
41d470a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""test_server.py module."""
import pytest
import numpy as np
import tensorflow as tf
from src.server.coordinator import FederatedCoordinator
from src.server.aggregator import FederatedAggregator
import yaml
@pytest.fixture
def server_config():
with open('config/server_config.yaml', 'r') as f:
return yaml.safe_load(f)['server']
@pytest.fixture
def coordinator(server_config):
return FederatedCoordinator(server_config)
@pytest.fixture
def aggregator(server_config):
return FederatedAggregator(server_config)
def test_coordinator_initialization(coordinator, server_config):
assert coordinator.min_clients == server_config['federated']['min_clients']
assert coordinator.rounds == server_config['federated']['rounds']
assert coordinator.sample_fraction == server_config['federated']['sample_fraction']
def test_client_registration(coordinator):
client_id = 1
client_size = 1000
coordinator.register_client(client_id, client_size)
assert client_id in coordinator.clients
assert coordinator.clients[client_id]['size'] == client_size
def test_client_selection(coordinator):
# Register multiple clients
for i in range(5):
coordinator.register_client(i, 1000)
selected_clients = coordinator.select_clients()
assert len(selected_clients) >= coordinator.min_clients
assert all(client_id in coordinator.clients for client_id in selected_clients)
def test_weight_aggregation(aggregator):
# Create mock client updates
client_updates = [
{
'client_id': i,
'weights': [np.random.randn(10, 10) for _ in range(3)],
'metrics': {'loss': 0.5}
}
for i in range(3)
]
aggregated_weights = aggregator.compute_metrics(client_updates)
assert isinstance(aggregated_weights, dict)
|