File size: 3,605 Bytes
f7b283c
 
 
 
 
fc5f33b
f7b283c
 
 
 
 
 
71ca212
f7b283c
71ca212
 
fc5f33b
f7b283c
 
 
 
 
 
 
 
 
 
 
 
71ca212
f7b283c
 
 
 
 
 
 
 
 
 
 
71ca212
f7b283c
 
 
 
 
 
 
 
71ca212
f7b283c
 
 
 
71ca212
f7b283c
 
 
 
 
71ca212
f7b283c
 
71ca212
f7b283c
71ca212
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca212
f7b283c
 
 
 
 
 
71ca212
 
f7b283c
 
71ca212
f7b283c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from datetime import datetime

class Plotter:
    def __init__(self, save_dir: Optional[Path] = None):
        self.save_dir = save_dir
        if save_dir:
            self.save_dir.mkdir(parents=True, exist_ok=True)

    def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
        """Plot and save training metrics history
        Args:
            history: Dict with training metrics
            title: Plot title
        """    
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
        
        # Plot losses
        ax1.plot(history['train_loss'], label='Train Loss')
        ax1.plot(history['val_loss'], label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Plot learning rate
        if 'learning_rate' in history:
            ax2.plot(history['learning_rate'], label='Learning Rate')
            ax2.set_xlabel('Step')
            ax2.set_ylabel('Learning Rate')
            ax2.set_title('Learning Rate Schedule')
            ax2.legend()
            ax2.grid(True)
        
        plt.suptitle(title)
        plt.tight_layout()
        
        # Save 
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'training_history_{timestamp}.png'
            plt.savefig(save_path)
        
        plt.show()
    
    def plot_validation_metrics(self, metrics: Dict[str, float]):
        """Plot validation metrics as a bar chart
        Args:
            metrics: Dictionary of validation metrics. Can handle nested dictionaries.
        """
        
        # Flatten nested metrics dict
        flat_metrics = {}
        for key, value in metrics.items():
            if key == 'num_queries_tested':
                continue
                
            # Flatten dict values, use numerical values only
            if isinstance(value, dict):
                for subkey, subvalue in value.items():
                    if isinstance(subvalue, (int, float)):
                        flat_metrics[f"{key}_{subkey}"] = subvalue
            elif isinstance(value, (int, float)):
                flat_metrics[key] = value
        
        if not flat_metrics:
            return
            
        plt.figure(figsize=(12, 6))
        
        # Extract metrics and values
        metric_names = list(flat_metrics.keys())
        values = list(flat_metrics.values())
        
        # Create bar chart
        bars = plt.bar(range(len(metric_names)), values)
        
        # Customize the plot
        plt.title('Validation Metrics')
        plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
        plt.ylabel('Value')
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom')
        
        # Set y-axis limits and adjust layout
        plt.ylim(0, 1.1)
        plt.tight_layout()
        
        # Save
        if self.save_dir:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
            plt.savefig(save_path)
        
        plt.show()