akera commited on
Commit
cfbcff1
Β·
verified Β·
1 Parent(s): 944a871

Create plotting.py

Browse files
Files changed (1) hide show
  1. src/plotting.py +529 -0
src/plotting.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/plotting.py
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.gridspec as gridspec
4
+ import matplotlib.colors as mcolors
5
+ from colorsys import rgb_to_hls, hls_to_rgb
6
+ import plotly.graph_objects as go
7
+ import plotly.express as px
8
+ from plotly.subplots import make_subplots
9
+ import pandas as pd
10
+ import numpy as np
11
+ from collections import defaultdict
12
+ from typing import Dict, List, Optional, Union
13
+ from config import LANGUAGE_NAMES, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES, METRICS_CONFIG
14
+
15
+ plt.style.use('default')
16
+ plt.rcParams['figure.facecolor'] = 'white'
17
+ plt.rcParams['axes.facecolor'] = 'white'
18
+
19
+ def create_leaderboard_ranking_plot(df: pd.DataFrame, metric: str = 'quality_score', top_n: int = 15) -> go.Figure:
20
+ """Create interactive leaderboard ranking plot using Plotly."""
21
+
22
+ if df.empty:
23
+ fig = go.Figure()
24
+ fig.add_annotation(
25
+ text="No data available",
26
+ xref="paper", yref="paper",
27
+ x=0.5, y=0.5, showarrow=False,
28
+ font=dict(size=16)
29
+ )
30
+ return fig
31
+
32
+ # Get top N models
33
+ top_models = df.head(top_n)
34
+
35
+ # Create color scale based on scores
36
+ colors = px.colors.qualitative.Set3[:len(top_models)]
37
+
38
+ # Create horizontal bar chart
39
+ fig = go.Figure(data=[
40
+ go.Bar(
41
+ y=top_models['model_name'],
42
+ x=top_models[metric],
43
+ orientation='h',
44
+ marker=dict(
45
+ color=top_models[metric],
46
+ colorscale='Viridis',
47
+ showscale=True,
48
+ colorbar=dict(title=metric.replace('_', ' ').title())
49
+ ),
50
+ text=[f"{score:.3f}" for score in top_models[metric]],
51
+ textposition='auto',
52
+ hovertemplate=(
53
+ "<b>%{y}</b><br>" +
54
+ f"{metric.replace('_', ' ').title()}: %{{x:.4f}}<br>" +
55
+ "Author: %{customdata[0]}<br>" +
56
+ "Coverage: %{customdata[1]:.1%}<br>" +
57
+ "<extra></extra>"
58
+ ),
59
+ customdata=list(zip(top_models['author'], top_models['coverage_rate']))
60
+ )
61
+ ])
62
+
63
+ fig.update_layout(
64
+ title=f"πŸ† SALT Translation Leaderboard - {metric.replace('_', ' ').title()}",
65
+ xaxis_title=f"{metric.replace('_', ' ').title()} Score",
66
+ yaxis_title="Models",
67
+ height=max(400, len(top_models) * 30 + 100),
68
+ margin=dict(l=20, r=20, t=60, b=20),
69
+ plot_bgcolor='white',
70
+ paper_bgcolor='white'
71
+ )
72
+
73
+ # Reverse y-axis to show best model at top
74
+ fig.update_yaxes(autorange="reversed")
75
+
76
+ return fig
77
+
78
+ def create_metrics_comparison_plot(df: pd.DataFrame, models: List[str] = None, max_models: int = 8) -> go.Figure:
79
+ """Create radar chart comparing multiple metrics across models."""
80
+
81
+ if df.empty:
82
+ return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
83
+
84
+ # Select models to compare
85
+ if models is None:
86
+ selected_models = df.head(max_models)
87
+ else:
88
+ selected_models = df[df['model_name'].isin(models)].head(max_models)
89
+
90
+ if len(selected_models) == 0:
91
+ return go.Figure().add_annotation(text="No models found", x=0.5, y=0.5)
92
+
93
+ # Metrics to include in radar chart
94
+ metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL']
95
+ metric_labels = ['Quality Score', 'BLEU (/100)', 'ChrF', 'ROUGE-1', 'ROUGE-L']
96
+
97
+ fig = go.Figure()
98
+
99
+ colors = px.colors.qualitative.Set1[:len(selected_models)]
100
+
101
+ for i, (_, model) in enumerate(selected_models.iterrows()):
102
+ # Normalize BLEU to 0-1 scale for radar chart
103
+ values = []
104
+ for metric in metrics:
105
+ value = model[metric]
106
+ if metric == 'bleu':
107
+ value = value / 100.0 # Normalize BLEU
108
+ values.append(value)
109
+
110
+ # Close the radar chart
111
+ values += values[:1]
112
+ metric_labels_closed = metric_labels + [metric_labels[0]]
113
+
114
+ fig.add_trace(go.Scatterpolar(
115
+ r=values,
116
+ theta=metric_labels_closed,
117
+ fill='toself',
118
+ name=model['model_name'],
119
+ line_color=colors[i % len(colors)],
120
+ fillcolor=colors[i % len(colors)],
121
+ opacity=0.6
122
+ ))
123
+
124
+ fig.update_layout(
125
+ polar=dict(
126
+ radialaxis=dict(
127
+ visible=True,
128
+ range=[0, 1]
129
+ )
130
+ ),
131
+ showlegend=True,
132
+ title="πŸ“Š Multi-Metric Model Comparison",
133
+ height=600
134
+ )
135
+
136
+ return fig
137
+
138
+ def create_language_pair_heatmap(results_dict: Dict, metric: str = 'quality_score') -> go.Figure:
139
+ """Create heatmap showing performance across language pairs."""
140
+
141
+ if not results_dict or 'pair_metrics' not in results_dict:
142
+ return go.Figure().add_annotation(text="No language pair data available", x=0.5, y=0.5)
143
+
144
+ pair_metrics = results_dict['pair_metrics']
145
+
146
+ # Create matrix for heatmap
147
+ languages = ALL_UG40_LANGUAGES
148
+ matrix = np.zeros((len(languages), len(languages)))
149
+
150
+ for i, src_lang in enumerate(languages):
151
+ for j, tgt_lang in enumerate(languages):
152
+ if src_lang != tgt_lang:
153
+ pair_key = f"{src_lang}_to_{tgt_lang}"
154
+ if pair_key in pair_metrics and metric in pair_metrics[pair_key]:
155
+ matrix[i, j] = pair_metrics[pair_key][metric]
156
+ else:
157
+ matrix[i, j] = np.nan
158
+ else:
159
+ matrix[i, j] = np.nan
160
+
161
+ # Create language labels
162
+ lang_labels = [LANGUAGE_NAMES.get(lang, lang) for lang in languages]
163
+
164
+ fig = go.Figure(data=go.Heatmap(
165
+ z=matrix,
166
+ x=lang_labels,
167
+ y=lang_labels,
168
+ colorscale='Viridis',
169
+ showscale=True,
170
+ colorbar=dict(title=metric.replace('_', ' ').title()),
171
+ hoverinfotemplate=(
172
+ "Source: %{y}<br>" +
173
+ "Target: %{x}<br>" +
174
+ f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" +
175
+ "<extra></extra>"
176
+ )
177
+ ))
178
+
179
+ fig.update_layout(
180
+ title=f"πŸ—ΊοΈ Language Pair Performance - {metric.replace('_', ' ').title()}",
181
+ xaxis_title="Target Language",
182
+ yaxis_title="Source Language",
183
+ height=600,
184
+ width=700
185
+ )
186
+
187
+ return fig
188
+
189
+ def create_coverage_analysis_plot(df: pd.DataFrame) -> go.Figure:
190
+ """Create plot analyzing test set coverage across submissions."""
191
+
192
+ if df.empty:
193
+ return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
194
+
195
+ fig = make_subplots(
196
+ rows=2, cols=2,
197
+ subplot_titles=(
198
+ "Coverage Distribution",
199
+ "Language Pairs Covered",
200
+ "Sample Count vs Quality",
201
+ "Google Comparable Coverage"
202
+ ),
203
+ specs=[[{"type": "bar"}, {"type": "scatter"}],
204
+ [{"type": "scatter"}, {"type": "bar"}]]
205
+ )
206
+
207
+ # Coverage distribution
208
+ coverage_bins = pd.cut(df['coverage_rate'],
209
+ bins=[0, 0.5, 0.8, 0.9, 0.95, 1.0],
210
+ labels=['<50%', '50-80%', '80-90%', '90-95%', '95-100%'])
211
+ coverage_counts = coverage_bins.value_counts()
212
+
213
+ fig.add_trace(
214
+ go.Bar(x=coverage_counts.index, y=coverage_counts.values, name="Coverage"),
215
+ row=1, col=1
216
+ )
217
+
218
+ # Language pairs covered vs quality
219
+ fig.add_trace(
220
+ go.Scatter(
221
+ x=df['language_pairs_covered'],
222
+ y=df['quality_score'],
223
+ mode='markers',
224
+ text=df['model_name'],
225
+ name="Quality vs Coverage"
226
+ ),
227
+ row=1, col=2
228
+ )
229
+
230
+ # Sample count vs quality
231
+ fig.add_trace(
232
+ go.Scatter(
233
+ x=df['total_samples'],
234
+ y=df['quality_score'],
235
+ mode='markers',
236
+ text=df['model_name'],
237
+ name="Quality vs Samples"
238
+ ),
239
+ row=2, col=1
240
+ )
241
+
242
+ # Google comparable coverage
243
+ google_coverage = df['google_pairs_covered'].value_counts().sort_index()
244
+ fig.add_trace(
245
+ go.Bar(x=google_coverage.index, y=google_coverage.values, name="Google Coverage"),
246
+ row=2, col=2
247
+ )
248
+
249
+ fig.update_layout(
250
+ title="πŸ“ˆ Test Set Coverage Analysis",
251
+ height=800,
252
+ showlegend=False
253
+ )
254
+
255
+ return fig
256
+
257
+ def create_model_performance_timeline(df: pd.DataFrame) -> go.Figure:
258
+ """Create timeline showing model performance over time."""
259
+
260
+ if df.empty:
261
+ return go.Figure().add_annotation(text="No data available", x=0.5, y=0.5)
262
+
263
+ # Convert submission_date to datetime
264
+ df_copy = df.copy()
265
+ df_copy['submission_date'] = pd.to_datetime(df_copy['submission_date'])
266
+ df_copy = df_copy.sort_values('submission_date')
267
+
268
+ fig = go.Figure()
269
+
270
+ # Add scatter plot for each submission
271
+ fig.add_trace(go.Scatter(
272
+ x=df_copy['submission_date'],
273
+ y=df_copy['quality_score'],
274
+ mode='markers+lines',
275
+ marker=dict(
276
+ size=10,
277
+ color=df_copy['quality_score'],
278
+ colorscale='Viridis',
279
+ showscale=True,
280
+ colorbar=dict(title="Quality Score")
281
+ ),
282
+ text=df_copy['model_name'],
283
+ hovertemplate=(
284
+ "<b>%{text}</b><br>" +
285
+ "Date: %{x}<br>" +
286
+ "Quality Score: %{y:.4f}<br>" +
287
+ "<extra></extra>"
288
+ ),
289
+ name="Models"
290
+ ))
291
+
292
+ # Add trend line
293
+ if len(df_copy) > 1:
294
+ z = np.polyfit(range(len(df_copy)), df_copy['quality_score'], 1)
295
+ trend_line = np.poly1d(z)(range(len(df_copy)))
296
+
297
+ fig.add_trace(go.Scatter(
298
+ x=df_copy['submission_date'],
299
+ y=trend_line,
300
+ mode='lines',
301
+ line=dict(dash='dash', color='red'),
302
+ name="Trend",
303
+ hoverinfo='skip'
304
+ ))
305
+
306
+ fig.update_layout(
307
+ title="πŸ“… Model Performance Timeline",
308
+ xaxis_title="Submission Date",
309
+ yaxis_title="Quality Score",
310
+ height=500
311
+ )
312
+
313
+ return fig
314
+
315
+ def create_google_comparison_plot(df: pd.DataFrame) -> go.Figure:
316
+ """Create plot comparing models on Google Translate-comparable language pairs."""
317
+
318
+ # Filter models that have Google comparable results
319
+ google_models = df[df['google_pairs_covered'] > 0].copy()
320
+
321
+ if google_models.empty:
322
+ return go.Figure().add_annotation(
323
+ text="No models with Google Translate comparable results",
324
+ x=0.5, y=0.5
325
+ )
326
+
327
+ fig = go.Figure()
328
+
329
+ # Create scatter plot
330
+ fig.add_trace(go.Scatter(
331
+ x=google_models['google_bleu'],
332
+ y=google_models['google_quality_score'],
333
+ mode='markers+text',
334
+ marker=dict(
335
+ size=12,
336
+ color=google_models['google_chrf'],
337
+ colorscale='Plasma',
338
+ showscale=True,
339
+ colorbar=dict(title="ChrF Score")
340
+ ),
341
+ text=google_models['model_name'],
342
+ textposition="top center",
343
+ hovertemplate=(
344
+ "<b>%{text}</b><br>" +
345
+ "BLEU: %{x:.2f}<br>" +
346
+ "Quality: %{y:.4f}<br>" +
347
+ "ChrF: %{marker.color:.4f}<br>" +
348
+ "<extra></extra>"
349
+ ),
350
+ name="Models"
351
+ ))
352
+
353
+ fig.update_layout(
354
+ title="πŸ€– Google Translate Comparable Performance",
355
+ xaxis_title="BLEU Score",
356
+ yaxis_title="Quality Score",
357
+ height=500
358
+ )
359
+
360
+ return fig
361
+
362
+ def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.Figure:
363
+ """Create detailed analysis plot for a specific model."""
364
+
365
+ if not model_results or 'pair_metrics' not in model_results:
366
+ return go.Figure().add_annotation(text="No detailed results available", x=0.5, y=0.5)
367
+
368
+ pair_metrics = model_results['pair_metrics']
369
+
370
+ # Extract language pair data
371
+ pairs = []
372
+ bleu_scores = []
373
+ quality_scores = []
374
+ sample_counts = []
375
+ google_comparable = []
376
+
377
+ for pair_key, metrics in pair_metrics.items():
378
+ if 'sample_count' in metrics and metrics['sample_count'] > 0:
379
+ src, tgt = pair_key.split('_to_')
380
+ pair_label = f"{LANGUAGE_NAMES.get(src, src)} β†’ {LANGUAGE_NAMES.get(tgt, tgt)}"
381
+
382
+ pairs.append(pair_label)
383
+ bleu_scores.append(metrics.get('bleu', 0))
384
+ quality_scores.append(metrics.get('quality_score', 0))
385
+ sample_counts.append(metrics.get('sample_count', 0))
386
+
387
+ is_google = (src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES)
388
+ google_comparable.append(is_google)
389
+
390
+ if not pairs:
391
+ return go.Figure().add_annotation(text="No language pair data found", x=0.5, y=0.5)
392
+
393
+ # Create subplot
394
+ fig = make_subplots(
395
+ rows=2, cols=1,
396
+ subplot_titles=(
397
+ f"{model_name} - BLEU Scores by Language Pair",
398
+ f"{model_name} - Quality Scores by Language Pair"
399
+ ),
400
+ vertical_spacing=0.1
401
+ )
402
+
403
+ # Color code by Google comparable
404
+ colors = ['#1f77b4' if gc else '#ff7f0e' for gc in google_comparable]
405
+
406
+ # BLEU scores
407
+ fig.add_trace(
408
+ go.Bar(
409
+ x=pairs,
410
+ y=bleu_scores,
411
+ marker_color=colors,
412
+ name="BLEU",
413
+ text=[f"{score:.1f}" for score in bleu_scores],
414
+ textposition='auto'
415
+ ),
416
+ row=1, col=1
417
+ )
418
+
419
+ # Quality scores
420
+ fig.add_trace(
421
+ go.Bar(
422
+ x=pairs,
423
+ y=quality_scores,
424
+ marker_color=colors,
425
+ name="Quality",
426
+ text=[f"{score:.3f}" for score in quality_scores],
427
+ textposition='auto',
428
+ showlegend=False
429
+ ),
430
+ row=2, col=1
431
+ )
432
+
433
+ fig.update_layout(
434
+ height=800,
435
+ title=f"πŸ“Š Detailed Analysis: {model_name}",
436
+ showlegend=True
437
+ )
438
+
439
+ # Rotate x-axis labels
440
+ fig.update_xaxes(tickangle=45)
441
+
442
+ # Add legend for colors
443
+ fig.add_trace(
444
+ go.Scatter(
445
+ x=[None], y=[None],
446
+ mode='markers',
447
+ marker=dict(size=10, color='#1f77b4'),
448
+ name="Google Comparable",
449
+ showlegend=True
450
+ )
451
+ )
452
+
453
+ fig.add_trace(
454
+ go.Scatter(
455
+ x=[None], y=[None],
456
+ mode='markers',
457
+ marker=dict(size=10, color='#ff7f0e'),
458
+ name="UG40 Only",
459
+ showlegend=True
460
+ )
461
+ )
462
+
463
+ return fig
464
+
465
+ def create_submission_summary_plot(validation_info: Dict, evaluation_results: Dict) -> go.Figure:
466
+ """Create summary plot for a new submission."""
467
+
468
+ fig = make_subplots(
469
+ rows=2, cols=2,
470
+ subplot_titles=(
471
+ "Coverage by Language Pair",
472
+ "Primary Metrics",
473
+ "Error Analysis",
474
+ "Sample Distribution"
475
+ ),
476
+ specs=[[{"type": "bar"}, {"type": "bar"}],
477
+ [{"type": "bar"}, {"type": "pie"}]]
478
+ )
479
+
480
+ # Coverage by language pair
481
+ if 'pair_coverage' in validation_info:
482
+ pair_data = validation_info['pair_coverage']
483
+ pairs = list(pair_data.keys())[:10] # Top 10 pairs
484
+ coverage_rates = [pair_data[p]['coverage_rate'] for p in pairs]
485
+
486
+ fig.add_trace(
487
+ go.Bar(x=pairs, y=coverage_rates, name="Coverage"),
488
+ row=1, col=1
489
+ )
490
+
491
+ # Primary metrics
492
+ if 'summary' in evaluation_results:
493
+ metrics_data = evaluation_results['summary']['primary_metrics']
494
+ metric_names = list(metrics_data.keys())
495
+ metric_values = list(metrics_data.values())
496
+
497
+ fig.add_trace(
498
+ go.Bar(x=metric_names, y=metric_values, name="Metrics"),
499
+ row=1, col=2
500
+ )
501
+
502
+ # Error analysis (CER, WER)
503
+ if 'averages' in evaluation_results:
504
+ error_metrics = ['cer', 'wer']
505
+ error_values = [evaluation_results['averages'].get(m, 0) for m in error_metrics]
506
+
507
+ fig.add_trace(
508
+ go.Bar(x=error_metrics, y=error_values, name="Errors"),
509
+ row=2, col=1
510
+ )
511
+
512
+ # Sample distribution (placeholder)
513
+ fig.add_trace(
514
+ go.Pie(
515
+ labels=["Evaluated", "Missing"],
516
+ values=[validation_info.get('coverage', 0.8) * 100,
517
+ (1 - validation_info.get('coverage', 0.8)) * 100],
518
+ name="Samples"
519
+ ),
520
+ row=2, col=2
521
+ )
522
+
523
+ fig.update_layout(
524
+ title="πŸ“‹ Submission Summary",
525
+ height=700,
526
+ showlegend=False
527
+ )
528
+
529
+ return fig