akera commited on
Commit
944a871
·
verified ·
1 Parent(s): 423834f

Rename src/plotting.py to src/leaderboard.py

Browse files
Files changed (2) hide show
  1. src/leaderboard.py +381 -0
  2. src/plotting.py +0 -296
src/leaderboard.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/leaderboard.py
2
+ import pandas as pd
3
+ from datasets import Dataset, load_dataset
4
+ import json
5
+ import datetime
6
+ from typing import Dict, List, Optional, Tuple
7
+ import os
8
+ from config import LEADERBOARD_DATASET, HF_TOKEN, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES
9
+ from src.utils import create_submission_id, sanitize_model_name, get_all_language_pairs, get_google_comparable_pairs
10
+
11
+ def initialize_leaderboard() -> pd.DataFrame:
12
+ """Initialize empty leaderboard DataFrame."""
13
+
14
+ columns = {
15
+ 'submission_id': [],
16
+ 'model_name': [],
17
+ 'author': [],
18
+ 'submission_date': [],
19
+ 'model_type': [],
20
+ 'description': [],
21
+
22
+ # Primary metrics
23
+ 'quality_score': [],
24
+ 'bleu': [],
25
+ 'chrf': [],
26
+
27
+ # Secondary metrics
28
+ 'rouge1': [],
29
+ 'rouge2': [],
30
+ 'rougeL': [],
31
+ 'cer': [],
32
+ 'wer': [],
33
+ 'len_ratio': [],
34
+
35
+ # Google comparable metrics
36
+ 'google_quality_score': [],
37
+ 'google_bleu': [],
38
+ 'google_chrf': [],
39
+
40
+ # Coverage info
41
+ 'total_samples': [],
42
+ 'language_pairs_covered': [],
43
+ 'google_pairs_covered': [],
44
+ 'coverage_rate': [],
45
+
46
+ # Detailed results
47
+ 'detailed_metrics': [], # JSON string
48
+ 'validation_report': [],
49
+
50
+ # Metadata
51
+ 'evaluation_date': [],
52
+ 'leaderboard_version': []
53
+ }
54
+
55
+ return pd.DataFrame(columns)
56
+
57
+ def load_leaderboard() -> pd.DataFrame:
58
+ """Load current leaderboard from HuggingFace dataset."""
59
+
60
+ try:
61
+ print("Loading leaderboard...")
62
+ dataset = load_dataset(LEADERBOARD_DATASET, split='train')
63
+ df = dataset.to_pandas()
64
+
65
+ # Ensure all required columns exist
66
+ required_columns = list(initialize_leaderboard().columns)
67
+ for col in required_columns:
68
+ if col not in df.columns:
69
+ if col in ['quality_score', 'bleu', 'chrf', 'rouge1', 'rouge2', 'rougeL',
70
+ 'cer', 'wer', 'len_ratio', 'google_quality_score', 'google_bleu',
71
+ 'google_chrf', 'total_samples', 'language_pairs_covered',
72
+ 'google_pairs_covered', 'coverage_rate']:
73
+ df[col] = 0.0
74
+ elif col in ['leaderboard_version']:
75
+ df[col] = 1
76
+ else:
77
+ df[col] = ''
78
+
79
+ print(f"Loaded leaderboard with {len(df)} entries")
80
+ return df
81
+
82
+ except Exception as e:
83
+ print(f"Could not load leaderboard: {e}")
84
+ print("Initializing empty leaderboard...")
85
+ return initialize_leaderboard()
86
+
87
+ def save_leaderboard(df: pd.DataFrame) -> bool:
88
+ """Save leaderboard to HuggingFace dataset."""
89
+
90
+ try:
91
+ # Clean data before saving
92
+ df_clean = df.copy()
93
+
94
+ # Ensure numeric columns are proper types
95
+ numeric_columns = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rouge2', 'rougeL',
96
+ 'cer', 'wer', 'len_ratio', 'google_quality_score', 'google_bleu',
97
+ 'google_chrf', 'total_samples', 'language_pairs_covered',
98
+ 'google_pairs_covered', 'coverage_rate', 'leaderboard_version']
99
+
100
+ for col in numeric_columns:
101
+ if col in df_clean.columns:
102
+ df_clean[col] = pd.to_numeric(df_clean[col], errors='coerce').fillna(0.0)
103
+
104
+ # Convert to dataset
105
+ dataset = Dataset.from_pandas(df_clean)
106
+
107
+ # Push to hub
108
+ dataset.push_to_hub(
109
+ LEADERBOARD_DATASET,
110
+ token=HF_TOKEN,
111
+ commit_message=f"Update leaderboard - {datetime.datetime.now().isoformat()[:19]}"
112
+ )
113
+
114
+ print("Leaderboard saved successfully!")
115
+ return True
116
+
117
+ except Exception as e:
118
+ print(f"Error saving leaderboard: {e}")
119
+ return False
120
+
121
+ def add_model_to_leaderboard(
122
+ model_name: str,
123
+ author: str,
124
+ evaluation_results: Dict,
125
+ validation_info: Dict,
126
+ model_type: str = "",
127
+ description: str = ""
128
+ ) -> pd.DataFrame:
129
+ """Add new model results to leaderboard."""
130
+
131
+ # Load current leaderboard
132
+ df = load_leaderboard()
133
+
134
+ # Check if model already exists
135
+ existing_mask = df['model_name'] == model_name
136
+ if existing_mask.any():
137
+ print(f"Model '{model_name}' already exists. Updating...")
138
+ df = df[~existing_mask] # Remove existing entry
139
+
140
+ # Extract metrics
141
+ averages = evaluation_results.get('averages', {})
142
+ google_averages = evaluation_results.get('google_comparable_averages', {})
143
+ summary = evaluation_results.get('summary', {})
144
+
145
+ # Create new entry
146
+ new_entry = {
147
+ 'submission_id': create_submission_id(),
148
+ 'model_name': sanitize_model_name(model_name),
149
+ 'author': author[:100] if author else 'Anonymous',
150
+ 'submission_date': datetime.datetime.now().isoformat(),
151
+ 'model_type': model_type[:50] if model_type else 'unknown',
152
+ 'description': description[:500] if description else '',
153
+
154
+ # Primary metrics
155
+ 'quality_score': float(averages.get('quality_score', 0.0)),
156
+ 'bleu': float(averages.get('bleu', 0.0)),
157
+ 'chrf': float(averages.get('chrf', 0.0)),
158
+
159
+ # Secondary metrics
160
+ 'rouge1': float(averages.get('rouge1', 0.0)),
161
+ 'rouge2': float(averages.get('rouge2', 0.0)),
162
+ 'rougeL': float(averages.get('rougeL', 0.0)),
163
+ 'cer': float(averages.get('cer', 0.0)),
164
+ 'wer': float(averages.get('wer', 0.0)),
165
+ 'len_ratio': float(averages.get('len_ratio', 0.0)),
166
+
167
+ # Google comparable metrics
168
+ 'google_quality_score': float(google_averages.get('quality_score', 0.0)),
169
+ 'google_bleu': float(google_averages.get('bleu', 0.0)),
170
+ 'google_chrf': float(google_averages.get('chrf', 0.0)),
171
+
172
+ # Coverage info
173
+ 'total_samples': int(summary.get('total_samples', 0)),
174
+ 'language_pairs_covered': int(summary.get('language_pairs_covered', 0)),
175
+ 'google_pairs_covered': int(summary.get('google_comparable_pairs', 0)),
176
+ 'coverage_rate': float(validation_info.get('coverage', 0.0)),
177
+
178
+ # Detailed results
179
+ 'detailed_metrics': json.dumps(evaluation_results),
180
+ 'validation_report': validation_info.get('report', ''),
181
+
182
+ # Metadata
183
+ 'evaluation_date': datetime.datetime.now().isoformat(),
184
+ 'leaderboard_version': 1
185
+ }
186
+
187
+ # Add to dataframe
188
+ new_row_df = pd.DataFrame([new_entry])
189
+ updated_df = pd.concat([df, new_row_df], ignore_index=True)
190
+
191
+ # Sort by quality score (descending)
192
+ updated_df = updated_df.sort_values('quality_score', ascending=False).reset_index(drop=True)
193
+
194
+ # Save updated leaderboard
195
+ if save_leaderboard(updated_df):
196
+ print(f"Added '{model_name}' to leaderboard")
197
+ return updated_df
198
+ else:
199
+ print("Failed to save leaderboard")
200
+ return df
201
+
202
+ def get_leaderboard_stats(df: pd.DataFrame) -> Dict:
203
+ """Get summary statistics for the leaderboard."""
204
+
205
+ if df.empty:
206
+ return {
207
+ 'total_models': 0,
208
+ 'avg_quality_score': 0.0,
209
+ 'best_model': None,
210
+ 'latest_submission': None,
211
+ 'google_comparable_models': 0,
212
+ 'coverage_distribution': {},
213
+ 'language_pair_coverage': {}
214
+ }
215
+
216
+ # Basic stats
217
+ stats = {
218
+ 'total_models': len(df),
219
+ 'avg_quality_score': float(df['quality_score'].mean()),
220
+ 'best_model': {
221
+ 'name': df.iloc[0]['model_name'],
222
+ 'score': float(df.iloc[0]['quality_score']),
223
+ 'author': df.iloc[0]['author']
224
+ } if len(df) > 0 else None,
225
+ 'latest_submission': df['submission_date'].max() if len(df) > 0 else None
226
+ }
227
+
228
+ # Google comparable models
229
+ stats['google_comparable_models'] = int((df['google_pairs_covered'] > 0).sum())
230
+
231
+ # Coverage distribution
232
+ coverage_bins = pd.cut(df['coverage_rate'], bins=[0, 0.5, 0.8, 0.95, 1.0],
233
+ labels=['<50%', '50-80%', '80-95%', '95-100%'])
234
+ stats['coverage_distribution'] = coverage_bins.value_counts().to_dict()
235
+
236
+ # Language pair coverage
237
+ if len(df) > 0:
238
+ stats['avg_pairs_covered'] = float(df['language_pairs_covered'].mean())
239
+ stats['max_pairs_covered'] = int(df['language_pairs_covered'].max())
240
+ stats['total_possible_pairs'] = len(get_all_language_pairs())
241
+
242
+ return stats
243
+
244
+ def filter_leaderboard(
245
+ df: pd.DataFrame,
246
+ search_query: str = "",
247
+ model_type: str = "",
248
+ min_coverage: float = 0.0,
249
+ google_comparable_only: bool = False,
250
+ top_n: int = None
251
+ ) -> pd.DataFrame:
252
+ """Filter leaderboard based on various criteria."""
253
+
254
+ filtered_df = df.copy()
255
+
256
+ # Text search
257
+ if search_query:
258
+ query_lower = search_query.lower()
259
+ mask = (
260
+ filtered_df['model_name'].str.lower().str.contains(query_lower, na=False) |
261
+ filtered_df['author'].str.lower().str.contains(query_lower, na=False) |
262
+ filtered_df['description'].str.lower().str.contains(query_lower, na=False)
263
+ )
264
+ filtered_df = filtered_df[mask]
265
+
266
+ # Model type filter
267
+ if model_type and model_type != "all":
268
+ filtered_df = filtered_df[filtered_df['model_type'] == model_type]
269
+
270
+ # Coverage filter
271
+ if min_coverage > 0:
272
+ filtered_df = filtered_df[filtered_df['coverage_rate'] >= min_coverage]
273
+
274
+ # Google comparable filter
275
+ if google_comparable_only:
276
+ filtered_df = filtered_df[filtered_df['google_pairs_covered'] > 0]
277
+
278
+ # Top N filter
279
+ if top_n:
280
+ filtered_df = filtered_df.head(top_n)
281
+
282
+ return filtered_df
283
+
284
+ def get_model_comparison(df: pd.DataFrame, model_names: List[str]) -> Dict:
285
+ """Get detailed comparison between specific models."""
286
+
287
+ models = df[df['model_name'].isin(model_names)]
288
+
289
+ if len(models) == 0:
290
+ return {'error': 'No models found'}
291
+
292
+ comparison = {
293
+ 'models': [],
294
+ 'metrics_comparison': {},
295
+ 'detailed_results': {}
296
+ }
297
+
298
+ # Extract basic info for each model
299
+ for _, model in models.iterrows():
300
+ comparison['models'].append({
301
+ 'name': model['model_name'],
302
+ 'author': model['author'],
303
+ 'submission_date': model['submission_date'],
304
+ 'model_type': model['model_type']
305
+ })
306
+
307
+ # Parse detailed metrics if available
308
+ try:
309
+ detailed = json.loads(model['detailed_metrics'])
310
+ comparison['detailed_results'][model['model_name']] = detailed
311
+ except:
312
+ comparison['detailed_results'][model['model_name']] = {}
313
+
314
+ # Compare metrics
315
+ metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL', 'cer', 'wer']
316
+ for metric in metrics:
317
+ if metric in models.columns:
318
+ comparison['metrics_comparison'][metric] = {
319
+ model_name: float(score)
320
+ for model_name, score in zip(models['model_name'], models[metric])
321
+ }
322
+
323
+ return comparison
324
+
325
+ def export_leaderboard(df: pd.DataFrame, format: str = 'csv', include_detailed: bool = False) -> str:
326
+ """Export leaderboard in specified format."""
327
+
328
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
329
+
330
+ # Select columns for export
331
+ if include_detailed:
332
+ export_df = df.copy()
333
+ else:
334
+ basic_columns = [
335
+ 'model_name', 'author', 'submission_date', 'model_type',
336
+ 'quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL',
337
+ 'total_samples', 'language_pairs_covered', 'coverage_rate'
338
+ ]
339
+ export_df = df[basic_columns].copy()
340
+
341
+ if format == 'csv':
342
+ filename = f"salt_leaderboard_{timestamp}.csv"
343
+ export_df.to_csv(filename, index=False)
344
+ elif format == 'json':
345
+ filename = f"salt_leaderboard_{timestamp}.json"
346
+ export_df.to_json(filename, orient='records', indent=2)
347
+ elif format == 'xlsx':
348
+ filename = f"salt_leaderboard_{timestamp}.xlsx"
349
+ export_df.to_excel(filename, index=False)
350
+ else:
351
+ raise ValueError(f"Unsupported format: {format}")
352
+
353
+ return filename
354
+
355
+ def get_ranking_history(df: pd.DataFrame, model_name: str) -> Dict:
356
+ """Get ranking history for a specific model (if multiple submissions)."""
357
+
358
+ model_entries = df[df['model_name'] == model_name].sort_values('submission_date')
359
+
360
+ if len(model_entries) == 0:
361
+ return {'error': 'Model not found'}
362
+
363
+ history = []
364
+ for _, entry in model_entries.iterrows():
365
+ # Calculate rank at time of submission
366
+ submission_date = entry['submission_date']
367
+ historical_df = df[df['submission_date'] <= submission_date]
368
+ rank = (historical_df['quality_score'] > entry['quality_score']).sum() + 1
369
+
370
+ history.append({
371
+ 'submission_date': submission_date,
372
+ 'quality_score': float(entry['quality_score']),
373
+ 'rank': int(rank),
374
+ 'total_models': len(historical_df)
375
+ })
376
+
377
+ return {
378
+ 'model_name': model_name,
379
+ 'history': history,
380
+ 'current_rank': history[-1]['rank'] if history else None
381
+ }
src/plotting.py DELETED
@@ -1,296 +0,0 @@
1
- # src/plotting.py
2
- import matplotlib.pyplot as plt
3
- import matplotlib.gridspec as gridspec
4
- import matplotlib.colors as mcolors
5
- from colorsys import rgb_to_hls, hls_to_rgb
6
- from collections import defaultdict
7
- import numpy as np
8
- import pandas as pd
9
- from config import LANGUAGE_NAMES
10
-
11
- def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure:
12
- """Create a horizontal bar chart showing model rankings."""
13
-
14
- fig, ax = plt.subplots(figsize=(12, 8))
15
-
16
- # Sort by the selected metric (descending)
17
- df_sorted = leaderboard_df.sort_values(metric, ascending=True)
18
-
19
- # Create color palette
20
- colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted)))
21
-
22
- # Create horizontal bar chart
23
- bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors)
24
-
25
- # Customize the plot
26
- ax.set_yticks(range(len(df_sorted)))
27
- ax.set_yticklabels(df_sorted['model_display_name'])
28
- ax.set_xlabel(f'{metric.replace("_", " ").title()} Score')
29
- ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20)
30
-
31
- # Add value labels on bars
32
- for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])):
33
- ax.text(value + 0.001, bar.get_y() + bar.get_height()/2,
34
- f'{value:.3f}', ha='left', va='center', fontweight='bold')
35
-
36
- # Add grid for better readability
37
- ax.grid(axis='x', linestyle='--', alpha=0.7)
38
- ax.set_axisbelow(True)
39
-
40
- # Set x-axis limits with some padding
41
- max_val = df_sorted[metric].max()
42
- ax.set_xlim(0, max_val * 1.15)
43
-
44
- plt.tight_layout()
45
- return fig
46
-
47
- def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure:
48
- """Create detailed comparison plot similar to the original evaluation script."""
49
-
50
- # Filter metrics_data to only include models in model_names
51
- filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data}
52
-
53
- if not filtered_metrics:
54
- # Create empty plot if no data
55
- fig, ax = plt.subplots(figsize=(10, 6))
56
- ax.text(0.5, 0.5, 'No data available for comparison',
57
- ha='center', va='center', transform=ax.transAxes, fontsize=16)
58
- ax.set_xlim(0, 1)
59
- ax.set_ylim(0, 1)
60
- ax.axis('off')
61
- return fig
62
-
63
- return plot_translation_metric_comparison(filtered_metrics, metric='bleu')
64
-
65
- def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure:
66
- """
67
- Creates a grouped bar chart comparing a selected metric across translation models.
68
- Adapted from the original plotting code.
69
- """
70
-
71
- # Split language pairs into xx_to_eng and eng_to_xx categories
72
- first_model_data = list(metrics_by_model.values())[0]
73
- xx_to_eng = [key for key in first_model_data.keys()
74
- if key.endswith('_to_eng') and key != 'averages']
75
- eng_to_xx = [key for key in first_model_data.keys()
76
- if key.startswith('eng_to_') and key != 'averages']
77
-
78
- # Function to create nice labels
79
- def format_label(label):
80
- if label.startswith("eng_to_"):
81
- source, target = "English", label.replace("eng_to_", "")
82
- target = LANGUAGE_NAMES.get(target, target)
83
- else:
84
- source, target = label.replace("_to_eng", ""), "English"
85
- source = LANGUAGE_NAMES.get(source, source)
86
- return f"{source}→{target}"
87
-
88
- # Extract metric values for each category
89
- def extract_metric_values(model_metrics, pairs, metric_name):
90
- return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs]
91
-
92
- xx_to_eng_data = {
93
- model_name: extract_metric_values(model_data, xx_to_eng, metric)
94
- for model_name, model_data in metrics_by_model.items()
95
- }
96
-
97
- eng_to_xx_data = {
98
- model_name: extract_metric_values(model_data, eng_to_xx, metric)
99
- for model_name, model_data in metrics_by_model.items()
100
- }
101
-
102
- averages_data = {
103
- model_name: [model_data.get("averages", {}).get(metric, 0.0)]
104
- for model_name, model_data in metrics_by_model.items()
105
- }
106
-
107
- # Set up plot with custom grid
108
- fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing
109
-
110
- # Create a GridSpec with 1 row and 5 columns
111
- gs = gridspec.GridSpec(1, 5)
112
-
113
- # Colors for the models
114
- model_names = list(metrics_by_model.keys())
115
-
116
- family_base_colors = {
117
- 'gemma': '#3274A1',
118
- 'nllb': '#7f7f7f',
119
- 'qwen': '#E1812C',
120
- 'google': '#3A923A',
121
- 'other': '#D62728',
122
- }
123
-
124
- # Identify the family for each model
125
- def get_family(model_name):
126
- model_lower = model_name.lower()
127
- if 'gemma' in model_lower:
128
- return 'gemma'
129
- elif 'qwen' in model_lower:
130
- return 'qwen'
131
- elif 'nllb' in model_lower:
132
- return 'nllb'
133
- elif 'google' in model_lower or model_name == 'google-translate':
134
- return 'google'
135
- else:
136
- return 'other'
137
-
138
- # Count how many models belong to each family
139
- family_counts = defaultdict(int)
140
- for model in model_names:
141
- family = get_family(model)
142
- family_counts[family] += 1
143
-
144
- # Generate slightly varied lightness within each family
145
- colors = []
146
- family_indices = defaultdict(int)
147
- for model in model_names:
148
- family = get_family(model)
149
- base_rgb = mcolors.to_rgb(family_base_colors[family])
150
- h, l, s = rgb_to_hls(*base_rgb)
151
-
152
- index = family_indices[family]
153
- count = family_counts[family]
154
-
155
- # Vary lightness: from 0.35 to 0.65
156
- if count == 1:
157
- new_l = l # Keep original for single models
158
- else:
159
- new_l = 0.65 - 0.3 * (index / max(count - 1, 1))
160
-
161
- varied_rgb = hls_to_rgb(h, new_l, s)
162
- hex_color = mcolors.to_hex(varied_rgb)
163
- colors.append(hex_color)
164
- family_indices[family] += 1
165
-
166
- bar_width = 0.2
167
- opacity = 0.8
168
-
169
- # Positions for the bars
170
- xx_to_eng_indices = np.arange(len(xx_to_eng))
171
- eng_to_xx_indices = np.arange(len(eng_to_xx))
172
- avg_index = np.array([0])
173
-
174
- # Determine y-axis limits based on metric
175
- if metric in ['chrf', 'len_ratio']:
176
- y_max = 1.1
177
- elif metric in ['cer', 'wer']:
178
- y_max = 1.0
179
- elif metric == 'bleu':
180
- y_max = 65 # Increased from 55 to accommodate high scores
181
- elif metric in ['rouge1', 'rouge2', 'rougeL']:
182
- y_max = 1.0
183
- elif metric == 'quality_score':
184
- y_max = 0.65
185
- else:
186
- # Auto-scale based on data
187
- all_values = []
188
- for data in [xx_to_eng_data, eng_to_xx_data, averages_data]:
189
- for model_data in data.values():
190
- all_values.extend(model_data)
191
- y_max = max(all_values) * 1.1 if all_values else 1.0
192
-
193
- # Format metric name for display
194
- metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title()
195
-
196
- # Create bars for xx_to_eng (using first 2 columns)
197
- if xx_to_eng:
198
- ax1 = plt.subplot(gs[0, 0:2])
199
- for i, (model_name, color) in enumerate(zip(model_names, colors)):
200
- if model_name in xx_to_eng_data:
201
- ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name],
202
- bar_width, alpha=opacity, color=color, label=model_name)
203
-
204
- ax1.set_xlabel('Translation Direction')
205
- ax1.set_ylabel(f'{metric_display} Score')
206
- ax1.set_title(f'XX→English {metric_display} Performance')
207
- ax1.set_xticks(xx_to_eng_indices + bar_width)
208
- ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right')
209
- ax1.set_ylim(0, y_max)
210
- ax1.grid(axis='y', linestyle='--', alpha=0.7)
211
-
212
- # Create bars for eng_to_xx (using next 2 columns)
213
- if eng_to_xx:
214
- ax2 = plt.subplot(gs[0, 2:4])
215
- for i, (model_name, color) in enumerate(zip(model_names, colors)):
216
- if model_name in eng_to_xx_data:
217
- ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name],
218
- bar_width, alpha=opacity, color=color, label=model_name)
219
-
220
- ax2.set_xlabel('Translation Direction')
221
- ax2.set_ylabel(f'{metric_display} Score')
222
- ax2.set_title(f'English→XX {metric_display} Performance')
223
- ax2.set_xticks(eng_to_xx_indices + bar_width)
224
- ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right')
225
- ax2.set_ylim(0, y_max)
226
- ax2.grid(axis='y', linestyle='--', alpha=0.7)
227
-
228
- # Create bars for averages (using last column)
229
- ax3 = plt.subplot(gs[0, 4])
230
- for i, (model_name, color) in enumerate(zip(model_names, colors)):
231
- if model_name in averages_data:
232
- ax3.bar(avg_index + i*bar_width, averages_data[model_name],
233
- bar_width, alpha=opacity, color=color, label=model_name)
234
-
235
- ax3.set_xlabel('Overall')
236
- ax3.set_ylabel(f'{metric_display} Score')
237
- ax3.set_title(f'Average {metric_display}')
238
- ax3.set_xticks(avg_index + bar_width)
239
- ax3.set_xticklabels(['Average'])
240
- ax3.set_ylim(0, y_max)
241
- ax3.grid(axis='y', linestyle='--', alpha=0.7)
242
- ax3.legend()
243
-
244
- # Add note for metrics where lower is better
245
- if metric in ['cer', 'wer']:
246
- plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric",
247
- ha='center', fontsize=12, style='italic')
248
-
249
- # Add an overall title and adjust layout
250
- model_list = ' vs '.join(model_names)
251
- plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98)
252
- plt.tight_layout(rect=[0, 0.02, 1, 0.95])
253
-
254
- return fig
255
-
256
- def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure:
257
- """Create a summary plot showing multiple metrics for top models."""
258
-
259
- if leaderboard_df.empty:
260
- fig, ax = plt.subplots(figsize=(10, 6))
261
- ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
262
- transform=ax.transAxes, fontsize=16)
263
- return fig
264
-
265
- # Select top 5 models by quality score
266
- top_models = leaderboard_df.nlargest(5, 'quality_score')
267
-
268
- # Metrics to display
269
- metrics = ['bleu', 'chrf', 'quality_score']
270
- metric_labels = ['BLEU', 'ChrF', 'Quality Score']
271
-
272
- fig, axes = plt.subplots(1, 3, figsize=(15, 6))
273
-
274
- for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
275
- ax = axes[i]
276
-
277
- # Sort by current metric
278
- sorted_models = top_models.sort_values(metric, ascending=True)
279
-
280
- # Create horizontal bar chart
281
- bars = ax.barh(range(len(sorted_models)), sorted_models[metric],
282
- color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models))))
283
-
284
- ax.set_yticks(range(len(sorted_models)))
285
- ax.set_yticklabels(sorted_models['model_display_name'])
286
- ax.set_xlabel(f'{label} Score')
287
- ax.set_title(f'Top Models - {label}')
288
- ax.grid(axis='x', linestyle='--', alpha=0.7)
289
-
290
- # Add value labels
291
- for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])):
292
- ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2,
293
- f'{value:.3f}', ha='left', va='center', fontsize=10)
294
-
295
- plt.tight_layout()
296
- return fig