from pathlib import Path from typing import Dict, List, Optional import matplotlib.pyplot as plt from datetime import datetime class Plotter: def __init__(self, save_dir: Optional[Path] = None): self.save_dir = save_dir if save_dir: self.save_dir.mkdir(parents=True, exist_ok=True) def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"): """Plot and save training metrics history Args: history: Dict with training metrics title: Plot title """ # Create figure with subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) # Plot losses ax1.plot(history['train_loss'], label='Train Loss') ax1.plot(history['val_loss'], label='Validation Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.set_title('Training and Validation Loss') ax1.legend() ax1.grid(True) # Plot learning rate if 'learning_rate' in history: ax2.plot(history['learning_rate'], label='Learning Rate') ax2.set_xlabel('Step') ax2.set_ylabel('Learning Rate') ax2.set_title('Learning Rate Schedule') ax2.legend() ax2.grid(True) plt.suptitle(title) plt.tight_layout() # Save if self.save_dir: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = self.save_dir / f'training_history_{timestamp}.png' plt.savefig(save_path) plt.show() def plot_validation_metrics(self, metrics: Dict[str, float]): """Plot validation metrics as a bar chart Args: metrics: Dictionary of validation metrics. Can handle nested dictionaries. """ # Flatten nested metrics dict flat_metrics = {} for key, value in metrics.items(): if key == 'num_queries_tested': continue # Flatten dict values, use numerical values only if isinstance(value, dict): for subkey, subvalue in value.items(): if isinstance(subvalue, (int, float)): flat_metrics[f"{key}_{subkey}"] = subvalue elif isinstance(value, (int, float)): flat_metrics[key] = value if not flat_metrics: return plt.figure(figsize=(12, 6)) # Extract metrics and values metric_names = list(flat_metrics.keys()) values = list(flat_metrics.values()) # Create bar chart bars = plt.bar(range(len(metric_names)), values) # Customize the plot plt.title('Validation Metrics') plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right') plt.ylabel('Value') # Add value labels on bars for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom') # Set y-axis limits and adjust layout plt.ylim(0, 1.1) plt.tight_layout() # Save if self.save_dir: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = self.save_dir / f'validation_metrics_{timestamp}.png' plt.savefig(save_path) plt.show()