akera commited on
Commit
34a7f8e
·
verified ·
1 Parent(s): 93b9d03

Create plotting.py

Browse files
Files changed (1) hide show
  1. src/plotting.py +296 -0
src/plotting.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/plotting.py
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.gridspec as gridspec
4
+ import matplotlib.colors as mcolors
5
+ from colorsys import rgb_to_hls, hls_to_rgb
6
+ from collections import defaultdict
7
+ import numpy as np
8
+ import pandas as pd
9
+ from config import LANGUAGE_NAMES
10
+
11
+ def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure:
12
+ """Create a horizontal bar chart showing model rankings."""
13
+
14
+ fig, ax = plt.subplots(figsize=(12, 8))
15
+
16
+ # Sort by the selected metric (descending)
17
+ df_sorted = leaderboard_df.sort_values(metric, ascending=True)
18
+
19
+ # Create color palette
20
+ colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted)))
21
+
22
+ # Create horizontal bar chart
23
+ bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors)
24
+
25
+ # Customize the plot
26
+ ax.set_yticks(range(len(df_sorted)))
27
+ ax.set_yticklabels(df_sorted['model_display_name'])
28
+ ax.set_xlabel(f'{metric.replace("_", " ").title()} Score')
29
+ ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20)
30
+
31
+ # Add value labels on bars
32
+ for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])):
33
+ ax.text(value + 0.001, bar.get_y() + bar.get_height()/2,
34
+ f'{value:.3f}', ha='left', va='center', fontweight='bold')
35
+
36
+ # Add grid for better readability
37
+ ax.grid(axis='x', linestyle='--', alpha=0.7)
38
+ ax.set_axisbelow(True)
39
+
40
+ # Set x-axis limits with some padding
41
+ max_val = df_sorted[metric].max()
42
+ ax.set_xlim(0, max_val * 1.15)
43
+
44
+ plt.tight_layout()
45
+ return fig
46
+
47
+ def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure:
48
+ """Create detailed comparison plot similar to the original evaluation script."""
49
+
50
+ # Filter metrics_data to only include models in model_names
51
+ filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data}
52
+
53
+ if not filtered_metrics:
54
+ # Create empty plot if no data
55
+ fig, ax = plt.subplots(figsize=(10, 6))
56
+ ax.text(0.5, 0.5, 'No data available for comparison',
57
+ ha='center', va='center', transform=ax.transAxes, fontsize=16)
58
+ ax.set_xlim(0, 1)
59
+ ax.set_ylim(0, 1)
60
+ ax.axis('off')
61
+ return fig
62
+
63
+ return plot_translation_metric_comparison(filtered_metrics, metric='bleu')
64
+
65
+ def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure:
66
+ """
67
+ Creates a grouped bar chart comparing a selected metric across translation models.
68
+ Adapted from the original plotting code.
69
+ """
70
+
71
+ # Split language pairs into xx_to_eng and eng_to_xx categories
72
+ first_model_data = list(metrics_by_model.values())[0]
73
+ xx_to_eng = [key for key in first_model_data.keys()
74
+ if key.endswith('_to_eng') and key != 'averages']
75
+ eng_to_xx = [key for key in first_model_data.keys()
76
+ if key.startswith('eng_to_') and key != 'averages']
77
+
78
+ # Function to create nice labels
79
+ def format_label(label):
80
+ if label.startswith("eng_to_"):
81
+ source, target = "English", label.replace("eng_to_", "")
82
+ target = LANGUAGE_NAMES.get(target, target)
83
+ else:
84
+ source, target = label.replace("_to_eng", ""), "English"
85
+ source = LANGUAGE_NAMES.get(source, source)
86
+ return f"{source}→{target}"
87
+
88
+ # Extract metric values for each category
89
+ def extract_metric_values(model_metrics, pairs, metric_name):
90
+ return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs]
91
+
92
+ xx_to_eng_data = {
93
+ model_name: extract_metric_values(model_data, xx_to_eng, metric)
94
+ for model_name, model_data in metrics_by_model.items()
95
+ }
96
+
97
+ eng_to_xx_data = {
98
+ model_name: extract_metric_values(model_data, eng_to_xx, metric)
99
+ for model_name, model_data in metrics_by_model.items()
100
+ }
101
+
102
+ averages_data = {
103
+ model_name: [model_data.get("averages", {}).get(metric, 0.0)]
104
+ for model_name, model_data in metrics_by_model.items()
105
+ }
106
+
107
+ # Set up plot with custom grid
108
+ fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing
109
+
110
+ # Create a GridSpec with 1 row and 5 columns
111
+ gs = gridspec.GridSpec(1, 5)
112
+
113
+ # Colors for the models
114
+ model_names = list(metrics_by_model.keys())
115
+
116
+ family_base_colors = {
117
+ 'gemma': '#3274A1',
118
+ 'nllb': '#7f7f7f',
119
+ 'qwen': '#E1812C',
120
+ 'google': '#3A923A',
121
+ 'other': '#D62728',
122
+ }
123
+
124
+ # Identify the family for each model
125
+ def get_family(model_name):
126
+ model_lower = model_name.lower()
127
+ if 'gemma' in model_lower:
128
+ return 'gemma'
129
+ elif 'qwen' in model_lower:
130
+ return 'qwen'
131
+ elif 'nllb' in model_lower:
132
+ return 'nllb'
133
+ elif 'google' in model_lower or model_name == 'google-translate':
134
+ return 'google'
135
+ else:
136
+ return 'other'
137
+
138
+ # Count how many models belong to each family
139
+ family_counts = defaultdict(int)
140
+ for model in model_names:
141
+ family = get_family(model)
142
+ family_counts[family] += 1
143
+
144
+ # Generate slightly varied lightness within each family
145
+ colors = []
146
+ family_indices = defaultdict(int)
147
+ for model in model_names:
148
+ family = get_family(model)
149
+ base_rgb = mcolors.to_rgb(family_base_colors[family])
150
+ h, l, s = rgb_to_hls(*base_rgb)
151
+
152
+ index = family_indices[family]
153
+ count = family_counts[family]
154
+
155
+ # Vary lightness: from 0.35 to 0.65
156
+ if count == 1:
157
+ new_l = l # Keep original for single models
158
+ else:
159
+ new_l = 0.65 - 0.3 * (index / max(count - 1, 1))
160
+
161
+ varied_rgb = hls_to_rgb(h, new_l, s)
162
+ hex_color = mcolors.to_hex(varied_rgb)
163
+ colors.append(hex_color)
164
+ family_indices[family] += 1
165
+
166
+ bar_width = 0.2
167
+ opacity = 0.8
168
+
169
+ # Positions for the bars
170
+ xx_to_eng_indices = np.arange(len(xx_to_eng))
171
+ eng_to_xx_indices = np.arange(len(eng_to_xx))
172
+ avg_index = np.array([0])
173
+
174
+ # Determine y-axis limits based on metric
175
+ if metric in ['chrf', 'len_ratio']:
176
+ y_max = 1.1
177
+ elif metric in ['cer', 'wer']:
178
+ y_max = 1.0
179
+ elif metric == 'bleu':
180
+ y_max = 65 # Increased from 55 to accommodate high scores
181
+ elif metric in ['rouge1', 'rouge2', 'rougeL']:
182
+ y_max = 1.0
183
+ elif metric == 'quality_score':
184
+ y_max = 0.65
185
+ else:
186
+ # Auto-scale based on data
187
+ all_values = []
188
+ for data in [xx_to_eng_data, eng_to_xx_data, averages_data]:
189
+ for model_data in data.values():
190
+ all_values.extend(model_data)
191
+ y_max = max(all_values) * 1.1 if all_values else 1.0
192
+
193
+ # Format metric name for display
194
+ metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title()
195
+
196
+ # Create bars for xx_to_eng (using first 2 columns)
197
+ if xx_to_eng:
198
+ ax1 = plt.subplot(gs[0, 0:2])
199
+ for i, (model_name, color) in enumerate(zip(model_names, colors)):
200
+ if model_name in xx_to_eng_data:
201
+ ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name],
202
+ bar_width, alpha=opacity, color=color, label=model_name)
203
+
204
+ ax1.set_xlabel('Translation Direction')
205
+ ax1.set_ylabel(f'{metric_display} Score')
206
+ ax1.set_title(f'XX→English {metric_display} Performance')
207
+ ax1.set_xticks(xx_to_eng_indices + bar_width)
208
+ ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right')
209
+ ax1.set_ylim(0, y_max)
210
+ ax1.grid(axis='y', linestyle='--', alpha=0.7)
211
+
212
+ # Create bars for eng_to_xx (using next 2 columns)
213
+ if eng_to_xx:
214
+ ax2 = plt.subplot(gs[0, 2:4])
215
+ for i, (model_name, color) in enumerate(zip(model_names, colors)):
216
+ if model_name in eng_to_xx_data:
217
+ ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name],
218
+ bar_width, alpha=opacity, color=color, label=model_name)
219
+
220
+ ax2.set_xlabel('Translation Direction')
221
+ ax2.set_ylabel(f'{metric_display} Score')
222
+ ax2.set_title(f'English→XX {metric_display} Performance')
223
+ ax2.set_xticks(eng_to_xx_indices + bar_width)
224
+ ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right')
225
+ ax2.set_ylim(0, y_max)
226
+ ax2.grid(axis='y', linestyle='--', alpha=0.7)
227
+
228
+ # Create bars for averages (using last column)
229
+ ax3 = plt.subplot(gs[0, 4])
230
+ for i, (model_name, color) in enumerate(zip(model_names, colors)):
231
+ if model_name in averages_data:
232
+ ax3.bar(avg_index + i*bar_width, averages_data[model_name],
233
+ bar_width, alpha=opacity, color=color, label=model_name)
234
+
235
+ ax3.set_xlabel('Overall')
236
+ ax3.set_ylabel(f'{metric_display} Score')
237
+ ax3.set_title(f'Average {metric_display}')
238
+ ax3.set_xticks(avg_index + bar_width)
239
+ ax3.set_xticklabels(['Average'])
240
+ ax3.set_ylim(0, y_max)
241
+ ax3.grid(axis='y', linestyle='--', alpha=0.7)
242
+ ax3.legend()
243
+
244
+ # Add note for metrics where lower is better
245
+ if metric in ['cer', 'wer']:
246
+ plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric",
247
+ ha='center', fontsize=12, style='italic')
248
+
249
+ # Add an overall title and adjust layout
250
+ model_list = ' vs '.join(model_names)
251
+ plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98)
252
+ plt.tight_layout(rect=[0, 0.02, 1, 0.95])
253
+
254
+ return fig
255
+
256
+ def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure:
257
+ """Create a summary plot showing multiple metrics for top models."""
258
+
259
+ if leaderboard_df.empty:
260
+ fig, ax = plt.subplots(figsize=(10, 6))
261
+ ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
262
+ transform=ax.transAxes, fontsize=16)
263
+ return fig
264
+
265
+ # Select top 5 models by quality score
266
+ top_models = leaderboard_df.nlargest(5, 'quality_score')
267
+
268
+ # Metrics to display
269
+ metrics = ['bleu', 'chrf', 'quality_score']
270
+ metric_labels = ['BLEU', 'ChrF', 'Quality Score']
271
+
272
+ fig, axes = plt.subplots(1, 3, figsize=(15, 6))
273
+
274
+ for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
275
+ ax = axes[i]
276
+
277
+ # Sort by current metric
278
+ sorted_models = top_models.sort_values(metric, ascending=True)
279
+
280
+ # Create horizontal bar chart
281
+ bars = ax.barh(range(len(sorted_models)), sorted_models[metric],
282
+ color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models))))
283
+
284
+ ax.set_yticks(range(len(sorted_models)))
285
+ ax.set_yticklabels(sorted_models['model_display_name'])
286
+ ax.set_xlabel(f'{label} Score')
287
+ ax.set_title(f'Top Models - {label}')
288
+ ax.grid(axis='x', linestyle='--', alpha=0.7)
289
+
290
+ # Add value labels
291
+ for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])):
292
+ ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2,
293
+ f'{value:.3f}', ha='left', va='center', fontsize=10)
294
+
295
+ plt.tight_layout()
296
+ return fig