Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Test script for the complete federated learning system. | |
This script tests the server, client, and web app integration. | |
""" | |
import requests | |
import time | |
import json | |
import numpy as np | |
from pathlib import Path | |
import subprocess | |
import sys | |
import threading | |
def test_server_health(server_url="http://localhost:8080"): | |
"""Test if the server is healthy.""" | |
try: | |
response = requests.get(f"{server_url}/health", timeout=5) | |
if response.status_code == 200: | |
print("β Server health check passed") | |
return True | |
else: | |
print(f"β Server health check failed: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"β Cannot connect to server: {e}") | |
return False | |
def test_prediction(server_url="http://localhost:8080"): | |
"""Test the prediction endpoint.""" | |
try: | |
# Generate test features | |
features = np.random.randn(32).tolist() | |
response = requests.post( | |
f"{server_url}/predict", | |
json={"features": features}, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
prediction = response.json().get("prediction") | |
print(f"β Prediction test passed: {prediction:.4f}") | |
return True | |
else: | |
print(f"β Prediction test failed: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"β Prediction test error: {e}") | |
return False | |
def test_training_status(server_url="http://localhost:8080"): | |
"""Test the training status endpoint.""" | |
try: | |
response = requests.get(f"{server_url}/training_status", timeout=5) | |
if response.status_code == 200: | |
data = response.json() | |
print(f"β Training status test passed: Round {data.get('current_round', 0)}") | |
return True | |
else: | |
print(f"β Training status test failed: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"β Training status test error: {e}") | |
return False | |
def test_client_registration(server_url="http://localhost:8080"): | |
"""Test client registration.""" | |
try: | |
client_info = { | |
'dataset_size': 100, | |
'model_params': 10000, | |
'capabilities': ['training', 'inference'] | |
} | |
response = requests.post( | |
f"{server_url}/register", | |
json={'client_id': 'test_client', 'client_info': client_info}, | |
timeout=10 | |
) | |
if response.status_code == 200: | |
print("β Client registration test passed") | |
return True | |
else: | |
print(f"β Client registration test failed: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"β Client registration test error: {e}") | |
return False | |
def run_complete_test(): | |
"""Run all tests.""" | |
print("π Testing Complete Federated Learning System") | |
print("=" * 50) | |
server_url = "http://localhost:8080" | |
# Test server health | |
if not test_server_health(server_url): | |
print("\nβ Server is not running. Please start the server first:") | |
print("python -m src.main --mode server --config config/server_config.yaml") | |
return False | |
# Test client registration | |
if not test_client_registration(server_url): | |
print("\nβ Client registration failed") | |
return False | |
# Test training status | |
if not test_training_status(server_url): | |
print("\nβ Training status failed") | |
return False | |
# Test prediction | |
if not test_prediction(server_url): | |
print("\nβ Prediction failed") | |
return False | |
print("\nπ All tests passed! The federated learning system is working correctly.") | |
print("\nNext steps:") | |
print("1. Start the web app: streamlit run webapp/streamlit_app.py") | |
print("2. Start additional clients: python -m src.main --mode client --config config/client_config.yaml") | |
print("3. Use the web interface to interact with the system") | |
return True | |
if __name__ == "__main__": | |
success = run_complete_test() | |
sys.exit(0 if success else 1) |