akera commited on
Commit
0c7e136
Β·
verified Β·
1 Parent(s): 52616c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ from datasets import load_dataset
6
+ import yaml
7
+ import json
8
+ import torch
9
+ from datetime import datetime
10
+ import traceback
11
+
12
+ # Import our modules
13
+ from src.model_loader import load_model, get_model_info
14
+ from src.evaluation import evaluate_model_full
15
+ from src.leaderboard import load_leaderboard, add_model_results, get_leaderboard_summary, search_models
16
+ from src.plotting import create_leaderboard_plot, create_detailed_comparison_plot, create_summary_metrics_plot
17
+ from src.utils import validate_model_path, get_model_type, sanitize_input
18
+ from config import *
19
+
20
+ # Global variables for caching
21
+ current_leaderboard = None
22
+ test_data = None
23
+
24
+ def load_salt_data():
25
+ """Load SALT dataset for evaluation."""
26
+ global test_data
27
+
28
+ if test_data is not None:
29
+ return test_data
30
+
31
+ try:
32
+ print("Loading SALT dataset...")
33
+
34
+ # Configuration for SALT dataset
35
+ dataset_config = f'''
36
+ huggingface_load:
37
+ path: {SALT_DATASET}
38
+ name: text-all
39
+ split: dev[:{MAX_EVAL_SAMPLES}]
40
+ source:
41
+ type: text
42
+ language: {SUPPORTED_LANGUAGES}
43
+ target:
44
+ type: text
45
+ language: {SUPPORTED_LANGUAGES}
46
+ src_or_tgt_languages_must_contain: eng
47
+ allow_same_src_and_tgt_language: False
48
+ '''
49
+
50
+ config = yaml.safe_load(dataset_config)
51
+
52
+ # Import salt dataset utilities
53
+ import salt.dataset
54
+ test_data = pd.DataFrame(salt.dataset.create(config))
55
+
56
+ print(f"Loaded {len(test_data)} evaluation samples")
57
+ return test_data
58
+
59
+ except Exception as e:
60
+ print(f"Error loading SALT dataset: {e}")
61
+ # Fallback: create minimal test data
62
+ test_data = pd.DataFrame({
63
+ 'source': ['Hello world', 'How are you?'],
64
+ 'target': ['Amakuru', 'Oli otya?'],
65
+ 'source.language': ['eng', 'eng'],
66
+ 'target.language': ['lug', 'lug']
67
+ })
68
+ return test_data
69
+
70
+ def refresh_leaderboard():
71
+ """Refresh leaderboard data."""
72
+ global current_leaderboard
73
+ current_leaderboard = load_leaderboard()
74
+ return current_leaderboard
75
+
76
+ def evaluate_submission(model_path: str, author_name: str) -> tuple:
77
+ """Main evaluation function."""
78
+
79
+ try:
80
+ # Validate inputs
81
+ model_path = sanitize_input(model_path)
82
+ author_name = sanitize_input(author_name)
83
+
84
+ if not model_path:
85
+ return "❌ Error: Model path is required", None, None, None
86
+
87
+ if not author_name:
88
+ author_name = "Anonymous"
89
+
90
+ if not validate_model_path(model_path):
91
+ return "❌ Error: Invalid model path format", None, None, None
92
+
93
+ # Load test data
94
+ test_data = load_salt_data()
95
+ if test_data is None or len(test_data) == 0:
96
+ return "❌ Error: Could not load evaluation data", None, None, None
97
+
98
+ # Get model info
99
+ print(f"Getting model info for: {model_path}")
100
+ model_info = get_model_info(model_path)
101
+ model_type = get_model_type(model_path)
102
+
103
+ # Load model
104
+ print(f"Loading model: {model_path}")
105
+ try:
106
+ model, tokenizer = load_model(model_path)
107
+ except Exception as e:
108
+ return f"❌ Error loading model: {str(e)}", None, None, None
109
+
110
+ # Run evaluation
111
+ print("Starting evaluation...")
112
+ try:
113
+ detailed_metrics = evaluate_model_full(model, tokenizer, model_path, test_data)
114
+ except Exception as e:
115
+ return f"❌ Error during evaluation: {str(e)}", None, None, None
116
+
117
+ # Extract average metrics
118
+ avg_metrics = detailed_metrics.get('averages', {})
119
+ if not avg_metrics:
120
+ return "❌ Error: No metrics calculated", None, None, None
121
+
122
+ # Add results to leaderboard
123
+ print("Adding results to leaderboard...")
124
+ updated_leaderboard = add_model_results(
125
+ model_path=model_path,
126
+ author=author_name,
127
+ metrics=avg_metrics,
128
+ detailed_metrics=detailed_metrics,
129
+ evaluation_samples=len(test_data),
130
+ model_type=model_type
131
+ )
132
+
133
+ # Update global leaderboard
134
+ global current_leaderboard
135
+ current_leaderboard = updated_leaderboard
136
+
137
+ # Create visualizations
138
+ leaderboard_plot = create_leaderboard_plot(updated_leaderboard, 'quality_score')
139
+ detailed_plot = create_detailed_comparison_plot({model_path: detailed_metrics}, [model_path])
140
+
141
+ # Format results message
142
+ results_msg = f"""
143
+ βœ… **Evaluation Complete!**
144
+
145
+ **Model:** {model_path}
146
+ **Author:** {author_name}
147
+ **Type:** {model_type}
148
+
149
+ **Results:**
150
+ - Quality Score: {avg_metrics.get('quality_score', 0):.4f}
151
+ - BLEU: {avg_metrics.get('bleu', 0):.2f}
152
+ - ChrF: {avg_metrics.get('chrf', 0):.4f}
153
+ - ROUGE-L: {avg_metrics.get('rougeL', 0):.4f}
154
+
155
+ **Ranking:** #{updated_leaderboard[updated_leaderboard['model_path'] == model_path].index[0] + 1} out of {len(updated_leaderboard)} models
156
+ """
157
+
158
+ return results_msg, updated_leaderboard, leaderboard_plot, detailed_plot
159
+
160
+ except Exception as e:
161
+ error_msg = f"❌ Unexpected error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
162
+ print(error_msg)
163
+ return error_msg, None, None, None
164
+
165
+ def update_leaderboard_display(search_query: str = "") -> tuple:
166
+ """Update leaderboard display with optional search."""
167
+
168
+ global current_leaderboard
169
+ if current_leaderboard is None:
170
+ current_leaderboard = refresh_leaderboard()
171
+
172
+ # Apply search filter
173
+ if search_query:
174
+ filtered_df = search_models(current_leaderboard, search_query)
175
+ else:
176
+ filtered_df = current_leaderboard
177
+
178
+ # Create plots
179
+ leaderboard_plot = create_leaderboard_plot(filtered_df, 'quality_score')
180
+ summary_plot = create_summary_metrics_plot(filtered_df)
181
+
182
+ # Get summary stats
183
+ summary = get_leaderboard_summary(filtered_df)
184
+ summary_text = f"""
185
+ πŸ“Š **Leaderboard Summary**
186
+ - Total Models: {summary['total_models']}
187
+ - Average Quality Score: {summary['avg_quality_score']:.4f}
188
+ - Best Model: {summary['best_model']}
189
+ - Latest Submission: {summary['latest_submission'][:10] if summary['latest_submission'] != 'None' else 'None'}
190
+ """
191
+
192
+ return filtered_df, leaderboard_plot, summary_plot, summary_text
193
+
194
+ # Initialize data
195
+ print("Initializing SALT Translation Leaderboard...")
196
+ load_salt_data()
197
+ refresh_leaderboard()
198
+
199
+ # Create Gradio interface
200
+ with gr.Blocks(
201
+ title=TITLE,
202
+ theme=gr.themes.Soft(),
203
+ css="""
204
+ .gradio-container {
205
+ max-width: 1200px !important;
206
+ }
207
+ .main-header {
208
+ text-align: center;
209
+ margin-bottom: 2rem;
210
+ }
211
+ .metric-display {
212
+ background: #f8f9fa;
213
+ padding: 1rem;
214
+ border-radius: 0.5rem;
215
+ margin: 0.5rem 0;
216
+ }
217
+ """
218
+ ) as demo:
219
+
220
+ # Header
221
+ gr.Markdown(f"""
222
+ <div class="main-header">
223
+
224
+ # {TITLE}
225
+
226
+ {DESCRIPTION}
227
+
228
+ **Supported Languages:** Luganda (lug), Acholi (ach), Swahili (swa), English (eng)
229
+
230
+ </div>
231
+ """)
232
+
233
+ with gr.Tabs():
234
+
235
+ # Tab 1: Submit Model
236
+ with gr.Tab("πŸš€ Submit Model", id="submit"):
237
+
238
+ gr.Markdown("""
239
+ ### Submit Your Translation Model
240
+
241
+ Enter a HuggingFace model path (e.g., `microsoft/DialoGPT-medium`) or use `google-translate` to benchmark against Google Translate.
242
+
243
+ **Supported Model Types:** Gemma, Qwen, Llama, NLLB, Google Translate
244
+ """)
245
+
246
+ with gr.Row():
247
+ with gr.Column(scale=2):
248
+ model_input = gr.Textbox(
249
+ label="πŸ€— HuggingFace Model Path",
250
+ placeholder="e.g., Sunbird/gemma3-12b-ug40-merged",
251
+ info="Enter the full HuggingFace model path or 'google-translate'"
252
+ )
253
+
254
+ author_input = gr.Textbox(
255
+ label="πŸ‘€ Author/Organization",
256
+ placeholder="Your name or organization",
257
+ value="Anonymous"
258
+ )
259
+
260
+ submit_btn = gr.Button(
261
+ "πŸ”„ Evaluate Model",
262
+ variant="primary",
263
+ size="lg"
264
+ )
265
+
266
+ with gr.Column(scale=1):
267
+ gr.Markdown("""
268
+ **πŸ“‹ Evaluation Process:**
269
+ 1. Model validation
270
+ 2. Loading model weights
271
+ 3. Generating translations
272
+ 4. Calculating metrics
273
+ 5. Updating leaderboard
274
+
275
+ ⏱️ **Expected time:** 5-15 minutes
276
+ """)
277
+
278
+ # Results section
279
+ with gr.Group():
280
+ results_output = gr.Markdown(label="πŸ“Š Results")
281
+
282
+ with gr.Row():
283
+ with gr.Column():
284
+ results_leaderboard = gr.Dataframe(
285
+ label="πŸ“ˆ Updated Leaderboard",
286
+ interactive=False
287
+ )
288
+
289
+ with gr.Row():
290
+ results_plot = gr.Plot(label="πŸ“Š Leaderboard Ranking")
291
+ detailed_plot = gr.Plot(label="πŸ” Detailed Performance")
292
+
293
+ # Tab 2: Leaderboard
294
+ with gr.Tab("πŸ† Leaderboard", id="leaderboard"):
295
+
296
+ with gr.Row():
297
+ search_input = gr.Textbox(
298
+ label="πŸ” Search Models",
299
+ placeholder="Search by model name, author, or path...",
300
+ scale=3
301
+ )
302
+ refresh_btn = gr.Button("πŸ”„ Refresh", scale=1)
303
+
304
+ summary_stats = gr.Markdown(label="πŸ“Š Summary")
305
+
306
+ with gr.Row():
307
+ leaderboard_table = gr.Dataframe(
308
+ label="πŸ† Model Rankings",
309
+ interactive=False,
310
+ wrap=True
311
+ )
312
+
313
+ with gr.Row():
314
+ leaderboard_viz = gr.Plot(label="πŸ“Š Performance Comparison")
315
+ summary_viz = gr.Plot(label="πŸ“ˆ Top Models Summary")
316
+
317
+ # Tab 3: Documentation
318
+ with gr.Tab("πŸ“š Documentation", id="docs"):
319
+
320
+ gr.Markdown("""
321
+ ## πŸ“– How to Use the SALT Translation Leaderboard
322
+
323
+ ### πŸš€ Submitting Your Model
324
+
325
+ 1. **Prepare your model**: Ensure your model is uploaded to HuggingFace Hub
326
+ 2. **Enter model path**: Use the format `username/model-name`
327
+ 3. **Add your details**: Provide your name or organization
328
+ 4. **Submit**: Click "Evaluate Model" and wait for results
329
+
330
+ ### πŸ“Š Metrics Explained
331
+
332
+ - **Quality Score**: Combined metric (0-1, higher is better)
333
+ - **BLEU**: Translation quality (0-100, higher is better)
334
+ - **ChrF**: Character-level F-score (0-1, higher is better)
335
+ - **ROUGE-L**: Longest common subsequence (0-1, higher is better)
336
+ - **CER/WER**: Character/Word Error Rate (0-1, lower is better)
337
+
338
+ ### 🎯 Supported Models
339
+
340
+ - **Gemma**: Google's Gemma models fine-tuned for translation
341
+ - **Qwen**: Alibaba's Qwen models
342
+ - **Llama**: Meta's Llama models
343
+ - **NLLB**: Facebook's No Language Left Behind models
344
+ - **Google Translate**: Baseline comparison
345
+
346
+ ### πŸ“‹ Dataset Information
347
+
348
+ **SALT Dataset**: Sunbird AI's comprehensive translation dataset
349
+ - **Languages**: Luganda, Acholi, Swahili, English
350
+ - **Evaluation Size**: {MAX_EVAL_SAMPLES} samples
351
+ - **Domains**: Multiple domains including news, literature, and conversations
352
+
353
+ ### πŸ”„ API Access
354
+
355
+ The leaderboard data is available via HuggingFace Datasets:
356
+ ```python
357
+ from datasets import load_dataset
358
+ leaderboard = load_dataset("{LEADERBOARD_DATASET}")
359
+ ```
360
+
361
+ ### 🀝 Contributing
362
+
363
+ This leaderboard is maintained by [Sunbird AI](https://sunbird.ai).
364
+ For issues or suggestions, please contact us or submit a GitHub issue.
365
+
366
+ ### πŸ“œ License & Citation
367
+
368
+ If you use this leaderboard in your research, please cite:
369
+ ```
370
+ @misc{{salt_leaderboard_2024,
371
+ title={{SALT Translation Leaderboard}},
372
+ author={{Sunbird AI}},
373
+ year={{2024}},
374
+ url={{https://huggingface.co/spaces/Sunbird/salt-translation-leaderboard}}
375
+ }}
376
+ ```
377
+ """)
378
+
379
+ # Event handlers
380
+ submit_btn.click(
381
+ fn=evaluate_submission,
382
+ inputs=[model_input, author_input],
383
+ outputs=[results_output, results_leaderboard, results_plot, detailed_plot],
384
+ show_progress=True
385
+ )
386
+
387
+ refresh_btn.click(
388
+ fn=update_leaderboard_display,
389
+ inputs=[search_input],
390
+ outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats]
391
+ )
392
+
393
+ search_input.change(
394
+ fn=update_leaderboard_display,
395
+ inputs=[search_input],
396
+ outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats]
397
+ )
398
+
399
+ # Load initial leaderboard data
400
+ demo.load(
401
+ fn=update_leaderboard_display,
402
+ inputs=[],
403
+ outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats]
404
+ )
405
+
406
+ # Launch the app
407
+ if __name__ == "__main__":
408
+ demo.launch(
409
+ server_name="0.0.0.0",
410
+ server_port=7860,
411
+ share=False,
412
+ show_error=True
413
+ )