Spaces:
Sleeping
Sleeping
"""model.py module.""" | |
from typing import Dict, List, Optional, Tuple | |
import tensorflow as tf | |
import numpy as np | |
import logging | |
import time | |
from ..api.client import FederatedHTTPClient | |
from .data_handler import FinancialDataHandler | |
class FederatedClient: | |
def __init__(self, client_id: str, config: Dict, server_url: Optional[str] = None): | |
"""Initialize the federated client.""" | |
self.client_id = str(client_id) | |
self.config = config.get('client', {}) | |
self.model = self._build_model() | |
self.data_handler = FinancialDataHandler(config) | |
# HTTP client for server communication | |
self.server_url = server_url or self.config.get('server_url', 'http://localhost:8080') | |
self.http_client = FederatedHTTPClient(self.server_url, self.client_id) | |
# Training state | |
self.registered = False | |
self.current_round = 0 | |
def start(self): | |
"""Start the federated client process with server communication.""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Client {self.client_id} starting...") | |
try: | |
# Wait for server to be available | |
if not self.http_client.wait_for_server(): | |
raise ConnectionError(f"Cannot connect to server at {self.server_url}") | |
# Register with server | |
self._register_with_server() | |
# Main federated learning loop | |
self._federated_learning_loop() | |
except Exception as e: | |
logger.error(f"Error during client execution: {str(e)}") | |
raise | |
finally: | |
self.http_client.close() | |
def _register_with_server(self): | |
"""Register this client with the federated server""" | |
logger = logging.getLogger(__name__) | |
try: | |
# Generate local data to get client info | |
X, y = self._generate_dummy_data() | |
client_info = { | |
'dataset_size': len(X), | |
'model_params': self.model.count_params(), | |
'capabilities': ['training', 'inference'] | |
} | |
response = self.http_client.register(client_info) | |
self.registered = True | |
logger.info(f"Successfully registered with server") | |
logger.info(f"Dataset size: {client_info['dataset_size']}") | |
logger.info(f"Model parameters: {client_info['model_params']:,}") | |
except Exception as e: | |
logger.error(f"Failed to register with server: {str(e)}") | |
raise | |
def _federated_learning_loop(self): | |
"""Main federated learning loop""" | |
logger = logging.getLogger(__name__) | |
while True: | |
try: | |
# Get training status from server | |
status = self.http_client.get_training_status() | |
if not status.get('training_active', True): | |
logger.info("Training completed on server") | |
break | |
server_round = status.get('current_round', 0) | |
if server_round > self.current_round: | |
self._participate_in_round(server_round) | |
self.current_round = server_round | |
time.sleep(5) # Check every 5 seconds | |
except Exception as e: | |
logger.error(f"Error in federated learning loop: {str(e)}") | |
time.sleep(10) # Wait longer on error | |
def _participate_in_round(self, round_num: int): | |
"""Participate in a federated learning round""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Participating in round {round_num}") | |
try: | |
# Get global model from server | |
model_response = self.http_client.get_global_model() | |
global_weights = model_response.get('model_weights') | |
if global_weights: | |
self.set_weights(global_weights) | |
logger.info("Updated local model with global weights") | |
# Generate/load local data | |
X, y = self._generate_dummy_data() | |
logger.info(f"Training on {len(X)} samples") | |
# Train locally | |
history = self.train_local((X, y)) | |
# Prepare metrics | |
metrics = { | |
'dataset_size': len(X), | |
'final_loss': history['loss'][-1] if history['loss'] else 0.0, | |
'epochs_trained': len(history['loss']), | |
'round': round_num | |
} | |
# Submit update to server | |
local_weights = self.get_weights() | |
self.http_client.submit_model_update(local_weights, metrics) | |
logger.info(f"Round {round_num} completed - Final loss: {metrics['final_loss']:.4f}") | |
except Exception as e: | |
logger.error(f"Error in round {round_num}: {str(e)}") | |
raise | |
def _generate_dummy_data(self): | |
"""Generate dummy data for testing.""" | |
try: | |
# Try to use the data handler for more realistic data | |
return self.data_handler.generate_synthetic_data(100) | |
except Exception: | |
# Fallback to simple dummy data | |
num_samples = 100 | |
input_dim = 32 # Match with model's input dimension | |
# Generate input data | |
X = tf.random.normal((num_samples, input_dim)) | |
# Generate target data (for this example, we'll predict the sum of inputs) | |
y = tf.reduce_sum(X, axis=1, keepdims=True) | |
return X.numpy(), y.numpy() | |
def _build_model(self): | |
"""Build the initial model architecture.""" | |
input_dim = 32 # Match with data generation | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Input(shape=(input_dim,)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(1) # Output layer for regression | |
]) | |
model.compile( | |
optimizer=tf.keras.optimizers.Adam( | |
learning_rate=self.config.get('training', {}).get('learning_rate', 0.001) | |
), | |
loss='mse' | |
) | |
return model | |
def train_local(self, data): | |
"""Train the model on local data.""" | |
logger = logging.getLogger(__name__) | |
X, y = data | |
# Ensure data is in the right format | |
if isinstance(X, np.ndarray): | |
X = tf.convert_to_tensor(X, dtype=tf.float32) | |
if isinstance(y, np.ndarray): | |
y = tf.convert_to_tensor(y, dtype=tf.float32) | |
# Log training parameters | |
logger.info(f"Training Parameters:") | |
logger.info(f"Input shape: {X.shape}") | |
logger.info(f"Output shape: {y.shape}") | |
logger.info(f"Batch size: {self.config.get('training', {}).get('batch_size', 32)}") | |
logger.info(f"Epochs: {self.config.get('training', {}).get('local_epochs', 5)}") | |
class LogCallback(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs=None): | |
logger.debug(f"Epoch {epoch + 1} - loss: {logs['loss']:.4f}") | |
# Train the model | |
history = self.model.fit( | |
X, y, | |
batch_size=self.config.get('training', {}).get('batch_size', 32), | |
epochs=self.config.get('training', {}).get('local_epochs', 3), | |
verbose=0, | |
callbacks=[LogCallback()] | |
) | |
return history.history | |
def get_weights(self) -> List: | |
"""Get the model weights.""" | |
weights = self.model.get_weights() | |
# Convert to serializable format | |
return [w.tolist() for w in weights] | |
def set_weights(self, weights: List): | |
"""Update local model with global weights.""" | |
# Convert from serializable format back to numpy arrays | |
np_weights = [np.array(w) for w in weights] | |
self.model.set_weights(np_weights) | |