Spaces:
Running
Running
# src/plotting.py | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
import numpy as np | |
import json | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
from config import ( | |
LANGUAGE_NAMES, | |
ALL_UG40_LANGUAGES, | |
GOOGLE_SUPPORTED_LANGUAGES, | |
METRICS_CONFIG, | |
EVALUATION_TRACKS, | |
MODEL_CATEGORIES, | |
CHART_CONFIG, | |
) | |
def create_leaderboard_plot( | |
df: pd.DataFrame, track: str, metric: str = "quality", top_n: int = 15 | |
) -> go.Figure: | |
"""Create leaderboard plot with confidence intervals.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No models available for this track", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
fig.update_layout( | |
title=f"No Data Available - {track.title()} Track", | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)" | |
) | |
return fig | |
try: | |
# Get top N models for this track | |
metric_col = f"{track}_{metric}" | |
ci_lower_col = f"{track}_ci_lower" | |
ci_upper_col = f"{track}_ci_upper" | |
if metric_col not in df.columns: | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Metric {metric} not available for {track} track", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
) | |
return fig | |
# Ensure numeric columns are properly typed | |
numeric_cols = [metric_col, ci_lower_col, ci_upper_col] | |
for col in numeric_cols: | |
if col in df.columns: | |
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0) | |
# Filter and sort | |
valid_models = df[(df[metric_col] > 0)].head(top_n).copy() | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No valid models found", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Create color mapping by category | |
colors = [MODEL_CATEGORIES.get(cat, {}).get("color", "#808080") for cat in valid_models["model_category"]] | |
# Main bar plot | |
fig = go.Figure() | |
# Add bars with error bars if confidence intervals available | |
error_x = None | |
if ci_lower_col in valid_models.columns and ci_upper_col in valid_models.columns: | |
try: | |
error_x = dict( | |
type="data", | |
array=valid_models[ci_upper_col] - valid_models[metric_col], | |
arrayminus=valid_models[metric_col] - valid_models[ci_lower_col], | |
visible=True, | |
thickness=2, | |
width=4, | |
) | |
except Exception as e: | |
print(f"Error creating error bars: {e}") | |
error_x = None | |
# Safely format text values | |
try: | |
text_values = [f"{float(score):.3f}" for score in valid_models[metric_col]] | |
except: | |
text_values = ["0.000"] * len(valid_models) | |
# Safely prepare custom data | |
try: | |
samples_col = f"{track}_samples" | |
samples_data = valid_models.get(samples_col, [0] * len(valid_models)) | |
customdata = list(zip( | |
valid_models["model_category"].fillna("unknown"), | |
valid_models["author"].fillna("Anonymous"), | |
[int(float(x)) if pd.notnull(x) else 0 for x in samples_data] | |
)) | |
except Exception as e: | |
print(f"Error preparing custom data: {e}") | |
customdata = [("unknown", "Anonymous", 0)] * len(valid_models) | |
fig.add_trace(go.Bar( | |
y=valid_models["model_name"], | |
x=valid_models[metric_col], | |
orientation="h", | |
marker=dict(color=colors, line=dict(color="black", width=0.5)), | |
error_x=error_x, | |
text=text_values, | |
textposition="auto", | |
hovertemplate=( | |
"<b>%{y}</b><br>" + | |
f"{metric.title()}: %{{x:.4f}}<br>" + | |
"Category: %{customdata[0]}<br>" + | |
"Author: %{customdata[1]}<br>" + | |
"Samples: %{customdata[2]}<br>" + | |
"<extra></extra>" | |
), | |
customdata=customdata, | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - {metric.title()} Score", | |
xaxis_title=f"{metric.title()} Score (with 95% CI)", | |
yaxis_title="Models", | |
height=max(400, len(valid_models) * 35 + 100), | |
margin=dict(l=20, r=20, t=60, b=20), | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
font=dict(size=12), | |
) | |
# Reverse y-axis to show best model at top | |
fig.update_yaxes(autorange="reversed") | |
return fig | |
except Exception as e: | |
print(f"Error creating leaderboard plot: {e}") | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Error creating plot: {str(e)}", | |
x=0.5, y=0.5, showarrow=False | |
) | |
return fig | |
def create_language_pair_heatmap( | |
model_results: Dict, track: str, metric: str = "quality_score" | |
) -> go.Figure: | |
"""Create language pair heatmap for a model.""" | |
if not model_results or "tracks" not in model_results: | |
fig = go.Figure() | |
fig.add_annotation(text="No model results available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
track_data = model_results["tracks"].get(track, {}) | |
if track_data.get("error") or "pair_metrics" not in track_data: | |
fig = go.Figure() | |
fig.add_annotation(text=f"No data available for {track} track", x=0.5, y=0.5, showarrow=False) | |
return fig | |
pair_metrics = track_data["pair_metrics"] | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
# Create matrix for heatmap | |
n_langs = len(track_languages) | |
matrix = np.full((n_langs, n_langs), np.nan) | |
for i, src_lang in enumerate(track_languages): | |
for j, tgt_lang in enumerate(track_languages): | |
if src_lang != tgt_lang: | |
pair_key = f"{src_lang}_to_{tgt_lang}" | |
if pair_key in pair_metrics and metric in pair_metrics[pair_key]: | |
matrix[i, j] = pair_metrics[pair_key][metric]["mean"] | |
# Create language labels | |
lang_labels = [LANGUAGE_NAMES.get(lang, lang.upper()) for lang in track_languages] | |
# Create heatmap | |
fig = go.Figure(data=go.Heatmap( | |
z=matrix, | |
x=lang_labels, | |
y=lang_labels, | |
colorscale="Viridis", | |
showscale=True, | |
colorbar=dict( | |
title=f"{metric.replace('_', ' ').title()}", | |
titleside="right", | |
len=0.8, | |
), | |
hovertemplate=( | |
"Source: %{y}<br>" + | |
"Target: %{x}<br>" + | |
f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" + | |
"<extra></extra>" | |
), | |
zmin=0, | |
zmax=1 if metric == "quality_score" else None, | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"πΊοΈ {track_info['name']} - {metric.replace('_', ' ').title()} by Language Pair", | |
xaxis_title="Target Language", | |
yaxis_title="Source Language", | |
height=600, | |
width=700, | |
font=dict(size=12), | |
xaxis=dict(side="bottom"), | |
yaxis=dict(autorange="reversed"), | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
) | |
return fig | |
def create_performance_comparison_plot(df: pd.DataFrame, track: str) -> go.Figure: | |
"""Create performance comparison plot showing confidence intervals.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
try: | |
metric_col = f"{track}_quality" | |
ci_lower_col = f"{track}_ci_lower" | |
ci_upper_col = f"{track}_ci_upper" | |
# Ensure numeric columns are properly typed | |
numeric_cols = [metric_col, ci_lower_col, ci_upper_col] | |
for col in numeric_cols: | |
if col in df.columns: | |
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0) | |
# Filter to models with data for this track | |
valid_models = df[ | |
(df[metric_col] > 0) & | |
(df[ci_lower_col].notna()) & | |
(df[ci_upper_col].notna()) | |
].head(10).copy() | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No models with confidence intervals", x=0.5, y=0.5, showarrow=False) | |
return fig | |
fig = go.Figure() | |
# Add confidence intervals as error bars | |
for i, (_, model) in enumerate(valid_models.iterrows()): | |
try: | |
category = str(model["model_category"]) | |
color = MODEL_CATEGORIES.get(category, {}).get("color", "#808080") | |
model_name = str(model["model_name"]) | |
# Safely extract numeric values | |
quality_val = float(model[metric_col]) | |
ci_lower_val = float(model[ci_lower_col]) | |
ci_upper_val = float(model[ci_upper_col]) | |
# Main point | |
fig.add_trace(go.Scatter( | |
x=[quality_val], | |
y=[i], | |
mode="markers", | |
marker=dict( | |
size=12, | |
color=color, | |
line=dict(color="black", width=1), | |
), | |
name=model_name, | |
showlegend=False, | |
hovertemplate=( | |
f"<b>{model_name}</b><br>" + | |
f"Quality: {quality_val:.4f}<br>" + | |
f"95% CI: [{ci_lower_val:.4f}, {ci_upper_val:.4f}]<br>" + | |
f"Category: {category}<br>" + | |
"<extra></extra>" | |
), | |
)) | |
# Confidence interval line | |
fig.add_trace(go.Scatter( | |
x=[ci_lower_val, ci_upper_val], | |
y=[i, i], | |
mode="lines", | |
line=dict(color=color, width=3), | |
showlegend=False, | |
hoverinfo="skip", | |
)) | |
except Exception as e: | |
print(f"Error adding model {i} to comparison plot: {e}") | |
continue | |
# Safely prepare tick labels | |
try: | |
tick_labels = [str(name) for name in valid_models["model_name"]] | |
except: | |
tick_labels = [f"Model {i}" for i in range(len(valid_models))] | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - Performance Comparison", | |
xaxis_title="Quality Score", | |
yaxis_title="Models", | |
height=max(400, len(valid_models) * 40 + 100), | |
yaxis=dict( | |
tickmode="array", | |
tickvals=list(range(len(valid_models))), | |
ticktext=tick_labels, | |
autorange="reversed", | |
), | |
showlegend=False, | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
) | |
return fig | |
except Exception as e: | |
print(f"Error creating performance comparison plot: {e}") | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Error creating plot: {str(e)}", | |
x=0.5, y=0.5, showarrow=False | |
) | |
return fig | |
def create_language_pair_comparison_plot(pairs_df: pd.DataFrame, track: str) -> go.Figure: | |
"""Create language pair comparison plot showing all models across all pairs.""" | |
if pairs_df.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No language pair data available", | |
x=0.5, y=0.5, showarrow=False | |
) | |
return fig | |
# Get unique language pairs and models | |
language_pairs = sorted(pairs_df['Language Pair'].unique()) | |
models = sorted(pairs_df['Model'].unique()) | |
if len(language_pairs) == 0 or len(models) == 0: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="Insufficient data for comparison", | |
x=0.5, y=0.5, showarrow=False | |
) | |
return fig | |
# Create subplot for each metric | |
fig = make_subplots( | |
rows=2, cols=1, | |
subplot_titles=('Quality Score by Language Pair', 'BLEU Score by Language Pair'), | |
vertical_spacing=0.1, | |
shared_xaxes=True | |
) | |
# Quality Score comparison | |
for model in models: | |
model_data = pairs_df[pairs_df['Model'] == model] | |
category = model_data['Category'].iloc[0] if not model_data.empty else 'community' | |
color = MODEL_CATEGORIES.get(category, {}).get('color', '#808080') | |
fig.add_trace( | |
go.Bar( | |
name=model, | |
x=model_data['Language Pair'], | |
y=model_data['Quality Score'], | |
marker_color=color, | |
opacity=0.8, | |
legendgroup=model, | |
showlegend=True, | |
hovertemplate=( | |
f"<b>{model}</b><br>" + | |
"Language Pair: %{x}<br>" + | |
"Quality Score: %{y:.4f}<br>" + | |
f"Category: {category}<br>" + | |
"<extra></extra>" | |
) | |
), | |
row=1, col=1 | |
) | |
# BLEU Score comparison | |
fig.add_trace( | |
go.Bar( | |
name=model, | |
x=model_data['Language Pair'], | |
y=model_data['BLEU'], | |
marker_color=color, | |
opacity=0.8, | |
legendgroup=model, | |
showlegend=False, | |
hovertemplate=( | |
f"<b>{model}</b><br>" + | |
"Language Pair: %{x}<br>" + | |
"BLEU: %{y:.2f}<br>" + | |
f"Category: {category}<br>" + | |
"<extra></extra>" | |
) | |
), | |
row=2, col=1 | |
) | |
# Update layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - Language Pair Performance Comparison", | |
height=800, | |
barmode='group', | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=1.02, | |
xanchor="right", | |
x=1 | |
) | |
) | |
# Rotate x-axis labels for better readability | |
fig.update_xaxes(tickangle=45, row=2, col=1) | |
fig.update_yaxes(title_text="Quality Score", row=1, col=1) | |
fig.update_yaxes(title_text="BLEU Score", row=2, col=1) | |
return fig | |
def create_category_comparison_plot(df: pd.DataFrame, track: str) -> go.Figure: | |
"""Create category-wise comparison plot.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
metric_col = f"{track}_quality" | |
# Filter to models with data | |
valid_models = df[df[metric_col] > 0] | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No valid models found", x=0.5, y=0.5, showarrow=False) | |
return fig | |
fig = go.Figure() | |
# Create box plot for each category | |
for category, info in MODEL_CATEGORIES.items(): | |
category_models = valid_models[valid_models["model_category"] == category] | |
if len(category_models) > 0: | |
fig.add_trace(go.Box( | |
y=category_models[metric_col], | |
name=info["name"], | |
marker_color=info["color"], | |
boxpoints="all", # Show all points | |
jitter=0.3, | |
pointpos=-1.8, | |
hovertemplate=( | |
f"<b>{info['name']}</b><br>" + | |
"Quality: %{y:.4f}<br>" + | |
"Model: %{customdata}<br>" + | |
"<extra></extra>" | |
), | |
customdata=category_models["model_name"], | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - Performance by Category", | |
xaxis_title="Model Category", | |
yaxis_title="Quality Score", | |
height=500, | |
showlegend=False, | |
paper_bgcolor="rgba(0,0,0,0)", | |
plot_bgcolor="rgba(0,0,0,0)", | |
) | |
return fig |