File size: 3,957 Bytes
f7b283c
 
 
 
 
fc5f33b
f7b283c
 
 
 
 
 
 
 
 
 
 
fc5f33b
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from datetime import datetime

class Plotter:
    def __init__(self, save_dir: Optional[Path] = None):
        self.save_dir = save_dir
        if save_dir:
            self.save_dir.mkdir(parents=True, exist_ok=True)

    def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
        """Plot and optionally save training metrics history.
        
        Args:
            history: Dictionary containing training metrics
            title: Title for the plot
        """    
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
        
        # Plot losses
        ax1.plot(history['train_loss'], label='Train Loss')
        ax1.plot(history['val_loss'], label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Plot learning rate if available
        if 'learning_rate' in history:
            ax2.plot(history['learning_rate'], label='Learning Rate')
            ax2.set_xlabel('Step')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Learning Rate Schedule')
            ax2.legend()
            ax2.grid(True)
        
        plt.suptitle(title)
        plt.tight_layout()
        
        # Save if directory provided
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'training_history_{timestamp}.png'
            plt.savefig(save_path)
        
        plt.show()
    
    def plot_validation_metrics(self, metrics: Dict[str, float]):
        """Plot validation metrics as a bar chart.
        
        Args:
            metrics: Dictionary of validation metrics. Can handle nested dictionaries.
        """
        
        # Flatten nested metrics dictionary
        flat_metrics = {}
        for key, value in metrics.items():
            # Skip num_queries_tested
            if key == 'num_queries_tested':
                continue
                
            if isinstance(value, dict):
                # If value is a dictionary, flatten it with key prefix
                for subkey, subvalue in value.items():
                    if isinstance(subvalue, (int, float)):  # Only include numeric values
                        flat_metrics[f"{key}_{subkey}"] = subvalue
            elif isinstance(value, (int, float)):  # Only include numeric values
                flat_metrics[key] = value
        
        if not flat_metrics:
            return
            
        plt.figure(figsize=(12, 6))
        
        # Extract metrics and values
        metric_names = list(flat_metrics.keys())
        values = list(flat_metrics.values())
        
        # Create bar chart
        bars = plt.bar(range(len(metric_names)), values)
        
        # Customize the plot
        plt.title('Validation Metrics')
        plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
        plt.ylabel('Value')
        
        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom')
        
        # Set y-axis limits to focus on metrics between 0 and 1
        plt.ylim(0, 1.1)  # Slight padding above 1 for label visibility
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        
        # Save if directory provided
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
            plt.savefig(save_path)
        
        plt.show()