Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset | |
import yaml | |
import json | |
import torch | |
from datetime import datetime | |
import traceback | |
# Import our modules | |
from src.model_loader import load_model, get_model_info | |
from src.evaluation import evaluate_model_full | |
from src.leaderboard import load_leaderboard, add_model_results, get_leaderboard_summary, search_models | |
from src.plotting import create_leaderboard_plot, create_detailed_comparison_plot, create_summary_metrics_plot | |
from src.utils import validate_model_path, get_model_type, sanitize_input | |
from config import * | |
# Global variables for caching | |
current_leaderboard = None | |
test_data = None | |
def load_salt_data(): | |
"""Load SALT dataset for evaluation.""" | |
global test_data | |
if test_data is not None: | |
return test_data | |
try: | |
print("Loading SALT dataset...") | |
# Configuration for SALT dataset | |
dataset_config = f''' | |
huggingface_load: | |
path: {SALT_DATASET} | |
name: text-all | |
split: dev[:{MAX_EVAL_SAMPLES}] | |
source: | |
type: text | |
language: {SUPPORTED_LANGUAGES} | |
target: | |
type: text | |
language: {SUPPORTED_LANGUAGES} | |
src_or_tgt_languages_must_contain: eng | |
allow_same_src_and_tgt_language: False | |
''' | |
config = yaml.safe_load(dataset_config) | |
# Import salt dataset utilities | |
import salt.dataset | |
test_data = pd.DataFrame(salt.dataset.create(config)) | |
print(f"Loaded {len(test_data)} evaluation samples") | |
return test_data | |
except Exception as e: | |
print(f"Error loading SALT dataset: {e}") | |
# Fallback: create minimal test data | |
test_data = pd.DataFrame({ | |
'source': ['Hello world', 'How are you?'], | |
'target': ['Amakuru', 'Oli otya?'], | |
'source.language': ['eng', 'eng'], | |
'target.language': ['lug', 'lug'] | |
}) | |
return test_data | |
def refresh_leaderboard(): | |
"""Refresh leaderboard data.""" | |
global current_leaderboard | |
current_leaderboard = load_leaderboard() | |
return current_leaderboard | |
def evaluate_submission(model_path: str, author_name: str) -> tuple: | |
"""Main evaluation function.""" | |
try: | |
# Validate inputs | |
model_path = sanitize_input(model_path) | |
author_name = sanitize_input(author_name) | |
if not model_path: | |
return "β Error: Model path is required", None, None, None | |
if not author_name: | |
author_name = "Anonymous" | |
if not validate_model_path(model_path): | |
return "β Error: Invalid model path format", None, None, None | |
# Load test data | |
test_data = load_salt_data() | |
if test_data is None or len(test_data) == 0: | |
return "β Error: Could not load evaluation data", None, None, None | |
# Get model info | |
print(f"Getting model info for: {model_path}") | |
model_info = get_model_info(model_path) | |
model_type = get_model_type(model_path) | |
# Load model | |
print(f"Loading model: {model_path}") | |
try: | |
model, tokenizer = load_model(model_path) | |
except Exception as e: | |
return f"β Error loading model: {str(e)}", None, None, None | |
# Run evaluation | |
print("Starting evaluation...") | |
try: | |
detailed_metrics = evaluate_model_full(model, tokenizer, model_path, test_data) | |
except Exception as e: | |
return f"β Error during evaluation: {str(e)}", None, None, None | |
# Extract average metrics | |
avg_metrics = detailed_metrics.get('averages', {}) | |
if not avg_metrics: | |
return "β Error: No metrics calculated", None, None, None | |
# Add results to leaderboard | |
print("Adding results to leaderboard...") | |
updated_leaderboard = add_model_results( | |
model_path=model_path, | |
author=author_name, | |
metrics=avg_metrics, | |
detailed_metrics=detailed_metrics, | |
evaluation_samples=len(test_data), | |
model_type=model_type | |
) | |
# Update global leaderboard | |
global current_leaderboard | |
current_leaderboard = updated_leaderboard | |
# Create visualizations | |
leaderboard_plot = create_leaderboard_plot(updated_leaderboard, 'quality_score') | |
detailed_plot = create_detailed_comparison_plot({model_path: detailed_metrics}, [model_path]) | |
# Format results message | |
results_msg = f""" | |
β **Evaluation Complete!** | |
**Model:** {model_path} | |
**Author:** {author_name} | |
**Type:** {model_type} | |
**Results:** | |
- Quality Score: {avg_metrics.get('quality_score', 0):.4f} | |
- BLEU: {avg_metrics.get('bleu', 0):.2f} | |
- ChrF: {avg_metrics.get('chrf', 0):.4f} | |
- ROUGE-L: {avg_metrics.get('rougeL', 0):.4f} | |
**Ranking:** #{updated_leaderboard[updated_leaderboard['model_path'] == model_path].index[0] + 1} out of {len(updated_leaderboard)} models | |
""" | |
return results_msg, updated_leaderboard, leaderboard_plot, detailed_plot | |
except Exception as e: | |
error_msg = f"β Unexpected error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
print(error_msg) | |
return error_msg, None, None, None | |
def update_leaderboard_display(search_query: str = "") -> tuple: | |
"""Update leaderboard display with optional search.""" | |
global current_leaderboard | |
if current_leaderboard is None: | |
current_leaderboard = refresh_leaderboard() | |
# Apply search filter | |
if search_query: | |
filtered_df = search_models(current_leaderboard, search_query) | |
else: | |
filtered_df = current_leaderboard | |
# Create plots | |
leaderboard_plot = create_leaderboard_plot(filtered_df, 'quality_score') | |
summary_plot = create_summary_metrics_plot(filtered_df) | |
# Get summary stats | |
summary = get_leaderboard_summary(filtered_df) | |
summary_text = f""" | |
π **Leaderboard Summary** | |
- Total Models: {summary['total_models']} | |
- Average Quality Score: {summary['avg_quality_score']:.4f} | |
- Best Model: {summary['best_model']} | |
- Latest Submission: {summary['latest_submission'][:10] if summary['latest_submission'] != 'None' else 'None'} | |
""" | |
return filtered_df, leaderboard_plot, summary_plot, summary_text | |
# Initialize data | |
print("Initializing SALT Translation Leaderboard...") | |
load_salt_data() | |
refresh_leaderboard() | |
# Create Gradio interface | |
with gr.Blocks( | |
title=TITLE, | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.main-header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
.metric-display { | |
background: #f8f9fa; | |
padding: 1rem; | |
border-radius: 0.5rem; | |
margin: 0.5rem 0; | |
} | |
""" | |
) as demo: | |
# Header | |
gr.Markdown(f""" | |
<div class="main-header"> | |
# {TITLE} | |
{DESCRIPTION} | |
**Supported Languages:** Luganda (lug), Acholi (ach), Swahili (swa), English (eng) | |
</div> | |
""") | |
with gr.Tabs(): | |
# Tab 1: Submit Model | |
with gr.Tab("π Submit Model", id="submit"): | |
gr.Markdown(""" | |
### Submit Your Translation Model | |
Enter a HuggingFace model path (e.g., `microsoft/DialoGPT-medium`) or use `google-translate` to benchmark against Google Translate. | |
**Supported Model Types:** Gemma, Qwen, Llama, NLLB, Google Translate | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_input = gr.Textbox( | |
label="π€ HuggingFace Model Path", | |
placeholder="e.g., Sunbird/gemma3-12b-ug40-merged", | |
info="Enter the full HuggingFace model path or 'google-translate'" | |
) | |
author_input = gr.Textbox( | |
label="π€ Author/Organization", | |
placeholder="Your name or organization", | |
value="Anonymous" | |
) | |
submit_btn = gr.Button( | |
"π Evaluate Model", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown(""" | |
**π Evaluation Process:** | |
1. Model validation | |
2. Loading model weights | |
3. Generating translations | |
4. Calculating metrics | |
5. Updating leaderboard | |
β±οΈ **Expected time:** 5-15 minutes | |
""") | |
# Results section | |
with gr.Group(): | |
results_output = gr.Markdown(label="π Results") | |
with gr.Row(): | |
with gr.Column(): | |
results_leaderboard = gr.Dataframe( | |
label="π Updated Leaderboard", | |
interactive=False | |
) | |
with gr.Row(): | |
results_plot = gr.Plot(label="π Leaderboard Ranking") | |
detailed_plot = gr.Plot(label="π Detailed Performance") | |
# Tab 2: Leaderboard | |
with gr.Tab("π Leaderboard", id="leaderboard"): | |
with gr.Row(): | |
search_input = gr.Textbox( | |
label="π Search Models", | |
placeholder="Search by model name, author, or path...", | |
scale=3 | |
) | |
refresh_btn = gr.Button("π Refresh", scale=1) | |
summary_stats = gr.Markdown(label="π Summary") | |
with gr.Row(): | |
leaderboard_table = gr.Dataframe( | |
label="π Model Rankings", | |
interactive=False, | |
wrap=True | |
) | |
with gr.Row(): | |
leaderboard_viz = gr.Plot(label="π Performance Comparison") | |
summary_viz = gr.Plot(label="π Top Models Summary") | |
# Tab 3: Documentation | |
with gr.Tab("π Documentation", id="docs"): | |
gr.Markdown(""" | |
## π How to Use the SALT Translation Leaderboard | |
### π Submitting Your Model | |
1. **Prepare your model**: Ensure your model is uploaded to HuggingFace Hub | |
2. **Enter model path**: Use the format `username/model-name` | |
3. **Add your details**: Provide your name or organization | |
4. **Submit**: Click "Evaluate Model" and wait for results | |
### π Metrics Explained | |
- **Quality Score**: Combined metric (0-1, higher is better) | |
- **BLEU**: Translation quality (0-100, higher is better) | |
- **ChrF**: Character-level F-score (0-1, higher is better) | |
- **ROUGE-L**: Longest common subsequence (0-1, higher is better) | |
- **CER/WER**: Character/Word Error Rate (0-1, lower is better) | |
### π― Supported Models | |
- **Gemma**: Google's Gemma models fine-tuned for translation | |
- **Qwen**: Alibaba's Qwen models | |
- **Llama**: Meta's Llama models | |
- **NLLB**: Facebook's No Language Left Behind models | |
- **Google Translate**: Baseline comparison | |
### π Dataset Information | |
**SALT Dataset**: Sunbird AI's comprehensive translation dataset | |
- **Languages**: Luganda, Acholi, Swahili, English | |
- **Evaluation Size**: {MAX_EVAL_SAMPLES} samples | |
- **Domains**: Multiple domains including news, literature, and conversations | |
### π API Access | |
The leaderboard data is available via HuggingFace Datasets: | |
```python | |
from datasets import load_dataset | |
leaderboard = load_dataset("{LEADERBOARD_DATASET}") | |
``` | |
### π€ Contributing | |
This leaderboard is maintained by [Sunbird AI](https://sunbird.ai). | |
For issues or suggestions, please contact us or submit a GitHub issue. | |
### π License & Citation | |
If you use this leaderboard in your research, please cite: | |
``` | |
@misc{{salt_leaderboard_2024, | |
title={{SALT Translation Leaderboard}}, | |
author={{Sunbird AI}}, | |
year={{2024}}, | |
url={{https://huggingface.co/spaces/Sunbird/salt-translation-leaderboard}} | |
}} | |
``` | |
""") | |
# Event handlers | |
submit_btn.click( | |
fn=evaluate_submission, | |
inputs=[model_input, author_input], | |
outputs=[results_output, results_leaderboard, results_plot, detailed_plot], | |
show_progress=True | |
) | |
refresh_btn.click( | |
fn=update_leaderboard_display, | |
inputs=[search_input], | |
outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats] | |
) | |
search_input.change( | |
fn=update_leaderboard_display, | |
inputs=[search_input], | |
outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats] | |
) | |
# Load initial leaderboard data | |
demo.load( | |
fn=update_leaderboard_display, | |
inputs=[], | |
outputs=[leaderboard_table, leaderboard_viz, summary_viz, summary_stats] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |