Spaces:
Sleeping
Sleeping
""" | |
RESTful API for Federated Learning Server | |
Handles client registration, model updates, and coordination | |
""" | |
from flask import Flask, request, jsonify | |
import logging | |
import threading | |
import time | |
from typing import Dict, Any, List | |
from ..server.coordinator import FederatedCoordinator | |
from ..utils.metrics import calculate_model_similarity | |
logger = logging.getLogger(__name__) | |
class FederatedAPI: | |
def __init__(self, coordinator: FederatedCoordinator, host: str = "0.0.0.0", port: int = 8080): | |
self.app = Flask(__name__) | |
self.coordinator = coordinator | |
self.host = host | |
self.port = port | |
self._setup_routes() | |
def _setup_routes(self): | |
"""Setup API routes""" | |
def health_check(): | |
"""Health check endpoint""" | |
return jsonify({ | |
'status': 'healthy', | |
'timestamp': time.time(), | |
'active_clients': len(self.coordinator.clients), | |
'current_round': getattr(self.coordinator, 'current_round', 0) | |
}) | |
def register_client(): | |
"""Register a new client""" | |
try: | |
data = request.get_json() | |
client_id = data.get('client_id') | |
client_info = data.get('client_info', {}) | |
if not client_id: | |
return jsonify({'error': 'client_id is required'}), 400 | |
success = self.coordinator.register_client(client_id, client_info) | |
if success: | |
return jsonify({ | |
'status': 'registered', | |
'client_id': client_id, | |
'server_config': self.coordinator.get_client_config() | |
}) | |
else: | |
return jsonify({'error': 'Registration failed'}), 400 | |
except Exception as e: | |
logger.error(f"Error registering client: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def get_global_model(): | |
"""Get the current global model""" | |
try: | |
data = request.get_json() | |
client_id = data.get('client_id') | |
if not client_id or client_id not in self.coordinator.clients: | |
return jsonify({'error': 'Invalid client_id'}), 400 | |
model_weights = self.coordinator.get_global_model() | |
return jsonify({ | |
'model_weights': model_weights, | |
'round': getattr(self.coordinator, 'current_round', 0), | |
'timestamp': time.time() | |
}) | |
except Exception as e: | |
logger.error(f"Error getting global model: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def submit_model_update(): | |
"""Submit a model update from client""" | |
try: | |
data = request.get_json() | |
client_id = data.get('client_id') | |
model_weights = data.get('model_weights') | |
training_metrics = data.get('metrics', {}) | |
if not client_id or not model_weights: | |
return jsonify({'error': 'client_id and model_weights are required'}), 400 | |
if client_id not in self.coordinator.clients: | |
return jsonify({'error': 'Client not registered'}), 400 | |
# Store the update | |
self.coordinator.receive_model_update(client_id, model_weights, training_metrics) | |
return jsonify({ | |
'status': 'update_received', | |
'client_id': client_id, | |
'timestamp': time.time() | |
}) | |
except Exception as e: | |
logger.error(f"Error submitting model update: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def get_training_status(): | |
"""Get current training status""" | |
try: | |
return jsonify({ | |
'current_round': getattr(self.coordinator, 'current_round', 0), | |
'total_rounds': self.coordinator.config.get('federated', {}).get('num_rounds', 10), | |
'active_clients': len(self.coordinator.clients), | |
'clients_ready': len(getattr(self.coordinator, 'client_updates', {})), | |
'min_clients': self.coordinator.config.get('federated', {}).get('min_clients', 2), | |
'training_active': getattr(self.coordinator, 'training_active', False) | |
}) | |
except Exception as e: | |
logger.error(f"Error getting training status: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def rag_query(): | |
"""Handle RAG queries""" | |
try: | |
data = request.get_json() | |
query = data.get('query') | |
client_id = data.get('client_id') | |
if not query: | |
return jsonify({'error': 'query is required'}), 400 | |
# This will be implemented when we integrate RAG | |
return jsonify({ | |
'response': 'RAG functionality coming soon', | |
'query': query, | |
'timestamp': time.time() | |
}) | |
except Exception as e: | |
logger.error(f"Error processing RAG query: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def predict(): | |
"""Predict using the current global model.""" | |
try: | |
data = request.get_json() | |
features = data.get('features') | |
if features is None or not isinstance(features, list) or len(features) != 32: | |
return jsonify({'error': 'features must be a list of 32 floats'}), 400 | |
# Get global model weights | |
model_weights = self.coordinator.get_global_model() | |
if model_weights is None: | |
return jsonify({'error': 'Global model not available yet'}), 503 | |
# Build the model (same as client) | |
import tensorflow as tf | |
import numpy as np | |
input_dim = 32 | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Input(shape=(input_dim,)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(1) | |
]) | |
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse') | |
model.set_weights([np.array(w) for w in model_weights]) | |
# Prepare input and predict | |
x = np.array(features, dtype=np.float32).reshape(1, -1) | |
pred = model.predict(x) | |
prediction = float(pred[0, 0]) | |
return jsonify({'prediction': prediction}) | |
except Exception as e: | |
logger.error(f"Error in prediction endpoint: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def run(self, debug: bool = False): | |
"""Run the API server""" | |
logger.info(f"Starting Federated API server on {self.host}:{self.port}") | |
self.app.run(host=self.host, port=self.port, debug=debug, threaded=True) | |
def run_threaded(self, debug: bool = False): | |
"""Run the API server in a separate thread""" | |
def run_server(): | |
self.app.run(host=self.host, port=self.port, debug=debug, threaded=True) | |
thread = threading.Thread(target=run_server, daemon=True) | |
thread.start() | |
logger.info(f"Federated API server started in background on {self.host}:{self.port}") | |
return thread | |