Spaces:
Sleeping
Sleeping
# src/plotting.py | |
import json | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
import numpy as np | |
import json | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
from config import ( | |
LANGUAGE_NAMES, | |
ALL_UG40_LANGUAGES, | |
GOOGLE_SUPPORTED_LANGUAGES, | |
METRICS_CONFIG, | |
EVALUATION_TRACKS, | |
MODEL_CATEGORIES, | |
CHART_CONFIG, | |
STATISTICAL_CONFIG, | |
SAMPLE_SIZE_RECOMMENDATIONS, | |
) | |
# Scientific plotting style | |
plt.style.use("default") | |
plt.rcParams["figure.facecolor"] = "white" | |
plt.rcParams["axes.facecolor"] = "white" | |
plt.rcParams["font.size"] = 10 | |
plt.rcParams["axes.labelsize"] = 12 | |
plt.rcParams["axes.titlesize"] = 14 | |
plt.rcParams["xtick.labelsize"] = 10 | |
plt.rcParams["ytick.labelsize"] = 10 | |
def create_scientific_leaderboard_plot( | |
df: pd.DataFrame, track: str, metric: str = "quality", top_n: int = 15 | |
) -> go.Figure: | |
"""Create scientific leaderboard plot with confidence intervals.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No models available for this track", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
fig.update_layout(title=f"No Data Available - {track.title()} Track") | |
return fig | |
# Get top N models for this track | |
metric_col = f"{track}_{metric}" | |
ci_lower_col = f"{track}_ci_lower" | |
ci_upper_col = f"{track}_ci_upper" | |
if metric_col not in df.columns: | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Metric {metric} not available for {track} track", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
) | |
return fig | |
# Filter and sort | |
valid_models = df[(df[metric_col] > 0)].head(top_n) | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No valid models found", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Create color mapping by category | |
category_colors = {} | |
for i, category in enumerate(MODEL_CATEGORIES.keys()): | |
category_colors[category] = MODEL_CATEGORIES[category]["color"] | |
colors = [category_colors.get(cat, "#808080") for cat in valid_models["model_category"]] | |
# Main bar plot | |
fig = go.Figure() | |
# Add bars with error bars if confidence intervals available | |
if ci_lower_col in valid_models.columns and ci_upper_col in valid_models.columns: | |
error_y = dict( | |
type="data", | |
array=valid_models[ci_upper_col] - valid_models[metric_col], | |
arrayminus=valid_models[metric_col] - valid_models[ci_lower_col], | |
visible=True, | |
thickness=2, | |
width=4, | |
) | |
else: | |
error_y = None | |
fig.add_trace(go.Bar( | |
y=valid_models["model_name"], | |
x=valid_models[metric_col], | |
orientation="h", | |
marker=dict(color=colors, line=dict(color="black", width=0.5)), | |
error_x=error_y, | |
text=[f"{score:.3f}" for score in valid_models[metric_col]], | |
textposition="auto", | |
hovertemplate=( | |
"<b>%{y}</b><br>" + | |
f"{metric.title()}: %{{x:.4f}}<br>" + | |
"Category: %{customdata[0]}<br>" + | |
"Author: %{customdata[1]}<br>" + | |
"Samples: %{customdata[2]}<br>" + | |
"<extra></extra>" | |
), | |
customdata=list(zip( | |
valid_models["model_category"], | |
valid_models["author"], | |
valid_models.get(f"{track}_samples", [0] * len(valid_models)) | |
)), | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - {metric.title()} Score", | |
xaxis_title=f"{metric.title()} Score (with 95% CI)", | |
yaxis_title="Models", | |
height=max(400, len(valid_models) * 35 + 100), | |
margin=dict(l=20, r=20, t=60, b=20), | |
plot_bgcolor="white", | |
paper_bgcolor="white", | |
font=dict(size=12), | |
) | |
# Reverse y-axis to show best model at top | |
fig.update_yaxes(autorange="reversed") | |
# Add category legend | |
for category, info in MODEL_CATEGORIES.items(): | |
if category in valid_models["model_category"].values: | |
fig.add_trace(go.Scatter( | |
x=[None], y=[None], | |
mode="markers", | |
marker=dict(size=10, color=info["color"]), | |
name=info["name"], | |
showlegend=True, | |
)) | |
return fig | |
def create_language_pair_heatmap_scientific( | |
model_results: Dict, track: str, metric: str = "quality_score" | |
) -> go.Figure: | |
"""Create research-grade language pair heatmap with proper axes.""" | |
if not model_results or "tracks" not in model_results: | |
fig = go.Figure() | |
fig.add_annotation(text="No model results available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
track_data = model_results["tracks"].get(track, {}) | |
if track_data.get("error") or "pair_metrics" not in track_data: | |
fig = go.Figure() | |
fig.add_annotation(text=f"No data available for {track} track", x=0.5, y=0.5, showarrow=False) | |
return fig | |
pair_metrics = track_data["pair_metrics"] | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
# Create matrix for heatmap | |
n_langs = len(track_languages) | |
matrix = np.full((n_langs, n_langs), np.nan) | |
for i, src_lang in enumerate(track_languages): | |
for j, tgt_lang in enumerate(track_languages): | |
if src_lang != tgt_lang: | |
pair_key = f"{src_lang}_to_{tgt_lang}" | |
if pair_key in pair_metrics and metric in pair_metrics[pair_key]: | |
matrix[i, j] = pair_metrics[pair_key][metric]["mean"] | |
# Create language labels | |
lang_labels = [LANGUAGE_NAMES.get(lang, lang.upper()) for lang in track_languages] | |
# Create heatmap | |
fig = go.Figure(data=go.Heatmap( | |
z=matrix, | |
x=lang_labels, | |
y=lang_labels, | |
colorscale="Viridis", | |
showscale=True, | |
colorbar=dict( | |
title=f"{metric.replace('_', ' ').title()}", | |
titleside="right", | |
len=0.8, | |
), | |
hovertemplate=( | |
"Source: %{y}<br>" + | |
"Target: %{x}<br>" + | |
f"{metric.replace('_', ' ').title()}: %{{z:.3f}}<br>" + | |
"<extra></extra>" | |
), | |
zmin=0, | |
zmax=1 if metric == "quality_score" else None, | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"πΊοΈ {track_info['name']} - {metric.replace('_', ' ').title()} by Language Pair", | |
xaxis_title="Target Language", | |
yaxis_title="Source Language", | |
height=600, | |
width=700, | |
font=dict(size=12), | |
xaxis=dict(side="bottom"), | |
yaxis=dict(autorange="reversed"), # Source languages from top to bottom | |
) | |
return fig | |
def create_statistical_comparison_plot(df: pd.DataFrame, track: str) -> go.Figure: | |
"""Create statistical comparison plot showing confidence intervals.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
metric_col = f"{track}_quality" | |
ci_lower_col = f"{track}_ci_lower" | |
ci_upper_col = f"{track}_ci_upper" | |
# Filter to models with data for this track | |
valid_models = df[ | |
(df[metric_col] > 0) & | |
(df[ci_lower_col].notna()) & | |
(df[ci_upper_col].notna()) | |
].head(10) | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No models with confidence intervals", x=0.5, y=0.5, showarrow=False) | |
return fig | |
fig = go.Figure() | |
# Add confidence intervals as error bars | |
for i, (_, model) in enumerate(valid_models.iterrows()): | |
category = model["model_category"] | |
color = MODEL_CATEGORIES.get(category, {}).get("color", "#808080") | |
# Main point | |
fig.add_trace(go.Scatter( | |
x=[model[metric_col]], | |
y=[i], | |
mode="markers", | |
marker=dict( | |
size=12, | |
color=color, | |
line=dict(color="black", width=1), | |
), | |
name=model["model_name"], | |
showlegend=False, | |
hovertemplate=( | |
f"<b>{model['model_name']}</b><br>" + | |
f"Quality: {model[metric_col]:.4f}<br>" + | |
f"95% CI: [{model[ci_lower_col]:.4f}, {model[ci_upper_col]:.4f}]<br>" + | |
f"Category: {category}<br>" + | |
"<extra></extra>" | |
), | |
)) | |
# Confidence interval line | |
fig.add_trace(go.Scatter( | |
x=[model[ci_lower_col], model[ci_upper_col]], | |
y=[i, i], | |
mode="lines", | |
line=dict(color=color, width=3), | |
showlegend=False, | |
hoverinfo="skip", | |
)) | |
# CI endpoints | |
fig.add_trace(go.Scatter( | |
x=[model[ci_lower_col], model[ci_upper_col]], | |
y=[i, i], | |
mode="markers", | |
marker=dict( | |
symbol="line-ns", | |
size=10, | |
color=color, | |
line=dict(width=2), | |
), | |
showlegend=False, | |
hoverinfo="skip", | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - Statistical Comparison", | |
xaxis_title="Quality Score", | |
yaxis_title="Models", | |
height=max(400, len(valid_models) * 40 + 100), | |
yaxis=dict( | |
tickmode="array", | |
tickvals=list(range(len(valid_models))), | |
ticktext=valid_models["model_name"].tolist(), | |
autorange="reversed", | |
), | |
showlegend=False, | |
plot_bgcolor="white", | |
paper_bgcolor="white", | |
) | |
return fig | |
def create_category_comparison_plot(df: pd.DataFrame, track: str) -> go.Figure: | |
"""Create category-wise comparison plot.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
metric_col = f"{track}_quality" | |
adequate_col = f"{track}_adequate" | |
# Filter to adequate models | |
valid_models = df[df[adequate_col] & (df[metric_col] > 0)] | |
if valid_models.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No adequate models found", x=0.5, y=0.5, showarrow=False) | |
return fig | |
fig = go.Figure() | |
# Create box plot for each category | |
for category, info in MODEL_CATEGORIES.items(): | |
category_models = valid_models[valid_models["model_category"] == category] | |
if len(category_models) > 0: | |
fig.add_trace(go.Box( | |
y=category_models[metric_col], | |
name=info["name"], | |
marker_color=info["color"], | |
boxpoints="all", # Show all points | |
jitter=0.3, | |
pointpos=-1.8, | |
hovertemplate=( | |
f"<b>{info['name']}</b><br>" + | |
"Quality: %{y:.4f}<br>" + | |
"Model: %{customdata}<br>" + | |
"<extra></extra>" | |
), | |
customdata=category_models["model_name"], | |
)) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π {track_info['name']} - Performance by Category", | |
xaxis_title="Model Category", | |
yaxis_title="Quality Score", | |
height=500, | |
showlegend=False, | |
plot_bgcolor="white", | |
paper_bgcolor="white", | |
) | |
return fig | |
def create_adequacy_analysis_plot(df: pd.DataFrame) -> go.Figure: | |
"""Create analysis plot for statistical adequacy across tracks.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=( | |
"Sample Sizes by Track", | |
"Statistical Adequacy Distribution", | |
"Scientific Adequacy Scores", | |
"Model Categories Distribution" | |
), | |
specs=[ | |
[{"type": "bar"}, {"type": "pie"}], | |
[{"type": "histogram"}, {"type": "bar"}] | |
] | |
) | |
# Sample sizes by track | |
track_names = [] | |
sample_counts = [] | |
for track in EVALUATION_TRACKS.keys(): | |
samples_col = f"{track}_samples" | |
if samples_col in df.columns: | |
total_samples = df[df[samples_col] > 0][samples_col].sum() | |
track_names.append(track.replace("_", " ").title()) | |
sample_counts.append(total_samples) | |
if track_names: | |
fig.add_trace( | |
go.Bar(x=track_names, y=sample_counts, name="Samples"), | |
row=1, col=1 | |
) | |
# Statistical adequacy distribution | |
adequacy_bins = pd.cut( | |
df["scientific_adequacy_score"], | |
bins=[0, 0.3, 0.6, 0.8, 1.0], | |
labels=["Poor", "Fair", "Good", "Excellent"] | |
) | |
adequacy_counts = adequacy_bins.value_counts() | |
if not adequacy_counts.empty: | |
fig.add_trace( | |
go.Pie( | |
labels=adequacy_counts.index, | |
values=adequacy_counts.values, | |
name="Adequacy" | |
), | |
row=1, col=2 | |
) | |
# Scientific adequacy scores histogram | |
fig.add_trace( | |
go.Histogram( | |
x=df["scientific_adequacy_score"], | |
nbinsx=20, | |
name="Adequacy Scores" | |
), | |
row=2, col=1 | |
) | |
# Model categories distribution | |
category_counts = df["model_category"].value_counts() | |
category_colors = [MODEL_CATEGORIES.get(cat, {}).get("color", "#808080") for cat in category_counts.index] | |
fig.add_trace( | |
go.Bar( | |
x=category_counts.index, | |
y=category_counts.values, | |
marker_color=category_colors, | |
name="Categories" | |
), | |
row=2, col=2 | |
) | |
fig.update_layout( | |
title="π Scientific Evaluation Analysis", | |
height=800, | |
showlegend=False | |
) | |
return fig | |
def create_cross_track_analysis_plot(df: pd.DataFrame) -> go.Figure: | |
"""Create cross-track performance correlation analysis.""" | |
if df.empty: | |
fig = go.Figure() | |
fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Get models with data in multiple tracks | |
quality_cols = [f"{track}_quality" for track in EVALUATION_TRACKS.keys()] | |
available_cols = [col for col in quality_cols if col in df.columns] | |
if len(available_cols) < 2: | |
fig = go.Figure() | |
fig.add_annotation(text="Need at least 2 tracks for comparison", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Filter to models with data in multiple tracks | |
multi_track_models = df.copy() | |
for col in available_cols: | |
multi_track_models = multi_track_models[multi_track_models[col] > 0] | |
if len(multi_track_models) < 3: | |
fig = go.Figure() | |
fig.add_annotation(text="Insufficient models for cross-track analysis", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Create scatter plot matrix | |
track_pairs = [(available_cols[i], available_cols[j]) | |
for i in range(len(available_cols)) | |
for j in range(i+1, len(available_cols))] | |
if not track_pairs: | |
fig = go.Figure() | |
fig.add_annotation(text="No track pairs available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Use first pair for demonstration | |
x_col, y_col = track_pairs[0] | |
x_track = x_col.replace("_quality", "").replace("_", " ").title() | |
y_track = y_col.replace("_quality", "").replace("_", " ").title() | |
fig = go.Figure() | |
# Color by category | |
for category, info in MODEL_CATEGORIES.items(): | |
category_models = multi_track_models[multi_track_models["model_category"] == category] | |
if len(category_models) > 0: | |
fig.add_trace(go.Scatter( | |
x=category_models[x_col], | |
y=category_models[y_col], | |
mode="markers", | |
marker=dict( | |
size=10, | |
color=info["color"], | |
line=dict(color="black", width=1), | |
), | |
name=info["name"], | |
text=category_models["model_name"], | |
hovertemplate=( | |
"<b>%{text}</b><br>" + | |
f"{x_track}: %{{x:.4f}}<br>" + | |
f"{y_track}: %{{y:.4f}}<br>" + | |
f"Category: {info['name']}<br>" + | |
"<extra></extra>" | |
), | |
)) | |
# Add diagonal line for reference | |
min_val = min(multi_track_models[x_col].min(), multi_track_models[y_col].min()) | |
max_val = max(multi_track_models[x_col].max(), multi_track_models[y_col].max()) | |
fig.add_trace(go.Scatter( | |
x=[min_val, max_val], | |
y=[min_val, max_val], | |
mode="lines", | |
line=dict(dash="dash", color="gray", width=2), | |
name="Perfect Correlation", | |
showlegend=False, | |
hoverinfo="skip", | |
)) | |
fig.update_layout( | |
title=f"π Cross-Track Performance: {x_track} vs {y_track}", | |
xaxis_title=f"{x_track} Quality Score", | |
yaxis_title=f"{y_track} Quality Score", | |
height=600, | |
width=600, | |
plot_bgcolor="white", | |
paper_bgcolor="white", | |
) | |
return fig | |
def create_scientific_model_detail_plot(model_results: Dict, model_name: str, track: str) -> go.Figure: | |
"""Create detailed scientific analysis for a specific model.""" | |
if not model_results or "tracks" not in model_results: | |
fig = go.Figure() | |
fig.add_annotation(text="No model results available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
track_data = model_results["tracks"].get(track, {}) | |
if track_data.get("error") or "pair_metrics" not in track_data: | |
fig = go.Figure() | |
fig.add_annotation(text=f"No data for {track} track", x=0.5, y=0.5, showarrow=False) | |
return fig | |
pair_metrics = track_data["pair_metrics"] | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
# Extract data for plotting | |
pairs = [] | |
quality_means = [] | |
quality_cis = [] | |
bleu_means = [] | |
sample_counts = [] | |
for src in track_languages: | |
for tgt in track_languages: | |
if src == tgt: | |
continue | |
pair_key = f"{src}_to_{tgt}" | |
if pair_key in pair_metrics: | |
metrics = pair_metrics[pair_key] | |
if "quality_score" in metrics and "sample_count" in metrics: | |
pair_label = f"{LANGUAGE_NAMES.get(src, src)} β {LANGUAGE_NAMES.get(tgt, tgt)}" | |
pairs.append(pair_label) | |
quality_stats = metrics["quality_score"] | |
quality_means.append(quality_stats["mean"]) | |
quality_cis.append([quality_stats["ci_lower"], quality_stats["ci_upper"]]) | |
bleu_stats = metrics.get("bleu", {"mean": 0}) | |
bleu_means.append(bleu_stats["mean"]) | |
sample_counts.append(metrics["sample_count"]) | |
if not pairs: | |
fig = go.Figure() | |
fig.add_annotation(text="No language pair data available", x=0.5, y=0.5, showarrow=False) | |
return fig | |
# Create subplots | |
fig = make_subplots( | |
rows=2, cols=1, | |
subplot_titles=( | |
"Quality Scores by Language Pair (with 95% CI)", | |
"BLEU Scores by Language Pair" | |
), | |
vertical_spacing=0.15, | |
) | |
# Quality scores with confidence intervals | |
error_y = dict( | |
type="data", | |
array=[ci[1] - mean for ci, mean in zip(quality_cis, quality_means)], | |
arrayminus=[mean - ci[0] for ci, mean in zip(quality_cis, quality_means)], | |
visible=True, | |
thickness=2, | |
width=4, | |
) | |
fig.add_trace( | |
go.Bar( | |
x=pairs, | |
y=quality_means, | |
error_y=error_y, | |
name="Quality Score", | |
marker_color="steelblue", | |
text=[f"{score:.3f}" for score in quality_means], | |
textposition="outside", | |
hovertemplate=( | |
"<b>%{x}</b><br>" + | |
"Quality: %{y:.4f}<br>" + | |
"Samples: %{customdata}<br>" + | |
"<extra></extra>" | |
), | |
customdata=sample_counts, | |
), | |
row=1, col=1 | |
) | |
# BLEU scores | |
fig.add_trace( | |
go.Bar( | |
x=pairs, | |
y=bleu_means, | |
name="BLEU Score", | |
marker_color="coral", | |
text=[f"{score:.1f}" for score in bleu_means], | |
textposition="outside", | |
), | |
row=2, col=1 | |
) | |
# Customize layout | |
track_info = EVALUATION_TRACKS[track] | |
fig.update_layout( | |
title=f"π¬ Detailed Analysis: {model_name} - {track_info['name']}", | |
height=900, | |
showlegend=False, | |
margin=dict(l=50, r=50, t=100, b=150), | |
) | |
# Rotate x-axis labels | |
fig.update_xaxes(tickangle=45, row=1, col=1) | |
fig.update_xaxes(tickangle=45, row=2, col=1) | |
return fig |