Spaces:
Running
Running
# src/plotting.py | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
import matplotlib.colors as mcolors | |
from colorsys import rgb_to_hls, hls_to_rgb | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
import numpy as np | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
from config import LANGUAGE_NAMES, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES, METRICS_CONFIG | |
plt.style.use('default') | |
plt.rcParams['figure.facecolor'] = 'white' | |
plt.rcParams['axes.facecolor'] = 'white' | |
def create_leaderboard_ranking_plot(df: pd.DataFrame, metric: str = 'quality_score', top_n: int = 15) -> go.Figure: | |
"""Create interactive leaderboard ranking plot using Plotly.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No data available", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
fig.update_layout(title="No Data Available") | |
return fig | |
# Get top N models | |
top_models = df.head(top_n) | |
# Create horizontal bar chart | |
fig = go.Figure(data=[ | |
go.Bar( | |
y=top_models['model_name'], | |
x=top_models[metric], | |
orientation='h', | |
marker=dict( | |
color=top_models[metric], | |
colorscale='Viridis', | |
showscale=True, | |
colorbar=dict(title=metric.replace('_', ' ').title()) | |
), | |
text=[f"{score:.3f}" for score in top_models[metric]], | |
textposition='auto', | |
hovertemplate=( | |
"<b>%{y}</b><br>" + | |
f"{metric.replace('_', ' ').title()}: %{{x:.4f}}<br>" + | |
"Author: %{customdata[0]}<br>" + | |
"Coverage: %{customdata[1]:.1%}<br>" + | |
"<extra></extra>" | |
), | |
customdata=list(zip(top_models['author'], top_models['coverage_rate'])) | |
) | |
]) | |
fig.update_layout( | |
title=f"π SALT Translation Leaderboard - {metric.replace('_', ' ').title()}", | |
xaxis_title=f"{metric.replace('_', ' ').title()} Score", | |
yaxis_title="Models", | |
height=max(400, len(top_models) * 30 + 100), | |
margin=dict(l=20, r=20, t=60, b=20), | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
# Reverse y-axis to show best model at top | |
fig.update_yaxes(autorange="reversed") | |
return fig | |
def create_metrics_comparison_plot(df: pd.DataFrame, models: List[str] = None, max_models: int = 8) -> go.Figure: | |
"""Create radar chart comparing multiple metrics across models.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title="No Data Available") | |
return fig | |
# Select models to compare | |
if models is None: | |
selected_models = df.head(max_models) | |
else: | |
selected_models = df[df['model_name'].isin(models)].head(max_models) | |
if len(selected_models) == 0: | |
fig = go.Figure() | |
fig.add_annotation(text="No models found", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title="No Models Found") | |
return fig | |
# Metrics to include in radar chart | |
metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL'] | |
metric_labels = ['Quality Score', 'BLEU (/100)', 'ChrF', 'ROUGE-1', 'ROUGE-L'] | |
fig = go.Figure() | |
colors = px.colors.qualitative.Set1[:len(selected_models)] | |
for i, (_, model) in enumerate(selected_models.iterrows()): | |
# Normalize BLEU to 0-1 scale for radar chart | |
values = [] | |
for metric in metrics: | |
value = model[metric] | |
if metric == 'bleu': | |
value = value / 100.0 # Normalize BLEU | |
values.append(value) | |
# Close the radar chart | |
values += values[:1] | |
metric_labels_closed = metric_labels + [metric_labels[0]] | |
fig.add_trace(go.Scatterpolar( | |
r=values, | |
theta=metric_labels_closed, | |
fill='toself', | |
name=model['model_name'], | |
line_color=colors[i % len(colors)], | |
fillcolor=colors[i % len(colors)], | |
opacity=0.6 | |
)) | |
fig.update_layout( | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 1] | |
) | |
), | |
showlegend=True, | |
title="π Multi-Metric Model Comparison", | |
height=600 | |
) | |
return fig | |
def create_language_pair_heatmap(results_dict: Dict, metric: str = 'quality_score') -> go.Figure: | |
"""Create heatmap showing performance across language pairs.""" | |
if not results_dict or 'pair_metrics' not in results_dict: | |
fig = go.Figure() | |
fig.add_annotation(text="No language pair data available", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title="No Language Pair Data Available") | |
return fig | |
pair_metrics = results_dict['pair_metrics'] | |
# Create matrix for heatmap | |
languages = ALL_UG40_LANGUAGES | |
matrix = np.zeros((len(languages), len(languages))) | |
for i, src_lang in enumerate(languages): | |
for j, tgt_lang in enumerate(languages): | |
if src_lang != tgt_lang: | |
pair_key = f"{src_lang}_to_{tgt_lang}" | |
if pair_key in pair_metrics and metric in pair_metrics[pair_key]: | |
matrix[i, j] = pair_metrics[pair_key][metric] | |
else: | |
matrix[i, j] = np.nan | |
else: | |
matrix[i, j] = np.nan | |
# Create language labels | |
lang_labels = [LANGUAGE_NAMES.get(lang, lang) for lang in languages] | |
fig = go.Figure(data=go.Heatmap( | |
z=matrix, | |
x=lang_labels, | |
y=lang_labels, | |
colorscale='Viridis', | |
showscale=True, | |
colorbar=dict(title=metric.replace('_', ' ').title()), | |
hovertemplate=( | |
"Source: %{y}<br>" + | |
"Target: %{x}<br>" + | |
f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" + | |
"<extra></extra>" | |
) | |
)) | |
fig.update_layout( | |
title=f"πΊοΈ Language Pair Performance - {metric.replace('_', ' ').title()}", | |
xaxis_title="Target Language", | |
yaxis_title="Source Language", | |
height=600, | |
width=700 | |
) | |
return fig | |
def create_coverage_analysis_plot(df: pd.DataFrame) -> go.Figure: | |
"""Create plot analyzing test set coverage across submissions.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title="No Data Available") | |
return fig | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=( | |
"Coverage Distribution", | |
"Language Pairs Covered", | |
"Sample Count vs Quality", | |
"Google Comparable Coverage" | |
), | |
specs=[[{"type": "bar"}, {"type": "scatter"}], | |
[{"type": "scatter"}, {"type": "bar"}]] | |
) | |
# Coverage distribution | |
coverage_bins = pd.cut(df['coverage_rate'], | |
bins=[0, 0.5, 0.8, 0.9, 0.95, 1.0], | |
labels=['<50%', '50-80%', '80-90%', '90-95%', '95-100%']) | |
coverage_counts = coverage_bins.value_counts() | |
fig.add_trace( | |
go.Bar(x=coverage_counts.index, y=coverage_counts.values, name="Coverage"), | |
row=1, col=1 | |
) | |
# Language pairs covered vs quality | |
fig.add_trace( | |
go.Scatter( | |
x=df['language_pairs_covered'], | |
y=df['quality_score'], | |
mode='markers', | |
text=df['model_name'], | |
name="Quality vs Coverage" | |
), | |
row=1, col=2 | |
) | |
# Sample count vs quality | |
fig.add_trace( | |
go.Scatter( | |
x=df['total_samples'], | |
y=df['quality_score'], | |
mode='markers', | |
text=df['model_name'], | |
name="Quality vs Samples" | |
), | |
row=2, col=1 | |
) | |
# Google comparable coverage | |
google_coverage = df['google_pairs_covered'].value_counts().sort_index() | |
fig.add_trace( | |
go.Bar(x=google_coverage.index, y=google_coverage.values, name="Google Coverage"), | |
row=2, col=2 | |
) | |
fig.update_layout( | |
title="π Test Set Coverage Analysis", | |
height=800, | |
showlegend=False | |
) | |
return fig | |
def create_model_performance_timeline(df: pd.DataFrame) -> go.Figure: | |
"""Create timeline showing model performance over time.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title="No Data Available") | |
return fig | |
# Convert submission_date to datetime | |
df_copy = df.copy() | |
df_copy['submission_date'] = pd.to_datetime(df_copy['submission_date']) | |
df_copy = df_copy.sort_values('submission_date') | |
fig = go.Figure() | |
# Add scatter plot for each submission | |
fig.add_trace(go.Scatter( | |
x=df_copy['submission_date'], | |
y=df_copy['quality_score'], | |
mode='markers+lines', | |
marker=dict( | |
size=10, | |
color=df_copy['quality_score'], | |
colorscale='Viridis', | |
showscale=True, | |
colorbar=dict(title="Quality Score") | |
), | |
text=df_copy['model_name'], | |
hovertemplate=( | |
"<b>%{text}</b><br>" + | |
"Date: %{x}<br>" + | |
"Quality Score: %{y:.4f}<br>" + | |
"<extra></extra>" | |
), | |
name="Models" | |
)) | |
# Add trend line | |
if len(df_copy) > 1: | |
z = np.polyfit(range(len(df_copy)), df_copy['quality_score'], 1) | |
trend_line = np.poly1d(z)(range(len(df_copy))) | |
fig.add_trace(go.Scatter( | |
x=df_copy['submission_date'], | |
y=trend_line, | |
mode='lines', | |
line=dict(dash='dash', color='red'), | |
name="Trend", | |
hoverinfo='skip' | |
)) | |
fig.update_layout( | |
title="π Model Performance Timeline", | |
xaxis_title="Submission Date", | |
yaxis_title="Quality Score", | |
height=500 | |
) | |
return fig | |
def create_google_comparison_plot(df: pd.DataFrame) -> go.Figure: | |
"""Create plot comparing models on Google Translate-comparable language pairs.""" | |
# Filter models that have Google comparable results | |
google_models = df[df['google_pairs_covered'] > 0].copy() | |
if google_models.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No models with Google Translate comparable results", | |
x=0.5, y=0.5, showarrow=False | |
) | |
fig.update_layout(title="No Google Comparable Models") | |
return fig | |
fig = go.Figure() | |
# Create scatter plot | |
fig.add_trace(go.Scatter( | |
x=google_models['google_bleu'], | |
y=google_models['google_quality_score'], | |
mode='markers+text', | |
marker=dict( | |
size=12, | |
color=google_models['google_chrf'], | |
colorscale='Plasma', | |
showscale=True, | |
colorbar=dict(title="ChrF Score") | |
), | |
text=google_models['model_name'], | |
textposition="top center", | |
hovertemplate=( | |
"<b>%{text}</b><br>" + | |
"BLEU: %{x:.2f}<br>" + | |
"Quality: %{y:.4f}<br>" + | |
"ChrF: %{marker.color:.4f}<br>" + | |
"<extra></extra>" | |
), | |
name="Models" | |
)) | |
fig.update_layout( | |
title="π€ Google Translate Comparable Performance", | |
xaxis_title="BLEU Score", | |
yaxis_title="Quality Score", | |
height=500 | |
) | |
return fig | |
def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.Figure: | |
"""Create detailed analysis plot for a specific model - FIXED version.""" | |
if not model_results or 'pair_metrics' not in model_results: | |
fig = go.Figure() | |
fig.add_annotation(text="No detailed results available", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title=f"No Data for {model_name}") | |
return fig | |
pair_metrics = model_results['pair_metrics'] | |
# Extract language pair data | |
pairs = [] | |
bleu_scores = [] | |
quality_scores = [] | |
sample_counts = [] | |
google_comparable = [] | |
for pair_key, metrics in pair_metrics.items(): | |
if 'sample_count' in metrics and metrics['sample_count'] > 0: | |
src, tgt = pair_key.split('_to_') | |
pair_label = f"{LANGUAGE_NAMES.get(src, src)} β {LANGUAGE_NAMES.get(tgt, tgt)}" | |
pairs.append(pair_label) | |
bleu_scores.append(metrics.get('bleu', 0)) | |
quality_scores.append(metrics.get('quality_score', 0)) | |
sample_counts.append(metrics.get('sample_count', 0)) | |
is_google = (src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES) | |
google_comparable.append(is_google) | |
if not pairs: | |
fig = go.Figure() | |
fig.add_annotation(text="No language pair data found", x=0.5, y=0.5, showarrow=False) | |
fig.update_layout(title=f"No Language Pair Data for {model_name}") | |
return fig | |
# Create subplot with proper spacing and titles | |
fig = make_subplots( | |
rows=2, cols=1, | |
subplot_titles=( | |
f"BLEU Scores by Language Pair", | |
f"Quality Scores by Language Pair" | |
), | |
vertical_spacing=0.15, | |
row_heights=[0.45, 0.45] | |
) | |
# Color code by Google comparable | |
colors = ['#1f77b4' if gc else '#ff7f0e' for gc in google_comparable] | |
# BLEU scores (top subplot) | |
fig.add_trace( | |
go.Bar( | |
x=pairs, | |
y=bleu_scores, | |
marker_color=colors, | |
name="BLEU", | |
text=[f"{score:.1f}" for score in bleu_scores], | |
textposition='outside', | |
textfont=dict(size=10), | |
showlegend=True | |
), | |
row=1, col=1 | |
) | |
# Quality scores (bottom subplot) | |
fig.add_trace( | |
go.Bar( | |
x=pairs, | |
y=quality_scores, | |
marker_color=colors, | |
name="Quality", | |
text=[f"{score:.3f}" for score in quality_scores], | |
textposition='outside', | |
textfont=dict(size=10), | |
showlegend=False | |
), | |
row=2, col=1 | |
) | |
# Update layout | |
fig.update_layout( | |
height=900, | |
title=dict( | |
text=f"π Detailed Analysis: {model_name}", | |
x=0.5, | |
xanchor='center' | |
), | |
showlegend=True, | |
margin=dict(l=50, r=50, t=100, b=150) | |
) | |
# Update x-axes to rotate labels properly | |
fig.update_xaxes( | |
tickangle=45, | |
tickfont=dict(size=10), | |
row=1, col=1 | |
) | |
fig.update_xaxes( | |
tickangle=45, | |
tickfont=dict(size=10), | |
row=2, col=1 | |
) | |
# Update y-axes | |
fig.update_yaxes(title_text="BLEU Score", row=1, col=1) | |
fig.update_yaxes(title_text="Quality Score", row=2, col=1) | |
# Add legend manually for Google vs UG40 only | |
fig.add_trace( | |
go.Scatter( | |
x=[None], y=[None], | |
mode='markers', | |
marker=dict(size=15, color='#1f77b4', symbol='square'), | |
name="Google Comparable", | |
showlegend=True | |
) | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=[None], y=[None], | |
mode='markers', | |
marker=dict(size=15, color='#ff7f0e', symbol='square'), | |
name="UG40 Only", | |
showlegend=True | |
) | |
) | |
return fig | |
def create_submission_summary_plot(validation_info: Dict, evaluation_results: Dict) -> go.Figure: | |
"""Create summary plot for a new submission.""" | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=( | |
"Sample Distribution", | |
"Primary Metrics", | |
"Error Analysis", | |
"Coverage Summary" | |
), | |
specs=[[{"type": "pie"}, {"type": "bar"}], | |
[{"type": "bar"}, {"type": "bar"}]] | |
) | |
# Sample distribution (pie chart) | |
coverage = validation_info.get('coverage', 0.8) | |
fig.add_trace( | |
go.Pie( | |
labels=["Evaluated", "Missing"], | |
values=[coverage * 100, (1 - coverage) * 100], | |
name="Samples" | |
), | |
row=1, col=1 | |
) | |
# Primary metrics | |
if 'summary' in evaluation_results: | |
metrics_data = evaluation_results['summary']['primary_metrics'] | |
metric_names = list(metrics_data.keys()) | |
metric_values = list(metrics_data.values()) | |
fig.add_trace( | |
go.Bar( | |
x=metric_names, | |
y=metric_values, | |
name="Metrics", | |
text=[f"{val:.3f}" for val in metric_values], | |
textposition='auto' | |
), | |
row=1, col=2 | |
) | |
# Error analysis (CER, WER) | |
if 'averages' in evaluation_results: | |
error_metrics = ['cer', 'wer'] | |
error_values = [evaluation_results['averages'].get(m, 0) for m in error_metrics] | |
fig.add_trace( | |
go.Bar( | |
x=error_metrics, | |
y=error_values, | |
name="Errors", | |
text=[f"{val:.3f}" for val in error_values], | |
textposition='auto' | |
), | |
row=2, col=1 | |
) | |
# Coverage summary | |
if 'summary' in evaluation_results: | |
summary = evaluation_results['summary'] | |
coverage_labels = ["Total Samples", "Lang Pairs", "Google Pairs"] | |
coverage_values = [ | |
summary.get('total_samples', 0), | |
summary.get('language_pairs_covered', 0), | |
summary.get('google_comparable_pairs', 0) | |
] | |
fig.add_trace( | |
go.Bar( | |
x=coverage_labels, | |
y=coverage_values, | |
name="Coverage", | |
text=[f"{val}" for val in coverage_values], | |
textposition='auto' | |
), | |
row=2, col=2 | |
) | |
fig.update_layout( | |
title="π Submission Summary", | |
height=700, | |
showlegend=False | |
) | |
return fig |