Spaces:
Sleeping
Sleeping
File size: 11,366 Bytes
34a7f8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
# src/plotting.py
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
from colorsys import rgb_to_hls, hls_to_rgb
from collections import defaultdict
import numpy as np
import pandas as pd
from config import LANGUAGE_NAMES
def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure:
"""Create a horizontal bar chart showing model rankings."""
fig, ax = plt.subplots(figsize=(12, 8))
# Sort by the selected metric (descending)
df_sorted = leaderboard_df.sort_values(metric, ascending=True)
# Create color palette
colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted)))
# Create horizontal bar chart
bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors)
# Customize the plot
ax.set_yticks(range(len(df_sorted)))
ax.set_yticklabels(df_sorted['model_display_name'])
ax.set_xlabel(f'{metric.replace("_", " ").title()} Score')
ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20)
# Add value labels on bars
for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])):
ax.text(value + 0.001, bar.get_y() + bar.get_height()/2,
f'{value:.3f}', ha='left', va='center', fontweight='bold')
# Add grid for better readability
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_axisbelow(True)
# Set x-axis limits with some padding
max_val = df_sorted[metric].max()
ax.set_xlim(0, max_val * 1.15)
plt.tight_layout()
return fig
def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure:
"""Create detailed comparison plot similar to the original evaluation script."""
# Filter metrics_data to only include models in model_names
filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data}
if not filtered_metrics:
# Create empty plot if no data
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'No data available for comparison',
ha='center', va='center', transform=ax.transAxes, fontsize=16)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
return fig
return plot_translation_metric_comparison(filtered_metrics, metric='bleu')
def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure:
"""
Creates a grouped bar chart comparing a selected metric across translation models.
Adapted from the original plotting code.
"""
# Split language pairs into xx_to_eng and eng_to_xx categories
first_model_data = list(metrics_by_model.values())[0]
xx_to_eng = [key for key in first_model_data.keys()
if key.endswith('_to_eng') and key != 'averages']
eng_to_xx = [key for key in first_model_data.keys()
if key.startswith('eng_to_') and key != 'averages']
# Function to create nice labels
def format_label(label):
if label.startswith("eng_to_"):
source, target = "English", label.replace("eng_to_", "")
target = LANGUAGE_NAMES.get(target, target)
else:
source, target = label.replace("_to_eng", ""), "English"
source = LANGUAGE_NAMES.get(source, source)
return f"{source}→{target}"
# Extract metric values for each category
def extract_metric_values(model_metrics, pairs, metric_name):
return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs]
xx_to_eng_data = {
model_name: extract_metric_values(model_data, xx_to_eng, metric)
for model_name, model_data in metrics_by_model.items()
}
eng_to_xx_data = {
model_name: extract_metric_values(model_data, eng_to_xx, metric)
for model_name, model_data in metrics_by_model.items()
}
averages_data = {
model_name: [model_data.get("averages", {}).get(metric, 0.0)]
for model_name, model_data in metrics_by_model.items()
}
# Set up plot with custom grid
fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing
# Create a GridSpec with 1 row and 5 columns
gs = gridspec.GridSpec(1, 5)
# Colors for the models
model_names = list(metrics_by_model.keys())
family_base_colors = {
'gemma': '#3274A1',
'nllb': '#7f7f7f',
'qwen': '#E1812C',
'google': '#3A923A',
'other': '#D62728',
}
# Identify the family for each model
def get_family(model_name):
model_lower = model_name.lower()
if 'gemma' in model_lower:
return 'gemma'
elif 'qwen' in model_lower:
return 'qwen'
elif 'nllb' in model_lower:
return 'nllb'
elif 'google' in model_lower or model_name == 'google-translate':
return 'google'
else:
return 'other'
# Count how many models belong to each family
family_counts = defaultdict(int)
for model in model_names:
family = get_family(model)
family_counts[family] += 1
# Generate slightly varied lightness within each family
colors = []
family_indices = defaultdict(int)
for model in model_names:
family = get_family(model)
base_rgb = mcolors.to_rgb(family_base_colors[family])
h, l, s = rgb_to_hls(*base_rgb)
index = family_indices[family]
count = family_counts[family]
# Vary lightness: from 0.35 to 0.65
if count == 1:
new_l = l # Keep original for single models
else:
new_l = 0.65 - 0.3 * (index / max(count - 1, 1))
varied_rgb = hls_to_rgb(h, new_l, s)
hex_color = mcolors.to_hex(varied_rgb)
colors.append(hex_color)
family_indices[family] += 1
bar_width = 0.2
opacity = 0.8
# Positions for the bars
xx_to_eng_indices = np.arange(len(xx_to_eng))
eng_to_xx_indices = np.arange(len(eng_to_xx))
avg_index = np.array([0])
# Determine y-axis limits based on metric
if metric in ['chrf', 'len_ratio']:
y_max = 1.1
elif metric in ['cer', 'wer']:
y_max = 1.0
elif metric == 'bleu':
y_max = 65 # Increased from 55 to accommodate high scores
elif metric in ['rouge1', 'rouge2', 'rougeL']:
y_max = 1.0
elif metric == 'quality_score':
y_max = 0.65
else:
# Auto-scale based on data
all_values = []
for data in [xx_to_eng_data, eng_to_xx_data, averages_data]:
for model_data in data.values():
all_values.extend(model_data)
y_max = max(all_values) * 1.1 if all_values else 1.0
# Format metric name for display
metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title()
# Create bars for xx_to_eng (using first 2 columns)
if xx_to_eng:
ax1 = plt.subplot(gs[0, 0:2])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in xx_to_eng_data:
ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax1.set_xlabel('Translation Direction')
ax1.set_ylabel(f'{metric_display} Score')
ax1.set_title(f'XX→English {metric_display} Performance')
ax1.set_xticks(xx_to_eng_indices + bar_width)
ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right')
ax1.set_ylim(0, y_max)
ax1.grid(axis='y', linestyle='--', alpha=0.7)
# Create bars for eng_to_xx (using next 2 columns)
if eng_to_xx:
ax2 = plt.subplot(gs[0, 2:4])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in eng_to_xx_data:
ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax2.set_xlabel('Translation Direction')
ax2.set_ylabel(f'{metric_display} Score')
ax2.set_title(f'English→XX {metric_display} Performance')
ax2.set_xticks(eng_to_xx_indices + bar_width)
ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right')
ax2.set_ylim(0, y_max)
ax2.grid(axis='y', linestyle='--', alpha=0.7)
# Create bars for averages (using last column)
ax3 = plt.subplot(gs[0, 4])
for i, (model_name, color) in enumerate(zip(model_names, colors)):
if model_name in averages_data:
ax3.bar(avg_index + i*bar_width, averages_data[model_name],
bar_width, alpha=opacity, color=color, label=model_name)
ax3.set_xlabel('Overall')
ax3.set_ylabel(f'{metric_display} Score')
ax3.set_title(f'Average {metric_display}')
ax3.set_xticks(avg_index + bar_width)
ax3.set_xticklabels(['Average'])
ax3.set_ylim(0, y_max)
ax3.grid(axis='y', linestyle='--', alpha=0.7)
ax3.legend()
# Add note for metrics where lower is better
if metric in ['cer', 'wer']:
plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric",
ha='center', fontsize=12, style='italic')
# Add an overall title and adjust layout
model_list = ' vs '.join(model_names)
plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98)
plt.tight_layout(rect=[0, 0.02, 1, 0.95])
return fig
def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure:
"""Create a summary plot showing multiple metrics for top models."""
if leaderboard_df.empty:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
transform=ax.transAxes, fontsize=16)
return fig
# Select top 5 models by quality score
top_models = leaderboard_df.nlargest(5, 'quality_score')
# Metrics to display
metrics = ['bleu', 'chrf', 'quality_score']
metric_labels = ['BLEU', 'ChrF', 'Quality Score']
fig, axes = plt.subplots(1, 3, figsize=(15, 6))
for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
ax = axes[i]
# Sort by current metric
sorted_models = top_models.sort_values(metric, ascending=True)
# Create horizontal bar chart
bars = ax.barh(range(len(sorted_models)), sorted_models[metric],
color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models))))
ax.set_yticks(range(len(sorted_models)))
ax.set_yticklabels(sorted_models['model_display_name'])
ax.set_xlabel(f'{label} Score')
ax.set_title(f'Top Models - {label}')
ax.grid(axis='x', linestyle='--', alpha=0.7)
# Add value labels
for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])):
ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2,
f'{value:.3f}', ha='left', va='center', fontsize=10)
plt.tight_layout()
return fig |