"""model.py module.""" from typing import Dict, List, Optional, Tuple import tensorflow as tf import numpy as np import logging import time from ..api.client import FederatedHTTPClient from .data_handler import FinancialDataHandler class FederatedClient: def __init__(self, client_id: str, config: Dict, server_url: Optional[str] = None): """Initialize the federated client.""" self.client_id = str(client_id) self.config = config.get('client', {}) self.model = self._build_model() self.data_handler = FinancialDataHandler(self.config) # HTTP client for server communication self.server_url = server_url or self.config.get('server_url', 'http://localhost:8080') self.http_client = FederatedHTTPClient(self.server_url, self.client_id) # Training state self.registered = False self.current_round = 0 def start(self): """Start the federated client process with server communication.""" logger = logging.getLogger(__name__) logger.info(f"Client {self.client_id} starting...") try: # Wait for server to be available if not self.http_client.wait_for_server(): raise ConnectionError(f"Cannot connect to server at {self.server_url}") # Register with server self._register_with_server() # Main federated learning loop self._federated_learning_loop() except Exception as e: logger.error(f"Error during client execution: {str(e)}") raise finally: self.http_client.close() def _register_with_server(self): """Register this client with the federated server""" logger = logging.getLogger(__name__) try: # Generate local data to get client info X, y = self._generate_dummy_data() client_info = { 'dataset_size': len(X), 'model_params': self.model.count_params(), 'capabilities': ['training', 'inference'] } response = self.http_client.register(client_info) self.registered = True logger.info(f"Successfully registered with server") logger.info(f"Dataset size: {client_info['dataset_size']}") logger.info(f"Model parameters: {client_info['model_params']:,}") except Exception as e: logger.error(f"Failed to register with server: {str(e)}") raise def _federated_learning_loop(self): """Main federated learning loop""" logger = logging.getLogger(__name__) while True: try: # Get training status from server status = self.http_client.get_training_status() if not status.get('training_active', True): logger.info("Training completed on server") break server_round = status.get('current_round', 0) if server_round > self.current_round: self._participate_in_round(server_round) self.current_round = server_round time.sleep(5) # Check every 5 seconds except Exception as e: logger.error(f"Error in federated learning loop: {str(e)}") time.sleep(10) # Wait longer on error def _participate_in_round(self, round_num: int): """Participate in a federated learning round""" logger = logging.getLogger(__name__) logger.info(f"Participating in round {round_num}") try: # Get global model from server model_response = self.http_client.get_global_model() global_weights = model_response.get('model_weights') if global_weights: self.set_weights(global_weights) logger.info("Updated local model with global weights") # Generate/load local data X, y = self._generate_dummy_data() logger.info(f"Training on {len(X)} samples") # Train locally history = self.train_local((X, y)) # Prepare metrics metrics = { 'dataset_size': len(X), 'final_loss': history['loss'][-1] if history['loss'] else 0.0, 'epochs_trained': len(history['loss']), 'round': round_num } # Submit update to server local_weights = self.get_weights() self.http_client.submit_model_update(local_weights, metrics) logger.info(f"Round {round_num} completed - Final loss: {metrics['final_loss']:.4f}") except Exception as e: logger.error(f"Error in round {round_num}: {str(e)}") raise def _generate_dummy_data(self): """Generate dummy data for testing.""" try: # Try to use the data handler for more realistic data return self.data_handler.generate_synthetic_data(100) except Exception: # Fallback to simple dummy data num_samples = 100 input_dim = 32 # Match with model's input dimension # Generate input data X = tf.random.normal((num_samples, input_dim)) # Generate target data (for this example, we'll predict the sum of inputs) y = tf.reduce_sum(X, axis=1, keepdims=True) return X.numpy(), y.numpy() def _build_model(self): """Build the initial model architecture.""" input_dim = 32 # Match with data generation model = tf.keras.Sequential([ tf.keras.layers.Input(shape=(input_dim,)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1) # Output layer for regression ]) model.compile( optimizer=tf.keras.optimizers.Adam( learning_rate=self.config.get('training', {}).get('learning_rate', 0.001) ), loss='mse' ) return model def train_local(self, data): """Train the model on local data.""" logger = logging.getLogger(__name__) X, y = data # Ensure data is in the right format if isinstance(X, np.ndarray): X = tf.convert_to_tensor(X, dtype=tf.float32) if isinstance(y, np.ndarray): y = tf.convert_to_tensor(y, dtype=tf.float32) # Log training parameters logger.info(f"Training Parameters:") logger.info(f"Input shape: {X.shape}") logger.info(f"Output shape: {y.shape}") logger.info(f"Batch size: {self.config.get('training', {}).get('batch_size', 32)}") logger.info(f"Epochs: {self.config.get('training', {}).get('local_epochs', 5)}") class LogCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): logger.debug(f"Epoch {epoch + 1} - loss: {logs['loss']:.4f}") # Train the model history = self.model.fit( X, y, batch_size=self.config.get('training', {}).get('batch_size', 32), epochs=self.config.get('training', {}).get('local_epochs', 3), verbose=0, callbacks=[LogCallback()] ) return history.history def get_weights(self) -> List: """Get the model weights.""" return self.model.get_weights() def set_weights(self, weights: List): """Update local model with global weights.""" self.model.set_weights(weights)