leaderboard / src /utils.py
akera's picture
Update src/utils.py
aa9fced verified
# src/utils.py
import re
import datetime
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple, Set, Optional, Union
from config import (
ALL_UG40_LANGUAGES,
GOOGLE_SUPPORTED_LANGUAGES,
LANGUAGE_NAMES,
EVALUATION_TRACKS,
MODEL_CATEGORIES,
METRICS_CONFIG,
)
def get_all_language_pairs() -> List[Tuple[str, str]]:
"""Get all possible UG40 language pairs."""
pairs = []
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def get_google_comparable_pairs() -> List[Tuple[str, str]]:
"""Get language pairs that can be compared with Google Translate."""
pairs = []
for src in GOOGLE_SUPPORTED_LANGUAGES:
for tgt in GOOGLE_SUPPORTED_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def get_track_language_pairs(track: str) -> List[Tuple[str, str]]:
"""Get language pairs for a specific evaluation track."""
if track not in EVALUATION_TRACKS:
return []
track_languages = EVALUATION_TRACKS[track]["languages"]
pairs = []
for src in track_languages:
for tgt in track_languages:
if src != tgt:
pairs.append((src, tgt))
return pairs
def format_language_pair(src: str, tgt: str) -> str:
"""Format language pair for display."""
src_name = LANGUAGE_NAMES.get(src, src.upper())
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper())
return f"{src_name}{tgt_name}"
def validate_language_code(lang: str) -> bool:
"""Validate if language code is supported."""
return lang in ALL_UG40_LANGUAGES
def create_submission_id() -> str:
"""Create unique submission ID with timestamp and random component."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
random_suffix = str(np.random.randint(1000, 9999))
return f"sub_{timestamp}_{random_suffix}"
def sanitize_model_name(name: str) -> str:
"""Sanitize model name for display and storage."""
if not name or not isinstance(name, str):
return "Anonymous_Model"
# Remove special characters, limit length
name = re.sub(r"[^\w\-.]", "_", name.strip())
# Remove multiple consecutive underscores
name = re.sub(r"_+", "_", name)
# Remove leading/trailing underscores
name = name.strip("_")
# Ensure minimum length
if len(name) < 3:
name = f"Model_{name}"
# Check for reserved names
reserved_names = ["admin", "test", "baseline", "google", "system"]
if name.lower() in reserved_names:
name = f"User_{name}"
return name[:50] # Limit to 50 characters
def format_metric_value(value: float, metric: str, precision: int = None) -> str:
"""Format metric value for display."""
if pd.isna(value) or value is None:
return "N/A"
try:
if precision is None:
precision = METRICS_CONFIG["display_precision"]
if metric == "coverage_rate":
return f"{value:.1%}"
elif metric in ["bleu"]:
return f"{value:.2f}"
elif metric in ["cer", "wer"] and value > 1:
# Cap error rates at 1.0 for display
return f"{min(value, 1.0):.{precision}f}"
else:
return f"{value:.{precision}f}"
except (ValueError, TypeError):
return str(value)
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
"""Safely divide two numbers, handling edge cases."""
try:
if denominator == 0 or pd.isna(denominator) or pd.isna(numerator):
return default
result = numerator / denominator
if pd.isna(result) or not np.isfinite(result):
return default
return float(result)
except (TypeError, ValueError, ZeroDivisionError):
return default
def clean_text_for_evaluation(text: str) -> str:
"""Clean text for evaluation, handling common encoding issues."""
if not isinstance(text, str):
return str(text) if text is not None else ""
# Remove extra whitespace
text = re.sub(r"\s+", " ", text.strip())
# Handle common encoding issues
text = text.replace("\u00a0", " ") # Non-breaking space
text = text.replace("\u2019", "'") # Right single quotation mark
text = text.replace("\u201c", '"') # Left double quotation mark
text = text.replace("\u201d", '"') # Right double quotation mark
return text
def validate_dataframe_structure(
df: pd.DataFrame, required_columns: List[str], track: str = None
) -> Tuple[bool, List[str]]:
"""Validate DataFrame structure."""
if df.empty:
return False, ["DataFrame is empty"]
issues = []
# Check required columns
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
issues.append(f"Missing columns: {', '.join(missing_columns)}")
# Check data types
if "sample_id" in df.columns:
if not df["sample_id"].dtype == "object":
try:
df["sample_id"] = df["sample_id"].astype(str)
except Exception:
issues.append("Cannot convert sample_id to string")
return len(issues) == 0, issues
def calculate_track_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame, track: str) -> Dict:
"""Calculate coverage statistics for a specific track."""
if track not in EVALUATION_TRACKS:
return {"error": f"Unknown track: {track}"}
track_config = EVALUATION_TRACKS[track]
track_languages = track_config["languages"]
# Filter test set to track languages
track_test_set = test_set[
(test_set["source_language"].isin(track_languages)) &
(test_set["target_language"].isin(track_languages))
]
if track_test_set.empty:
return {"error": f"No test data available for {track} track"}
# Calculate coverage
pred_ids = set(predictions["sample_id"].astype(str))
test_ids = set(track_test_set["sample_id"].astype(str))
matching_ids = pred_ids & test_ids
coverage_rate = len(matching_ids) / len(test_ids)
# Analyze by language pair
pair_analysis = {}
for src in track_languages:
for tgt in track_languages:
if src == tgt:
continue
pair_test_data = track_test_set[
(track_test_set["source_language"] == src) &
(track_test_set["target_language"] == tgt)
]
if len(pair_test_data) > 0:
pair_test_ids = set(pair_test_data["sample_id"].astype(str))
pair_matching = pred_ids & pair_test_ids
pair_analysis[f"{src}_to_{tgt}"] = {
"total": len(pair_test_data),
"covered": len(pair_matching),
"coverage_rate": len(pair_matching) / len(pair_test_data),
}
return {
"track_name": track_config["name"],
"total_samples": len(track_test_set),
"covered_samples": len(matching_ids),
"coverage_rate": coverage_rate,
"pair_analysis": pair_analysis,
}
def generate_model_identifier(model_name: str, author: str, category: str) -> str:
"""Generate a unique identifier for a model."""
clean_name = sanitize_model_name(model_name)
clean_author = re.sub(r"[^\w\-]", "_", author.strip())[:20] if author else "Anonymous"
clean_category = category[:10] if category in MODEL_CATEGORIES else "community"
timestamp = datetime.datetime.now().strftime("%m%d_%H%M")
return f"{clean_category}_{clean_name}_{clean_author}_{timestamp}"
def format_duration(seconds: float) -> str:
"""Format duration in seconds to human-readable format."""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds/60:.1f}m"
else:
return f"{seconds/3600:.1f}h"
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
"""Truncate text to specified length with suffix."""
if not isinstance(text, str):
text = str(text)
if len(text) <= max_length:
return text
return text[: max_length - len(suffix)] + suffix
def get_language_pair_display_name(src: str, tgt: str) -> str:
"""Get display name for a language pair."""
src_name = LANGUAGE_NAMES.get(src, src.upper())
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper())
return f"{src_name}{tgt_name}"
def validate_submission_completeness(
predictions: pd.DataFrame, test_set: pd.DataFrame, track: str = None
) -> Dict:
"""Validate submission completeness."""
if predictions.empty or test_set.empty:
return {
"is_complete": False,
"missing_count": len(test_set) if not test_set.empty else 0,
"extra_count": len(predictions) if not predictions.empty else 0,
"missing_ids": [],
"coverage": 0.0,
}
# If track specified, filter to track languages
if track and track in EVALUATION_TRACKS:
track_languages = EVALUATION_TRACKS[track]["languages"]
test_set = test_set[
(test_set["source_language"].isin(track_languages)) &
(test_set["target_language"].isin(track_languages))
]
try:
required_ids = set(test_set["sample_id"].astype(str))
provided_ids = set(predictions["sample_id"].astype(str))
missing_ids = required_ids - provided_ids
extra_ids = provided_ids - required_ids
matching_ids = provided_ids & required_ids
return {
"is_complete": len(missing_ids) == 0,
"missing_count": len(missing_ids),
"extra_count": len(extra_ids),
"missing_ids": list(missing_ids)[:10],
"coverage": len(matching_ids) / len(required_ids) if required_ids else 0.0,
}
except Exception as e:
print(f"Error in submission completeness validation: {e}")
return {
"is_complete": False,
"missing_count": 0,
"extra_count": 0,
"missing_ids": [],
"coverage": 0.0,
}
def get_model_summary_stats(model_results: Dict, track: str = None) -> Dict:
"""Extract summary statistics from model evaluation results."""
if not model_results or "tracks" not in model_results:
return {}
tracks = model_results["tracks"]
# If specific track requested
if track and track in tracks:
track_data = tracks[track]
if track_data.get("error"):
return {"error": f"No valid data for {track} track"}
track_averages = track_data.get("track_averages", {})
summary = track_data.get("summary", {})
stats = {
"track": track,
"track_name": EVALUATION_TRACKS[track]["name"],
"quality_score": track_averages.get("quality_score", 0.0),
"bleu": track_averages.get("bleu", 0.0),
"chrf": track_averages.get("chrf", 0.0),
"total_samples": summary.get("total_samples", 0),
"language_pairs": summary.get("language_pairs_evaluated", 0),
}
return stats
# Otherwise, return summary across all tracks
all_tracks_summary = {
"tracks_evaluated": len([t for t in tracks.values() if not t.get("error")]),
"total_tracks": len(EVALUATION_TRACKS),
"by_track": {},
}
for track_name, track_data in tracks.items():
if not track_data.get("error"):
track_averages = track_data.get("track_averages", {})
summary = track_data.get("summary", {})
all_tracks_summary["by_track"][track_name] = {
"quality_score": track_averages.get("quality_score", 0.0),
"samples": summary.get("total_samples", 0),
"pairs": summary.get("language_pairs_evaluated", 0),
}
return all_tracks_summary