leaderboard / src /plotting.py
akera's picture
Create plotting.py
34a7f8e verified
raw
history blame
11.4 kB
# src/plotting.py
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
from colorsys import rgb_to_hls, hls_to_rgb
from collections import defaultdict
import numpy as np
import pandas as pd
from config import LANGUAGE_NAMES
def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure:
"""Create a horizontal bar chart showing model rankings."""
fig, ax = plt.subplots(figsize=(12, 8))
# Sort by the selected metric (descending)
df_sorted = leaderboard_df.sort_values(metric, ascending=True)
# Create color palette
colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted)))
# Create horizontal bar chart
bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors)
# Customize the plot
ax.set_yticks(range(len(df_sorted)))
ax.set_yticklabels(df_sorted['model_display_name'])
ax.set_xlabel(f'{metric.replace("_", " ").title()} Score')
ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20)
# Add value labels on bars
for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])):
ax.text(value + 0.001, bar.get_y() + bar.get_height()/2,
f'{value:.3f}', ha='left', va='center', fontweight='bold')
# Add grid for better readability
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_axisbelow(True)
# Set x-axis limits with some padding
max_val = df_sorted[metric].max()
ax.set_xlim(0, max_val * 1.15)
plt.tight_layout()
return fig
def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure:
"""Create detailed comparison plot similar to the original evaluation script."""
# Filter metrics_data to only include models in model_names
filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data}
if not filtered_metrics:
# Create empty plot if no data
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'No data available for comparison',
ha='center', va='center', transform=ax.transAxes, fontsize=16)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
return fig
return plot_translation_metric_comparison(filtered_metrics, metric='bleu')
def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure:
"""
Creates a grouped bar chart comparing a selected metric across translation models.
Adapted from the original plotting code.
"""
# Split language pairs into xx_to_eng and eng_to_xx categories
first_model_data = list(metrics_by_model.values())[0]
xx_to_eng = [key for key in first_model_data.keys()
if key.endswith('_to_eng') and key != 'averages']
eng_to_xx = [key for key in first_model_data.keys()
if key.startswith('eng_to_') and key != 'averages']
# Function to create nice labels
def format_label(label):
if label.startswith("eng_to_"):
source, target = "English", label.replace("eng_to_", "")
target = LANGUAGE_NAMES.get(target, target)
else:
source, target = label.replace("_to_eng", ""), "English"
source = LANGUAGE_NAMES.get(source, source)
return f"{source}{target}"
# Extract metric values for each category
def extract_metric_values(model_metrics, pairs, metric_name):
return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs]
xx_to_eng_data = {
model_name: extract_metric_values(model_data, xx_to_eng, metric)
for model_name, model_data in metrics_by_model.items()
}
eng_to_xx_data = {
model_name: extract_metric_values(model_data, eng_to_xx, metric)
for model_name, model_data in metrics_by_model.items()
}
averages_data = {
model_name: [model_data.get("averages", {}).get(metric, 0.0)]
for model_name, model_data in metrics_by_model.items()
}
# Set up plot with custom grid
fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing
# Create a GridSpec with 1 row and 5 columns
gs = gridspec.GridSpec(1, 5)
# Colors for the models
model_names = list(metrics_by_model.keys())
family_base_colors = {
'gemma': '#3274A1',
'nllb': '#7f7f7f',
'qwen': '#E1812C',
'google': '#3A923A',
'other': '#D62728',
}
# Identify the family for each model
def get_family(model_name):
model_lower = model_name.lower()
if 'gemma' in model_lower:
return 'gemma'
elif 'qwen' in model_lower:
return 'qwen'
elif 'nllb' in model_lower:
return 'nllb'
elif 'google' in model_lower or model_name == 'google-translate':
return 'google'
else:
return 'other'
# Count how many models belong to each family
family_counts = defaultdict(int)
for model in model_names:
family = get_family(model)
family_counts[family] += 1
# Generate slightly varied lightness within each family
colors = []
family_indices = defaultdict(int)
for model in model_names:
family = get_family(model)
base_rgb = mcolors.to_rgb(family_base_colors[family])
h, l, s = rgb_to_hls(*base_rgb)
index = family_indices[family]
count = family_counts[family]
# Vary lightness: from 0.35 to 0.65
if count == 1:
new_l = l # Keep original for single models
else:
new_l = 0.65 - 0.3 * (index / max(count - 1, 1))
varied_rgb = hls_to_rgb(h, new_l, s)
hex_color = mcolors.to_hex(varied_rgb)
colors.append(hex_color)
family_indices[family] += 1
bar_width = 0.2
opacity = 0.8
# Positions for the bars
xx_to_eng_indices = np.arange(len(xx_to_eng))
eng_to_xx_indices = np.arange(len(eng_to_xx))
avg_index = np.array([0])
# Determine y-axis limits based on metric
if metric in ['chrf', 'len_ratio']:
y_max = 1.1
elif metric in ['cer', 'wer']:
y_max = 1.0
elif metric == 'bleu':
y_max = 65 # Increased from 55 to accommodate high scores
elif metric in ['rouge1', 'rouge2', 'rougeL']:
y_max = 1.0
elif metric == 'quality_score':
y_max = 0.65
else:
# Auto-scale based on data
all_values = []
for data in [xx_to_eng_data, eng_to_xx_data, averages_data]:
for model_data in data.values():
all_values.extend(model_data)
y_max = max(all_values) * 1.1 if all_values else 1.0
# Format metric name for display
metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title()
# Create bars for xx_to_eng (using first 2 columns)
if xx_to_eng:
ax1 = plt.subplot(gs[0, 0:2])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in xx_to_eng_data:
ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax1.set_xlabel('Translation Direction')
ax1.set_ylabel(f'{metric_display} Score')
ax1.set_title(f'XX→English {metric_display} Performance')
ax1.set_xticks(xx_to_eng_indices + bar_width)
ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right')
ax1.set_ylim(0, y_max)
ax1.grid(axis='y', linestyle='--', alpha=0.7)
# Create bars for eng_to_xx (using next 2 columns)
if eng_to_xx:
ax2 = plt.subplot(gs[0, 2:4])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in eng_to_xx_data:
ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax2.set_xlabel('Translation Direction')
ax2.set_ylabel(f'{metric_display} Score')
ax2.set_title(f'English→XX {metric_display} Performance')
ax2.set_xticks(eng_to_xx_indices + bar_width)
ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right')
ax2.set_ylim(0, y_max)
ax2.grid(axis='y', linestyle='--', alpha=0.7)
# Create bars for averages (using last column)
ax3 = plt.subplot(gs[0, 4])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in averages_data:
ax3.bar(avg_index + i*bar_width, averages_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax3.set_xlabel('Overall')
ax3.set_ylabel(f'{metric_display} Score')
ax3.set_title(f'Average {metric_display}')
ax3.set_xticks(avg_index + bar_width)
ax3.set_xticklabels(['Average'])
ax3.set_ylim(0, y_max)
ax3.grid(axis='y', linestyle='--', alpha=0.7)
ax3.legend()
# Add note for metrics where lower is better
if metric in ['cer', 'wer']:
plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric",
ha='center', fontsize=12, style='italic')
# Add an overall title and adjust layout
model_list = ' vs '.join(model_names)
plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98)
plt.tight_layout(rect=[0, 0.02, 1, 0.95])
return fig
def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure:
"""Create a summary plot showing multiple metrics for top models."""
if leaderboard_df.empty:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
transform=ax.transAxes, fontsize=16)
return fig
# Select top 5 models by quality score
top_models = leaderboard_df.nlargest(5, 'quality_score')
# Metrics to display
metrics = ['bleu', 'chrf', 'quality_score']
metric_labels = ['BLEU', 'ChrF', 'Quality Score']
fig, axes = plt.subplots(1, 3, figsize=(15, 6))
for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
ax = axes[i]
# Sort by current metric
sorted_models = top_models.sort_values(metric, ascending=True)
# Create horizontal bar chart
bars = ax.barh(range(len(sorted_models)), sorted_models[metric],
color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models))))
ax.set_yticks(range(len(sorted_models)))
ax.set_yticklabels(sorted_models['model_display_name'])
ax.set_xlabel(f'{label} Score')
ax.set_title(f'Top Models - {label}')
ax.grid(axis='x', linestyle='--', alpha=0.7)
# Add value labels
for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])):
ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2,
f'{value:.3f}', ha='left', va='center', fontsize=10)
plt.tight_layout()
return fig