csc525_retrieval_based_chatbot / training_plotter.py
JoeArmani
summarization, reranker, environment setup, and response quality checker
f7b283c
raw
history blame
4.55 kB
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class TrainingPlotter:
def __init__(self, save_dir: Optional[Path] = None):
self.save_dir = save_dir
if save_dir:
self.save_dir.mkdir(parents=True, exist_ok=True)
def plot_training_history(self, history: Dict[str, List[float]], title: str = "Training History"):
"""Plot and optionally save training metrics history.
Args:
history: Dictionary containing training metrics
title: Title for the plot
"""
# Silence matplotlib debug messages
logger.setLevel(logging.WARNING)
# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
# Plot losses
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)
# Plot learning rate if available
if 'learning_rate' in history:
ax2.plot(history['learning_rate'], label='Learning Rate')
ax2.set_xlabel('Step')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.legend()
ax2.grid(True)
plt.suptitle(title)
plt.tight_layout()
# Reset the logger level
logger.setLevel(logging.INFO)
# Save if directory provided
if self.save_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.save_dir / f'training_history_{timestamp}.png'
plt.savefig(save_path)
logger.info(f"Saved training history plot to {save_path}")
plt.show()
def plot_validation_metrics(self, metrics: Dict[str, float]):
"""Plot validation metrics as a bar chart.
Args:
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
"""
# Silence matplotlib debug messages
logger.setLevel(logging.WARNING)
# Flatten nested metrics dictionary
flat_metrics = {}
for key, value in metrics.items():
# Skip num_queries_tested
if key == 'num_queries_tested':
continue
if isinstance(value, dict):
# If value is a dictionary, flatten it with key prefix
for subkey, subvalue in value.items():
if isinstance(subvalue, (int, float)): # Only include numeric values
flat_metrics[f"{key}_{subkey}"] = subvalue
elif isinstance(value, (int, float)): # Only include numeric values
flat_metrics[key] = value
if not flat_metrics:
logger.warning("No numeric metrics to plot")
return
plt.figure(figsize=(12, 6))
# Extract metrics and values
metric_names = list(flat_metrics.keys())
values = list(flat_metrics.values())
# Create bar chart
bars = plt.bar(range(len(metric_names)), values)
# Customize the plot
plt.title('Validation Metrics')
plt.xticks(range(len(metric_names)), metric_names, rotation=45, ha='right')
plt.ylabel('Value')
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom')
# Set y-axis limits to focus on metrics between 0 and 1
plt.ylim(0, 1.1) # Slight padding above 1 for label visibility
# Adjust layout to prevent label cutoff
plt.tight_layout()
# Reset the logger level
logger.setLevel(logging.INFO)
# Save if directory provided
if self.save_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
plt.savefig(save_path)
logger.info(f"Saved validation metrics plot to {save_path}")
plt.show()