from pathlib import Path from typing import Dict, List, Optional import matplotlib.pyplot as plt from datetime import datetime import logging logger = logging.getLogger(__name__) class TrainingPlotter: def __init__(self, save_dir: Optional[Path] = None): self.save_dir = save_dir if save_dir: self.save_dir.mkdir(parents=True, exist_ok=True) def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"): """Plot and optionally save training metrics history. Args: history: Dictionary containing training metrics title: Title for the plot """ # Silence matplotlib debug messages logger.setLevel(logging.WARNING) # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) # Plot losses ax1.plot(history['train_loss'], label='Train Loss') ax1.plot(history['val_loss'], label='Validation Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.set_title('Training and Validation Loss') ax1.legend() ax1.grid(True) # Plot learning rate if available if 'learning_rate' in history: ax2.plot(history['learning_rate'], label='Learning Rate') ax2.set_xlabel('Step') ax2.set_ylabel('Learning Rate') ax2.set_title('Learning Rate Schedule') ax2.legend() ax2.grid(True) plt.suptitle(title) plt.tight_layout() # Reset the logger level logger.setLevel(logging.INFO) # Save if directory provided if self.save_dir: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = self.save_dir / f'training_history_{timestamp}.png' plt.savefig(save_path) logger.info(f"Saved training history plot to {save_path}") plt.show() def plot_validation_metrics(self, metrics: Dict[str, float]): """Plot validation metrics as a bar chart. Args: metrics: Dictionary of validation metrics. Can handle nested dictionaries. """ # Silence matplotlib debug messages logger.setLevel(logging.WARNING) # Flatten nested metrics dictionary flat_metrics = {} for key, value in metrics.items(): # Skip num_queries_tested if key == 'num_queries_tested': continue if isinstance(value, dict): # If value is a dictionary, flatten it with key prefix for subkey, subvalue in value.items(): if isinstance(subvalue, (int, float)): # Only include numeric values flat_metrics[f"{key}_{subkey}"] = subvalue elif isinstance(value, (int, float)): # Only include numeric values flat_metrics[key] = value if not flat_metrics: logger.warning("No numeric metrics to plot") return plt.figure(figsize=(12, 6)) # Extract metrics and values metric_names = list(flat_metrics.keys()) values = list(flat_metrics.values()) # Create bar chart bars = plt.bar(range(len(metric_names)), values) # Customize the plot plt.title('Validation Metrics') plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right') plt.ylabel('Value') # Add value labels on top of bars for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom') # Set y-axis limits to focus on metrics between 0 and 1 plt.ylim(0, 1.1) # Slight padding above 1 for label visibility # Adjust layout to prevent label cutoff plt.tight_layout() # Reset the logger level logger.setLevel(logging.INFO) # Save if directory provided if self.save_dir: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = self.save_dir / f'validation_metrics_{timestamp}.png' plt.savefig(save_path) logger.info(f"Saved validation metrics plot to {save_path}") plt.show()