File size: 4,554 Bytes
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from datetime import datetime
import logging

logger = logging.getLogger(__name__)

class TrainingPlotter:
    def __init__(self, save_dir: Optional[Path] = None):
        self.save_dir = save_dir
        if save_dir:
            self.save_dir.mkdir(parents=True, exist_ok=True)

    def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
        """Plot and optionally save training metrics history.
        
        Args:
            history: Dictionary containing training metrics
            title: Title for the plot
        """
        # Silence matplotlib debug messages
        logger.setLevel(logging.WARNING)
        
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
        
        # Plot losses
        ax1.plot(history['train_loss'], label='Train Loss')
        ax1.plot(history['val_loss'], label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Plot learning rate if available
        if 'learning_rate' in history:
            ax2.plot(history['learning_rate'], label='Learning Rate')
            ax2.set_xlabel('Step')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Learning Rate Schedule')
            ax2.legend()
            ax2.grid(True)
        
        plt.suptitle(title)
        plt.tight_layout()
        
        # Reset the logger level
        logger.setLevel(logging.INFO)
        
        # Save if directory provided
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'training_history_{timestamp}.png'
            plt.savefig(save_path)
            logger.info(f"Saved training history plot to {save_path}")
        
        plt.show()
    
    def plot_validation_metrics(self, metrics: Dict[str, float]):
        """Plot validation metrics as a bar chart.
        
        Args:
            metrics: Dictionary of validation metrics. Can handle nested dictionaries.
        """
        # Silence matplotlib debug messages
        logger.setLevel(logging.WARNING)
        
        # Flatten nested metrics dictionary
        flat_metrics = {}
        for key, value in metrics.items():
            # Skip num_queries_tested
            if key == 'num_queries_tested':
                continue
                
            if isinstance(value, dict):
                # If value is a dictionary, flatten it with key prefix
                for subkey, subvalue in value.items():
                    if isinstance(subvalue, (int, float)):  # Only include numeric values
                        flat_metrics[f"{key}_{subkey}"] = subvalue
            elif isinstance(value, (int, float)):  # Only include numeric values
                flat_metrics[key] = value
        
        if not flat_metrics:
            logger.warning("No numeric metrics to plot")
            return
            
        plt.figure(figsize=(12, 6))
        
        # Extract metrics and values
        metric_names = list(flat_metrics.keys())
        values = list(flat_metrics.values())
        
        # Create bar chart
        bars = plt.bar(range(len(metric_names)), values)
        
        # Customize the plot
        plt.title('Validation Metrics')
        plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
        plt.ylabel('Value')
        
        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom')
        
        # Set y-axis limits to focus on metrics between 0 and 1
        plt.ylim(0, 1.1)  # Slight padding above 1 for label visibility
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        
        # Reset the logger level
        logger.setLevel(logging.INFO)
        
        # Save if directory provided
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
            plt.savefig(save_path)
            logger.info(f"Saved validation metrics plot to {save_path}")
        
        plt.show()