Spaces:
Running
Running
Rename src/plotting.py to src/leaderboard.py
Browse files- src/leaderboard.py +381 -0
- src/plotting.py +0 -296
src/leaderboard.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/leaderboard.py
|
2 |
+
import pandas as pd
|
3 |
+
from datasets import Dataset, load_dataset
|
4 |
+
import json
|
5 |
+
import datetime
|
6 |
+
from typing import Dict, List, Optional, Tuple
|
7 |
+
import os
|
8 |
+
from config import LEADERBOARD_DATASET, HF_TOKEN, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES
|
9 |
+
from src.utils import create_submission_id, sanitize_model_name, get_all_language_pairs, get_google_comparable_pairs
|
10 |
+
|
11 |
+
def initialize_leaderboard() -> pd.DataFrame:
|
12 |
+
"""Initialize empty leaderboard DataFrame."""
|
13 |
+
|
14 |
+
columns = {
|
15 |
+
'submission_id': [],
|
16 |
+
'model_name': [],
|
17 |
+
'author': [],
|
18 |
+
'submission_date': [],
|
19 |
+
'model_type': [],
|
20 |
+
'description': [],
|
21 |
+
|
22 |
+
# Primary metrics
|
23 |
+
'quality_score': [],
|
24 |
+
'bleu': [],
|
25 |
+
'chrf': [],
|
26 |
+
|
27 |
+
# Secondary metrics
|
28 |
+
'rouge1': [],
|
29 |
+
'rouge2': [],
|
30 |
+
'rougeL': [],
|
31 |
+
'cer': [],
|
32 |
+
'wer': [],
|
33 |
+
'len_ratio': [],
|
34 |
+
|
35 |
+
# Google comparable metrics
|
36 |
+
'google_quality_score': [],
|
37 |
+
'google_bleu': [],
|
38 |
+
'google_chrf': [],
|
39 |
+
|
40 |
+
# Coverage info
|
41 |
+
'total_samples': [],
|
42 |
+
'language_pairs_covered': [],
|
43 |
+
'google_pairs_covered': [],
|
44 |
+
'coverage_rate': [],
|
45 |
+
|
46 |
+
# Detailed results
|
47 |
+
'detailed_metrics': [], # JSON string
|
48 |
+
'validation_report': [],
|
49 |
+
|
50 |
+
# Metadata
|
51 |
+
'evaluation_date': [],
|
52 |
+
'leaderboard_version': []
|
53 |
+
}
|
54 |
+
|
55 |
+
return pd.DataFrame(columns)
|
56 |
+
|
57 |
+
def load_leaderboard() -> pd.DataFrame:
|
58 |
+
"""Load current leaderboard from HuggingFace dataset."""
|
59 |
+
|
60 |
+
try:
|
61 |
+
print("Loading leaderboard...")
|
62 |
+
dataset = load_dataset(LEADERBOARD_DATASET, split='train')
|
63 |
+
df = dataset.to_pandas()
|
64 |
+
|
65 |
+
# Ensure all required columns exist
|
66 |
+
required_columns = list(initialize_leaderboard().columns)
|
67 |
+
for col in required_columns:
|
68 |
+
if col not in df.columns:
|
69 |
+
if col in ['quality_score', 'bleu', 'chrf', 'rouge1', 'rouge2', 'rougeL',
|
70 |
+
'cer', 'wer', 'len_ratio', 'google_quality_score', 'google_bleu',
|
71 |
+
'google_chrf', 'total_samples', 'language_pairs_covered',
|
72 |
+
'google_pairs_covered', 'coverage_rate']:
|
73 |
+
df[col] = 0.0
|
74 |
+
elif col in ['leaderboard_version']:
|
75 |
+
df[col] = 1
|
76 |
+
else:
|
77 |
+
df[col] = ''
|
78 |
+
|
79 |
+
print(f"Loaded leaderboard with {len(df)} entries")
|
80 |
+
return df
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Could not load leaderboard: {e}")
|
84 |
+
print("Initializing empty leaderboard...")
|
85 |
+
return initialize_leaderboard()
|
86 |
+
|
87 |
+
def save_leaderboard(df: pd.DataFrame) -> bool:
|
88 |
+
"""Save leaderboard to HuggingFace dataset."""
|
89 |
+
|
90 |
+
try:
|
91 |
+
# Clean data before saving
|
92 |
+
df_clean = df.copy()
|
93 |
+
|
94 |
+
# Ensure numeric columns are proper types
|
95 |
+
numeric_columns = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rouge2', 'rougeL',
|
96 |
+
'cer', 'wer', 'len_ratio', 'google_quality_score', 'google_bleu',
|
97 |
+
'google_chrf', 'total_samples', 'language_pairs_covered',
|
98 |
+
'google_pairs_covered', 'coverage_rate', 'leaderboard_version']
|
99 |
+
|
100 |
+
for col in numeric_columns:
|
101 |
+
if col in df_clean.columns:
|
102 |
+
df_clean[col] = pd.to_numeric(df_clean[col], errors='coerce').fillna(0.0)
|
103 |
+
|
104 |
+
# Convert to dataset
|
105 |
+
dataset = Dataset.from_pandas(df_clean)
|
106 |
+
|
107 |
+
# Push to hub
|
108 |
+
dataset.push_to_hub(
|
109 |
+
LEADERBOARD_DATASET,
|
110 |
+
token=HF_TOKEN,
|
111 |
+
commit_message=f"Update leaderboard - {datetime.datetime.now().isoformat()[:19]}"
|
112 |
+
)
|
113 |
+
|
114 |
+
print("Leaderboard saved successfully!")
|
115 |
+
return True
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
print(f"Error saving leaderboard: {e}")
|
119 |
+
return False
|
120 |
+
|
121 |
+
def add_model_to_leaderboard(
|
122 |
+
model_name: str,
|
123 |
+
author: str,
|
124 |
+
evaluation_results: Dict,
|
125 |
+
validation_info: Dict,
|
126 |
+
model_type: str = "",
|
127 |
+
description: str = ""
|
128 |
+
) -> pd.DataFrame:
|
129 |
+
"""Add new model results to leaderboard."""
|
130 |
+
|
131 |
+
# Load current leaderboard
|
132 |
+
df = load_leaderboard()
|
133 |
+
|
134 |
+
# Check if model already exists
|
135 |
+
existing_mask = df['model_name'] == model_name
|
136 |
+
if existing_mask.any():
|
137 |
+
print(f"Model '{model_name}' already exists. Updating...")
|
138 |
+
df = df[~existing_mask] # Remove existing entry
|
139 |
+
|
140 |
+
# Extract metrics
|
141 |
+
averages = evaluation_results.get('averages', {})
|
142 |
+
google_averages = evaluation_results.get('google_comparable_averages', {})
|
143 |
+
summary = evaluation_results.get('summary', {})
|
144 |
+
|
145 |
+
# Create new entry
|
146 |
+
new_entry = {
|
147 |
+
'submission_id': create_submission_id(),
|
148 |
+
'model_name': sanitize_model_name(model_name),
|
149 |
+
'author': author[:100] if author else 'Anonymous',
|
150 |
+
'submission_date': datetime.datetime.now().isoformat(),
|
151 |
+
'model_type': model_type[:50] if model_type else 'unknown',
|
152 |
+
'description': description[:500] if description else '',
|
153 |
+
|
154 |
+
# Primary metrics
|
155 |
+
'quality_score': float(averages.get('quality_score', 0.0)),
|
156 |
+
'bleu': float(averages.get('bleu', 0.0)),
|
157 |
+
'chrf': float(averages.get('chrf', 0.0)),
|
158 |
+
|
159 |
+
# Secondary metrics
|
160 |
+
'rouge1': float(averages.get('rouge1', 0.0)),
|
161 |
+
'rouge2': float(averages.get('rouge2', 0.0)),
|
162 |
+
'rougeL': float(averages.get('rougeL', 0.0)),
|
163 |
+
'cer': float(averages.get('cer', 0.0)),
|
164 |
+
'wer': float(averages.get('wer', 0.0)),
|
165 |
+
'len_ratio': float(averages.get('len_ratio', 0.0)),
|
166 |
+
|
167 |
+
# Google comparable metrics
|
168 |
+
'google_quality_score': float(google_averages.get('quality_score', 0.0)),
|
169 |
+
'google_bleu': float(google_averages.get('bleu', 0.0)),
|
170 |
+
'google_chrf': float(google_averages.get('chrf', 0.0)),
|
171 |
+
|
172 |
+
# Coverage info
|
173 |
+
'total_samples': int(summary.get('total_samples', 0)),
|
174 |
+
'language_pairs_covered': int(summary.get('language_pairs_covered', 0)),
|
175 |
+
'google_pairs_covered': int(summary.get('google_comparable_pairs', 0)),
|
176 |
+
'coverage_rate': float(validation_info.get('coverage', 0.0)),
|
177 |
+
|
178 |
+
# Detailed results
|
179 |
+
'detailed_metrics': json.dumps(evaluation_results),
|
180 |
+
'validation_report': validation_info.get('report', ''),
|
181 |
+
|
182 |
+
# Metadata
|
183 |
+
'evaluation_date': datetime.datetime.now().isoformat(),
|
184 |
+
'leaderboard_version': 1
|
185 |
+
}
|
186 |
+
|
187 |
+
# Add to dataframe
|
188 |
+
new_row_df = pd.DataFrame([new_entry])
|
189 |
+
updated_df = pd.concat([df, new_row_df], ignore_index=True)
|
190 |
+
|
191 |
+
# Sort by quality score (descending)
|
192 |
+
updated_df = updated_df.sort_values('quality_score', ascending=False).reset_index(drop=True)
|
193 |
+
|
194 |
+
# Save updated leaderboard
|
195 |
+
if save_leaderboard(updated_df):
|
196 |
+
print(f"Added '{model_name}' to leaderboard")
|
197 |
+
return updated_df
|
198 |
+
else:
|
199 |
+
print("Failed to save leaderboard")
|
200 |
+
return df
|
201 |
+
|
202 |
+
def get_leaderboard_stats(df: pd.DataFrame) -> Dict:
|
203 |
+
"""Get summary statistics for the leaderboard."""
|
204 |
+
|
205 |
+
if df.empty:
|
206 |
+
return {
|
207 |
+
'total_models': 0,
|
208 |
+
'avg_quality_score': 0.0,
|
209 |
+
'best_model': None,
|
210 |
+
'latest_submission': None,
|
211 |
+
'google_comparable_models': 0,
|
212 |
+
'coverage_distribution': {},
|
213 |
+
'language_pair_coverage': {}
|
214 |
+
}
|
215 |
+
|
216 |
+
# Basic stats
|
217 |
+
stats = {
|
218 |
+
'total_models': len(df),
|
219 |
+
'avg_quality_score': float(df['quality_score'].mean()),
|
220 |
+
'best_model': {
|
221 |
+
'name': df.iloc[0]['model_name'],
|
222 |
+
'score': float(df.iloc[0]['quality_score']),
|
223 |
+
'author': df.iloc[0]['author']
|
224 |
+
} if len(df) > 0 else None,
|
225 |
+
'latest_submission': df['submission_date'].max() if len(df) > 0 else None
|
226 |
+
}
|
227 |
+
|
228 |
+
# Google comparable models
|
229 |
+
stats['google_comparable_models'] = int((df['google_pairs_covered'] > 0).sum())
|
230 |
+
|
231 |
+
# Coverage distribution
|
232 |
+
coverage_bins = pd.cut(df['coverage_rate'], bins=[0, 0.5, 0.8, 0.95, 1.0],
|
233 |
+
labels=['<50%', '50-80%', '80-95%', '95-100%'])
|
234 |
+
stats['coverage_distribution'] = coverage_bins.value_counts().to_dict()
|
235 |
+
|
236 |
+
# Language pair coverage
|
237 |
+
if len(df) > 0:
|
238 |
+
stats['avg_pairs_covered'] = float(df['language_pairs_covered'].mean())
|
239 |
+
stats['max_pairs_covered'] = int(df['language_pairs_covered'].max())
|
240 |
+
stats['total_possible_pairs'] = len(get_all_language_pairs())
|
241 |
+
|
242 |
+
return stats
|
243 |
+
|
244 |
+
def filter_leaderboard(
|
245 |
+
df: pd.DataFrame,
|
246 |
+
search_query: str = "",
|
247 |
+
model_type: str = "",
|
248 |
+
min_coverage: float = 0.0,
|
249 |
+
google_comparable_only: bool = False,
|
250 |
+
top_n: int = None
|
251 |
+
) -> pd.DataFrame:
|
252 |
+
"""Filter leaderboard based on various criteria."""
|
253 |
+
|
254 |
+
filtered_df = df.copy()
|
255 |
+
|
256 |
+
# Text search
|
257 |
+
if search_query:
|
258 |
+
query_lower = search_query.lower()
|
259 |
+
mask = (
|
260 |
+
filtered_df['model_name'].str.lower().str.contains(query_lower, na=False) |
|
261 |
+
filtered_df['author'].str.lower().str.contains(query_lower, na=False) |
|
262 |
+
filtered_df['description'].str.lower().str.contains(query_lower, na=False)
|
263 |
+
)
|
264 |
+
filtered_df = filtered_df[mask]
|
265 |
+
|
266 |
+
# Model type filter
|
267 |
+
if model_type and model_type != "all":
|
268 |
+
filtered_df = filtered_df[filtered_df['model_type'] == model_type]
|
269 |
+
|
270 |
+
# Coverage filter
|
271 |
+
if min_coverage > 0:
|
272 |
+
filtered_df = filtered_df[filtered_df['coverage_rate'] >= min_coverage]
|
273 |
+
|
274 |
+
# Google comparable filter
|
275 |
+
if google_comparable_only:
|
276 |
+
filtered_df = filtered_df[filtered_df['google_pairs_covered'] > 0]
|
277 |
+
|
278 |
+
# Top N filter
|
279 |
+
if top_n:
|
280 |
+
filtered_df = filtered_df.head(top_n)
|
281 |
+
|
282 |
+
return filtered_df
|
283 |
+
|
284 |
+
def get_model_comparison(df: pd.DataFrame, model_names: List[str]) -> Dict:
|
285 |
+
"""Get detailed comparison between specific models."""
|
286 |
+
|
287 |
+
models = df[df['model_name'].isin(model_names)]
|
288 |
+
|
289 |
+
if len(models) == 0:
|
290 |
+
return {'error': 'No models found'}
|
291 |
+
|
292 |
+
comparison = {
|
293 |
+
'models': [],
|
294 |
+
'metrics_comparison': {},
|
295 |
+
'detailed_results': {}
|
296 |
+
}
|
297 |
+
|
298 |
+
# Extract basic info for each model
|
299 |
+
for _, model in models.iterrows():
|
300 |
+
comparison['models'].append({
|
301 |
+
'name': model['model_name'],
|
302 |
+
'author': model['author'],
|
303 |
+
'submission_date': model['submission_date'],
|
304 |
+
'model_type': model['model_type']
|
305 |
+
})
|
306 |
+
|
307 |
+
# Parse detailed metrics if available
|
308 |
+
try:
|
309 |
+
detailed = json.loads(model['detailed_metrics'])
|
310 |
+
comparison['detailed_results'][model['model_name']] = detailed
|
311 |
+
except:
|
312 |
+
comparison['detailed_results'][model['model_name']] = {}
|
313 |
+
|
314 |
+
# Compare metrics
|
315 |
+
metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL', 'cer', 'wer']
|
316 |
+
for metric in metrics:
|
317 |
+
if metric in models.columns:
|
318 |
+
comparison['metrics_comparison'][metric] = {
|
319 |
+
model_name: float(score)
|
320 |
+
for model_name, score in zip(models['model_name'], models[metric])
|
321 |
+
}
|
322 |
+
|
323 |
+
return comparison
|
324 |
+
|
325 |
+
def export_leaderboard(df: pd.DataFrame, format: str = 'csv', include_detailed: bool = False) -> str:
|
326 |
+
"""Export leaderboard in specified format."""
|
327 |
+
|
328 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
329 |
+
|
330 |
+
# Select columns for export
|
331 |
+
if include_detailed:
|
332 |
+
export_df = df.copy()
|
333 |
+
else:
|
334 |
+
basic_columns = [
|
335 |
+
'model_name', 'author', 'submission_date', 'model_type',
|
336 |
+
'quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL',
|
337 |
+
'total_samples', 'language_pairs_covered', 'coverage_rate'
|
338 |
+
]
|
339 |
+
export_df = df[basic_columns].copy()
|
340 |
+
|
341 |
+
if format == 'csv':
|
342 |
+
filename = f"salt_leaderboard_{timestamp}.csv"
|
343 |
+
export_df.to_csv(filename, index=False)
|
344 |
+
elif format == 'json':
|
345 |
+
filename = f"salt_leaderboard_{timestamp}.json"
|
346 |
+
export_df.to_json(filename, orient='records', indent=2)
|
347 |
+
elif format == 'xlsx':
|
348 |
+
filename = f"salt_leaderboard_{timestamp}.xlsx"
|
349 |
+
export_df.to_excel(filename, index=False)
|
350 |
+
else:
|
351 |
+
raise ValueError(f"Unsupported format: {format}")
|
352 |
+
|
353 |
+
return filename
|
354 |
+
|
355 |
+
def get_ranking_history(df: pd.DataFrame, model_name: str) -> Dict:
|
356 |
+
"""Get ranking history for a specific model (if multiple submissions)."""
|
357 |
+
|
358 |
+
model_entries = df[df['model_name'] == model_name].sort_values('submission_date')
|
359 |
+
|
360 |
+
if len(model_entries) == 0:
|
361 |
+
return {'error': 'Model not found'}
|
362 |
+
|
363 |
+
history = []
|
364 |
+
for _, entry in model_entries.iterrows():
|
365 |
+
# Calculate rank at time of submission
|
366 |
+
submission_date = entry['submission_date']
|
367 |
+
historical_df = df[df['submission_date'] <= submission_date]
|
368 |
+
rank = (historical_df['quality_score'] > entry['quality_score']).sum() + 1
|
369 |
+
|
370 |
+
history.append({
|
371 |
+
'submission_date': submission_date,
|
372 |
+
'quality_score': float(entry['quality_score']),
|
373 |
+
'rank': int(rank),
|
374 |
+
'total_models': len(historical_df)
|
375 |
+
})
|
376 |
+
|
377 |
+
return {
|
378 |
+
'model_name': model_name,
|
379 |
+
'history': history,
|
380 |
+
'current_rank': history[-1]['rank'] if history else None
|
381 |
+
}
|
src/plotting.py
DELETED
@@ -1,296 +0,0 @@
|
|
1 |
-
# src/plotting.py
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
import matplotlib.gridspec as gridspec
|
4 |
-
import matplotlib.colors as mcolors
|
5 |
-
from colorsys import rgb_to_hls, hls_to_rgb
|
6 |
-
from collections import defaultdict
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
-
from config import LANGUAGE_NAMES
|
10 |
-
|
11 |
-
def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure:
|
12 |
-
"""Create a horizontal bar chart showing model rankings."""
|
13 |
-
|
14 |
-
fig, ax = plt.subplots(figsize=(12, 8))
|
15 |
-
|
16 |
-
# Sort by the selected metric (descending)
|
17 |
-
df_sorted = leaderboard_df.sort_values(metric, ascending=True)
|
18 |
-
|
19 |
-
# Create color palette
|
20 |
-
colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted)))
|
21 |
-
|
22 |
-
# Create horizontal bar chart
|
23 |
-
bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors)
|
24 |
-
|
25 |
-
# Customize the plot
|
26 |
-
ax.set_yticks(range(len(df_sorted)))
|
27 |
-
ax.set_yticklabels(df_sorted['model_display_name'])
|
28 |
-
ax.set_xlabel(f'{metric.replace("_", " ").title()} Score')
|
29 |
-
ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20)
|
30 |
-
|
31 |
-
# Add value labels on bars
|
32 |
-
for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])):
|
33 |
-
ax.text(value + 0.001, bar.get_y() + bar.get_height()/2,
|
34 |
-
f'{value:.3f}', ha='left', va='center', fontweight='bold')
|
35 |
-
|
36 |
-
# Add grid for better readability
|
37 |
-
ax.grid(axis='x', linestyle='--', alpha=0.7)
|
38 |
-
ax.set_axisbelow(True)
|
39 |
-
|
40 |
-
# Set x-axis limits with some padding
|
41 |
-
max_val = df_sorted[metric].max()
|
42 |
-
ax.set_xlim(0, max_val * 1.15)
|
43 |
-
|
44 |
-
plt.tight_layout()
|
45 |
-
return fig
|
46 |
-
|
47 |
-
def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure:
|
48 |
-
"""Create detailed comparison plot similar to the original evaluation script."""
|
49 |
-
|
50 |
-
# Filter metrics_data to only include models in model_names
|
51 |
-
filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data}
|
52 |
-
|
53 |
-
if not filtered_metrics:
|
54 |
-
# Create empty plot if no data
|
55 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
56 |
-
ax.text(0.5, 0.5, 'No data available for comparison',
|
57 |
-
ha='center', va='center', transform=ax.transAxes, fontsize=16)
|
58 |
-
ax.set_xlim(0, 1)
|
59 |
-
ax.set_ylim(0, 1)
|
60 |
-
ax.axis('off')
|
61 |
-
return fig
|
62 |
-
|
63 |
-
return plot_translation_metric_comparison(filtered_metrics, metric='bleu')
|
64 |
-
|
65 |
-
def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure:
|
66 |
-
"""
|
67 |
-
Creates a grouped bar chart comparing a selected metric across translation models.
|
68 |
-
Adapted from the original plotting code.
|
69 |
-
"""
|
70 |
-
|
71 |
-
# Split language pairs into xx_to_eng and eng_to_xx categories
|
72 |
-
first_model_data = list(metrics_by_model.values())[0]
|
73 |
-
xx_to_eng = [key for key in first_model_data.keys()
|
74 |
-
if key.endswith('_to_eng') and key != 'averages']
|
75 |
-
eng_to_xx = [key for key in first_model_data.keys()
|
76 |
-
if key.startswith('eng_to_') and key != 'averages']
|
77 |
-
|
78 |
-
# Function to create nice labels
|
79 |
-
def format_label(label):
|
80 |
-
if label.startswith("eng_to_"):
|
81 |
-
source, target = "English", label.replace("eng_to_", "")
|
82 |
-
target = LANGUAGE_NAMES.get(target, target)
|
83 |
-
else:
|
84 |
-
source, target = label.replace("_to_eng", ""), "English"
|
85 |
-
source = LANGUAGE_NAMES.get(source, source)
|
86 |
-
return f"{source}→{target}"
|
87 |
-
|
88 |
-
# Extract metric values for each category
|
89 |
-
def extract_metric_values(model_metrics, pairs, metric_name):
|
90 |
-
return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs]
|
91 |
-
|
92 |
-
xx_to_eng_data = {
|
93 |
-
model_name: extract_metric_values(model_data, xx_to_eng, metric)
|
94 |
-
for model_name, model_data in metrics_by_model.items()
|
95 |
-
}
|
96 |
-
|
97 |
-
eng_to_xx_data = {
|
98 |
-
model_name: extract_metric_values(model_data, eng_to_xx, metric)
|
99 |
-
for model_name, model_data in metrics_by_model.items()
|
100 |
-
}
|
101 |
-
|
102 |
-
averages_data = {
|
103 |
-
model_name: [model_data.get("averages", {}).get(metric, 0.0)]
|
104 |
-
for model_name, model_data in metrics_by_model.items()
|
105 |
-
}
|
106 |
-
|
107 |
-
# Set up plot with custom grid
|
108 |
-
fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing
|
109 |
-
|
110 |
-
# Create a GridSpec with 1 row and 5 columns
|
111 |
-
gs = gridspec.GridSpec(1, 5)
|
112 |
-
|
113 |
-
# Colors for the models
|
114 |
-
model_names = list(metrics_by_model.keys())
|
115 |
-
|
116 |
-
family_base_colors = {
|
117 |
-
'gemma': '#3274A1',
|
118 |
-
'nllb': '#7f7f7f',
|
119 |
-
'qwen': '#E1812C',
|
120 |
-
'google': '#3A923A',
|
121 |
-
'other': '#D62728',
|
122 |
-
}
|
123 |
-
|
124 |
-
# Identify the family for each model
|
125 |
-
def get_family(model_name):
|
126 |
-
model_lower = model_name.lower()
|
127 |
-
if 'gemma' in model_lower:
|
128 |
-
return 'gemma'
|
129 |
-
elif 'qwen' in model_lower:
|
130 |
-
return 'qwen'
|
131 |
-
elif 'nllb' in model_lower:
|
132 |
-
return 'nllb'
|
133 |
-
elif 'google' in model_lower or model_name == 'google-translate':
|
134 |
-
return 'google'
|
135 |
-
else:
|
136 |
-
return 'other'
|
137 |
-
|
138 |
-
# Count how many models belong to each family
|
139 |
-
family_counts = defaultdict(int)
|
140 |
-
for model in model_names:
|
141 |
-
family = get_family(model)
|
142 |
-
family_counts[family] += 1
|
143 |
-
|
144 |
-
# Generate slightly varied lightness within each family
|
145 |
-
colors = []
|
146 |
-
family_indices = defaultdict(int)
|
147 |
-
for model in model_names:
|
148 |
-
family = get_family(model)
|
149 |
-
base_rgb = mcolors.to_rgb(family_base_colors[family])
|
150 |
-
h, l, s = rgb_to_hls(*base_rgb)
|
151 |
-
|
152 |
-
index = family_indices[family]
|
153 |
-
count = family_counts[family]
|
154 |
-
|
155 |
-
# Vary lightness: from 0.35 to 0.65
|
156 |
-
if count == 1:
|
157 |
-
new_l = l # Keep original for single models
|
158 |
-
else:
|
159 |
-
new_l = 0.65 - 0.3 * (index / max(count - 1, 1))
|
160 |
-
|
161 |
-
varied_rgb = hls_to_rgb(h, new_l, s)
|
162 |
-
hex_color = mcolors.to_hex(varied_rgb)
|
163 |
-
colors.append(hex_color)
|
164 |
-
family_indices[family] += 1
|
165 |
-
|
166 |
-
bar_width = 0.2
|
167 |
-
opacity = 0.8
|
168 |
-
|
169 |
-
# Positions for the bars
|
170 |
-
xx_to_eng_indices = np.arange(len(xx_to_eng))
|
171 |
-
eng_to_xx_indices = np.arange(len(eng_to_xx))
|
172 |
-
avg_index = np.array([0])
|
173 |
-
|
174 |
-
# Determine y-axis limits based on metric
|
175 |
-
if metric in ['chrf', 'len_ratio']:
|
176 |
-
y_max = 1.1
|
177 |
-
elif metric in ['cer', 'wer']:
|
178 |
-
y_max = 1.0
|
179 |
-
elif metric == 'bleu':
|
180 |
-
y_max = 65 # Increased from 55 to accommodate high scores
|
181 |
-
elif metric in ['rouge1', 'rouge2', 'rougeL']:
|
182 |
-
y_max = 1.0
|
183 |
-
elif metric == 'quality_score':
|
184 |
-
y_max = 0.65
|
185 |
-
else:
|
186 |
-
# Auto-scale based on data
|
187 |
-
all_values = []
|
188 |
-
for data in [xx_to_eng_data, eng_to_xx_data, averages_data]:
|
189 |
-
for model_data in data.values():
|
190 |
-
all_values.extend(model_data)
|
191 |
-
y_max = max(all_values) * 1.1 if all_values else 1.0
|
192 |
-
|
193 |
-
# Format metric name for display
|
194 |
-
metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title()
|
195 |
-
|
196 |
-
# Create bars for xx_to_eng (using first 2 columns)
|
197 |
-
if xx_to_eng:
|
198 |
-
ax1 = plt.subplot(gs[0, 0:2])
|
199 |
-
for i, (model_name, color) in enumerate(zip(model_names, colors)):
|
200 |
-
if model_name in xx_to_eng_data:
|
201 |
-
ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name],
|
202 |
-
bar_width, alpha=opacity, color=color, label=model_name)
|
203 |
-
|
204 |
-
ax1.set_xlabel('Translation Direction')
|
205 |
-
ax1.set_ylabel(f'{metric_display} Score')
|
206 |
-
ax1.set_title(f'XX→English {metric_display} Performance')
|
207 |
-
ax1.set_xticks(xx_to_eng_indices + bar_width)
|
208 |
-
ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right')
|
209 |
-
ax1.set_ylim(0, y_max)
|
210 |
-
ax1.grid(axis='y', linestyle='--', alpha=0.7)
|
211 |
-
|
212 |
-
# Create bars for eng_to_xx (using next 2 columns)
|
213 |
-
if eng_to_xx:
|
214 |
-
ax2 = plt.subplot(gs[0, 2:4])
|
215 |
-
for i, (model_name, color) in enumerate(zip(model_names, colors)):
|
216 |
-
if model_name in eng_to_xx_data:
|
217 |
-
ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name],
|
218 |
-
bar_width, alpha=opacity, color=color, label=model_name)
|
219 |
-
|
220 |
-
ax2.set_xlabel('Translation Direction')
|
221 |
-
ax2.set_ylabel(f'{metric_display} Score')
|
222 |
-
ax2.set_title(f'English→XX {metric_display} Performance')
|
223 |
-
ax2.set_xticks(eng_to_xx_indices + bar_width)
|
224 |
-
ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right')
|
225 |
-
ax2.set_ylim(0, y_max)
|
226 |
-
ax2.grid(axis='y', linestyle='--', alpha=0.7)
|
227 |
-
|
228 |
-
# Create bars for averages (using last column)
|
229 |
-
ax3 = plt.subplot(gs[0, 4])
|
230 |
-
for i, (model_name, color) in enumerate(zip(model_names, colors)):
|
231 |
-
if model_name in averages_data:
|
232 |
-
ax3.bar(avg_index + i*bar_width, averages_data[model_name],
|
233 |
-
bar_width, alpha=opacity, color=color, label=model_name)
|
234 |
-
|
235 |
-
ax3.set_xlabel('Overall')
|
236 |
-
ax3.set_ylabel(f'{metric_display} Score')
|
237 |
-
ax3.set_title(f'Average {metric_display}')
|
238 |
-
ax3.set_xticks(avg_index + bar_width)
|
239 |
-
ax3.set_xticklabels(['Average'])
|
240 |
-
ax3.set_ylim(0, y_max)
|
241 |
-
ax3.grid(axis='y', linestyle='--', alpha=0.7)
|
242 |
-
ax3.legend()
|
243 |
-
|
244 |
-
# Add note for metrics where lower is better
|
245 |
-
if metric in ['cer', 'wer']:
|
246 |
-
plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric",
|
247 |
-
ha='center', fontsize=12, style='italic')
|
248 |
-
|
249 |
-
# Add an overall title and adjust layout
|
250 |
-
model_list = ' vs '.join(model_names)
|
251 |
-
plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98)
|
252 |
-
plt.tight_layout(rect=[0, 0.02, 1, 0.95])
|
253 |
-
|
254 |
-
return fig
|
255 |
-
|
256 |
-
def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure:
|
257 |
-
"""Create a summary plot showing multiple metrics for top models."""
|
258 |
-
|
259 |
-
if leaderboard_df.empty:
|
260 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
261 |
-
ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
|
262 |
-
transform=ax.transAxes, fontsize=16)
|
263 |
-
return fig
|
264 |
-
|
265 |
-
# Select top 5 models by quality score
|
266 |
-
top_models = leaderboard_df.nlargest(5, 'quality_score')
|
267 |
-
|
268 |
-
# Metrics to display
|
269 |
-
metrics = ['bleu', 'chrf', 'quality_score']
|
270 |
-
metric_labels = ['BLEU', 'ChrF', 'Quality Score']
|
271 |
-
|
272 |
-
fig, axes = plt.subplots(1, 3, figsize=(15, 6))
|
273 |
-
|
274 |
-
for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
|
275 |
-
ax = axes[i]
|
276 |
-
|
277 |
-
# Sort by current metric
|
278 |
-
sorted_models = top_models.sort_values(metric, ascending=True)
|
279 |
-
|
280 |
-
# Create horizontal bar chart
|
281 |
-
bars = ax.barh(range(len(sorted_models)), sorted_models[metric],
|
282 |
-
color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models))))
|
283 |
-
|
284 |
-
ax.set_yticks(range(len(sorted_models)))
|
285 |
-
ax.set_yticklabels(sorted_models['model_display_name'])
|
286 |
-
ax.set_xlabel(f'{label} Score')
|
287 |
-
ax.set_title(f'Top Models - {label}')
|
288 |
-
ax.grid(axis='x', linestyle='--', alpha=0.7)
|
289 |
-
|
290 |
-
# Add value labels
|
291 |
-
for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])):
|
292 |
-
ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2,
|
293 |
-
f'{value:.3f}', ha='left', va='center', fontsize=10)
|
294 |
-
|
295 |
-
plt.tight_layout()
|
296 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|