leaderboard / src /plotting.py
akera's picture
Update src/plotting.py
ce626d3 verified
raw
history blame
18.6 kB
# src/plotting.py
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
from colorsys import rgb_to_hls, hls_to_rgb
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import Dict, List, Optional, Union
from config import LANGUAGE_NAMES, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES, METRICS_CONFIG
plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
def create_leaderboard_ranking_plot(df: pd.DataFrame, metric: str = 'quality_score', top_n: int = 15) -> go.Figure:
"""Create interactive leaderboard ranking plot using Plotly."""
if df.empty:
fig = go.Figure()
fig.add_annotation(
text="No data available",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
fig.update_layout(title="No Data Available")
return fig
# Get top N models
top_models = df.head(top_n)
# Create horizontal bar chart
fig = go.Figure(data=[
go.Bar(
y=top_models['model_name'],
x=top_models[metric],
orientation='h',
marker=dict(
color=top_models[metric],
colorscale='Viridis',
showscale=True,
colorbar=dict(title=metric.replace('_', ' ').title())
),
text=[f"{score:.3f}" for score in top_models[metric]],
textposition='auto',
hovertemplate=(
"<b>%{y}</b><br>" +
f"{metric.replace('_', ' ').title()}: %{{x:.4f}}<br>" +
"Author: %{customdata[0]}<br>" +
"Coverage: %{customdata[1]:.1%}<br>" +
"<extra></extra>"
),
customdata=list(zip(top_models['author'], top_models['coverage_rate']))
)
])
fig.update_layout(
title=f"πŸ† SALT Translation Leaderboard - {metric.replace('_', ' ').title()}",
xaxis_title=f"{metric.replace('_', ' ').title()} Score",
yaxis_title="Models",
height=max(400, len(top_models) * 30 + 100),
margin=dict(l=20, r=20, t=60, b=20),
plot_bgcolor='white',
paper_bgcolor='white'
)
# Reverse y-axis to show best model at top
fig.update_yaxes(autorange="reversed")
return fig
def create_metrics_comparison_plot(df: pd.DataFrame, models: List[str] = None, max_models: int = 8) -> go.Figure:
"""Create radar chart comparing multiple metrics across models."""
if df.empty:
fig = go.Figure()
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title="No Data Available")
return fig
# Select models to compare
if models is None:
selected_models = df.head(max_models)
else:
selected_models = df[df['model_name'].isin(models)].head(max_models)
if len(selected_models) == 0:
fig = go.Figure()
fig.add_annotation(text="No models found", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title="No Models Found")
return fig
# Metrics to include in radar chart
metrics = ['quality_score', 'bleu', 'chrf', 'rouge1', 'rougeL']
metric_labels = ['Quality Score', 'BLEU (/100)', 'ChrF', 'ROUGE-1', 'ROUGE-L']
fig = go.Figure()
colors = px.colors.qualitative.Set1[:len(selected_models)]
for i, (_, model) in enumerate(selected_models.iterrows()):
# Normalize BLEU to 0-1 scale for radar chart
values = []
for metric in metrics:
value = model[metric]
if metric == 'bleu':
value = value / 100.0 # Normalize BLEU
values.append(value)
# Close the radar chart
values += values[:1]
metric_labels_closed = metric_labels + [metric_labels[0]]
fig.add_trace(go.Scatterpolar(
r=values,
theta=metric_labels_closed,
fill='toself',
name=model['model_name'],
line_color=colors[i % len(colors)],
fillcolor=colors[i % len(colors)],
opacity=0.6
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)
),
showlegend=True,
title="πŸ“Š Multi-Metric Model Comparison",
height=600
)
return fig
def create_language_pair_heatmap(results_dict: Dict, metric: str = 'quality_score') -> go.Figure:
"""Create heatmap showing performance across language pairs."""
if not results_dict or 'pair_metrics' not in results_dict:
fig = go.Figure()
fig.add_annotation(text="No language pair data available", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title="No Language Pair Data Available")
return fig
pair_metrics = results_dict['pair_metrics']
# Create matrix for heatmap
languages = ALL_UG40_LANGUAGES
matrix = np.zeros((len(languages), len(languages)))
for i, src_lang in enumerate(languages):
for j, tgt_lang in enumerate(languages):
if src_lang != tgt_lang:
pair_key = f"{src_lang}_to_{tgt_lang}"
if pair_key in pair_metrics and metric in pair_metrics[pair_key]:
matrix[i, j] = pair_metrics[pair_key][metric]
else:
matrix[i, j] = np.nan
else:
matrix[i, j] = np.nan
# Create language labels
lang_labels = [LANGUAGE_NAMES.get(lang, lang) for lang in languages]
fig = go.Figure(data=go.Heatmap(
z=matrix,
x=lang_labels,
y=lang_labels,
colorscale='Viridis',
showscale=True,
colorbar=dict(title=metric.replace('_', ' ').title()),
hovertemplate=(
"Source: %{y}<br>" +
"Target: %{x}<br>" +
f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" +
"<extra></extra>"
)
))
fig.update_layout(
title=f"πŸ—ΊοΈ Language Pair Performance - {metric.replace('_', ' ').title()}",
xaxis_title="Target Language",
yaxis_title="Source Language",
height=600,
width=700
)
return fig
def create_coverage_analysis_plot(df: pd.DataFrame) -> go.Figure:
"""Create plot analyzing test set coverage across submissions."""
if df.empty:
fig = go.Figure()
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title="No Data Available")
return fig
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
"Coverage Distribution",
"Language Pairs Covered",
"Sample Count vs Quality",
"Google Comparable Coverage"
),
specs=[[{"type": "bar"}, {"type": "scatter"}],
[{"type": "scatter"}, {"type": "bar"}]]
)
# Coverage distribution
coverage_bins = pd.cut(df['coverage_rate'],
bins=[0, 0.5, 0.8, 0.9, 0.95, 1.0],
labels=['<50%', '50-80%', '80-90%', '90-95%', '95-100%'])
coverage_counts = coverage_bins.value_counts()
fig.add_trace(
go.Bar(x=coverage_counts.index, y=coverage_counts.values, name="Coverage"),
row=1, col=1
)
# Language pairs covered vs quality
fig.add_trace(
go.Scatter(
x=df['language_pairs_covered'],
y=df['quality_score'],
mode='markers',
text=df['model_name'],
name="Quality vs Coverage"
),
row=1, col=2
)
# Sample count vs quality
fig.add_trace(
go.Scatter(
x=df['total_samples'],
y=df['quality_score'],
mode='markers',
text=df['model_name'],
name="Quality vs Samples"
),
row=2, col=1
)
# Google comparable coverage
google_coverage = df['google_pairs_covered'].value_counts().sort_index()
fig.add_trace(
go.Bar(x=google_coverage.index, y=google_coverage.values, name="Google Coverage"),
row=2, col=2
)
fig.update_layout(
title="πŸ“ˆ Test Set Coverage Analysis",
height=800,
showlegend=False
)
return fig
def create_model_performance_timeline(df: pd.DataFrame) -> go.Figure:
"""Create timeline showing model performance over time."""
if df.empty:
fig = go.Figure()
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title="No Data Available")
return fig
# Convert submission_date to datetime
df_copy = df.copy()
df_copy['submission_date'] = pd.to_datetime(df_copy['submission_date'])
df_copy = df_copy.sort_values('submission_date')
fig = go.Figure()
# Add scatter plot for each submission
fig.add_trace(go.Scatter(
x=df_copy['submission_date'],
y=df_copy['quality_score'],
mode='markers+lines',
marker=dict(
size=10,
color=df_copy['quality_score'],
colorscale='Viridis',
showscale=True,
colorbar=dict(title="Quality Score")
),
text=df_copy['model_name'],
hovertemplate=(
"<b>%{text}</b><br>" +
"Date: %{x}<br>" +
"Quality Score: %{y:.4f}<br>" +
"<extra></extra>"
),
name="Models"
))
# Add trend line
if len(df_copy) > 1:
z = np.polyfit(range(len(df_copy)), df_copy['quality_score'], 1)
trend_line = np.poly1d(z)(range(len(df_copy)))
fig.add_trace(go.Scatter(
x=df_copy['submission_date'],
y=trend_line,
mode='lines',
line=dict(dash='dash', color='red'),
name="Trend",
hoverinfo='skip'
))
fig.update_layout(
title="πŸ“… Model Performance Timeline",
xaxis_title="Submission Date",
yaxis_title="Quality Score",
height=500
)
return fig
def create_google_comparison_plot(df: pd.DataFrame) -> go.Figure:
"""Create plot comparing models on Google Translate-comparable language pairs."""
# Filter models that have Google comparable results
google_models = df[df['google_pairs_covered'] > 0].copy()
if google_models.empty:
fig = go.Figure()
fig.add_annotation(
text="No models with Google Translate comparable results",
x=0.5, y=0.5, showarrow=False
)
fig.update_layout(title="No Google Comparable Models")
return fig
fig = go.Figure()
# Create scatter plot
fig.add_trace(go.Scatter(
x=google_models['google_bleu'],
y=google_models['google_quality_score'],
mode='markers+text',
marker=dict(
size=12,
color=google_models['google_chrf'],
colorscale='Plasma',
showscale=True,
colorbar=dict(title="ChrF Score")
),
text=google_models['model_name'],
textposition="top center",
hovertemplate=(
"<b>%{text}</b><br>" +
"BLEU: %{x:.2f}<br>" +
"Quality: %{y:.4f}<br>" +
"ChrF: %{marker.color:.4f}<br>" +
"<extra></extra>"
),
name="Models"
))
fig.update_layout(
title="πŸ€– Google Translate Comparable Performance",
xaxis_title="BLEU Score",
yaxis_title="Quality Score",
height=500
)
return fig
def create_detailed_model_analysis(model_results: Dict, model_name: str) -> go.Figure:
"""Create detailed analysis plot for a specific model - FIXED version."""
if not model_results or 'pair_metrics' not in model_results:
fig = go.Figure()
fig.add_annotation(text="No detailed results available", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title=f"No Data for {model_name}")
return fig
pair_metrics = model_results['pair_metrics']
# Extract language pair data
pairs = []
bleu_scores = []
quality_scores = []
sample_counts = []
google_comparable = []
for pair_key, metrics in pair_metrics.items():
if 'sample_count' in metrics and metrics['sample_count'] > 0:
src, tgt = pair_key.split('_to_')
pair_label = f"{LANGUAGE_NAMES.get(src, src)} β†’ {LANGUAGE_NAMES.get(tgt, tgt)}"
pairs.append(pair_label)
bleu_scores.append(metrics.get('bleu', 0))
quality_scores.append(metrics.get('quality_score', 0))
sample_counts.append(metrics.get('sample_count', 0))
is_google = (src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES)
google_comparable.append(is_google)
if not pairs:
fig = go.Figure()
fig.add_annotation(text="No language pair data found", x=0.5, y=0.5, showarrow=False)
fig.update_layout(title=f"No Language Pair Data for {model_name}")
return fig
# Create subplot with proper spacing and titles
fig = make_subplots(
rows=2, cols=1,
subplot_titles=(
f"BLEU Scores by Language Pair",
f"Quality Scores by Language Pair"
),
vertical_spacing=0.15,
row_heights=[0.45, 0.45]
)
# Color code by Google comparable
colors = ['#1f77b4' if gc else '#ff7f0e' for gc in google_comparable]
# BLEU scores (top subplot)
fig.add_trace(
go.Bar(
x=pairs,
y=bleu_scores,
marker_color=colors,
name="BLEU",
text=[f"{score:.1f}" for score in bleu_scores],
textposition='outside',
textfont=dict(size=10),
showlegend=True
),
row=1, col=1
)
# Quality scores (bottom subplot)
fig.add_trace(
go.Bar(
x=pairs,
y=quality_scores,
marker_color=colors,
name="Quality",
text=[f"{score:.3f}" for score in quality_scores],
textposition='outside',
textfont=dict(size=10),
showlegend=False
),
row=2, col=1
)
# Update layout
fig.update_layout(
height=900,
title=dict(
text=f"πŸ“Š Detailed Analysis: {model_name}",
x=0.5,
xanchor='center'
),
showlegend=True,
margin=dict(l=50, r=50, t=100, b=150)
)
# Update x-axes to rotate labels properly
fig.update_xaxes(
tickangle=45,
tickfont=dict(size=10),
row=1, col=1
)
fig.update_xaxes(
tickangle=45,
tickfont=dict(size=10),
row=2, col=1
)
# Update y-axes
fig.update_yaxes(title_text="BLEU Score", row=1, col=1)
fig.update_yaxes(title_text="Quality Score", row=2, col=1)
# Add legend manually for Google vs UG40 only
fig.add_trace(
go.Scatter(
x=[None], y=[None],
mode='markers',
marker=dict(size=15, color='#1f77b4', symbol='square'),
name="Google Comparable",
showlegend=True
)
)
fig.add_trace(
go.Scatter(
x=[None], y=[None],
mode='markers',
marker=dict(size=15, color='#ff7f0e', symbol='square'),
name="UG40 Only",
showlegend=True
)
)
return fig
def create_submission_summary_plot(validation_info: Dict, evaluation_results: Dict) -> go.Figure:
"""Create summary plot for a new submission."""
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
"Sample Distribution",
"Primary Metrics",
"Error Analysis",
"Coverage Summary"
),
specs=[[{"type": "pie"}, {"type": "bar"}],
[{"type": "bar"}, {"type": "bar"}]]
)
# Sample distribution (pie chart)
coverage = validation_info.get('coverage', 0.8)
fig.add_trace(
go.Pie(
labels=["Evaluated", "Missing"],
values=[coverage * 100, (1 - coverage) * 100],
name="Samples"
),
row=1, col=1
)
# Primary metrics
if 'summary' in evaluation_results:
metrics_data = evaluation_results['summary']['primary_metrics']
metric_names = list(metrics_data.keys())
metric_values = list(metrics_data.values())
fig.add_trace(
go.Bar(
x=metric_names,
y=metric_values,
name="Metrics",
text=[f"{val:.3f}" for val in metric_values],
textposition='auto'
),
row=1, col=2
)
# Error analysis (CER, WER)
if 'averages' in evaluation_results:
error_metrics = ['cer', 'wer']
error_values = [evaluation_results['averages'].get(m, 0) for m in error_metrics]
fig.add_trace(
go.Bar(
x=error_metrics,
y=error_values,
name="Errors",
text=[f"{val:.3f}" for val in error_values],
textposition='auto'
),
row=2, col=1
)
# Coverage summary
if 'summary' in evaluation_results:
summary = evaluation_results['summary']
coverage_labels = ["Total Samples", "Lang Pairs", "Google Pairs"]
coverage_values = [
summary.get('total_samples', 0),
summary.get('language_pairs_covered', 0),
summary.get('google_comparable_pairs', 0)
]
fig.add_trace(
go.Bar(
x=coverage_labels,
y=coverage_values,
name="Coverage",
text=[f"{val}" for val in coverage_values],
textposition='auto'
),
row=2, col=2
)
fig.update_layout(
title="πŸ“‹ Submission Summary",
height=700,
showlegend=False
)
return fig