Spaces:
Running
Running
# src/utils.py | |
import re | |
import datetime | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Tuple, Set, Optional, Union | |
from config import ( | |
ALL_UG40_LANGUAGES, | |
GOOGLE_SUPPORTED_LANGUAGES, | |
LANGUAGE_NAMES, | |
EVALUATION_TRACKS, | |
MODEL_CATEGORIES, | |
METRICS_CONFIG, | |
) | |
def get_all_language_pairs() -> List[Tuple[str, str]]: | |
"""Get all possible UG40 language pairs.""" | |
pairs = [] | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def get_google_comparable_pairs() -> List[Tuple[str, str]]: | |
"""Get language pairs that can be compared with Google Translate.""" | |
pairs = [] | |
for src in GOOGLE_SUPPORTED_LANGUAGES: | |
for tgt in GOOGLE_SUPPORTED_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def get_track_language_pairs(track: str) -> List[Tuple[str, str]]: | |
"""Get language pairs for a specific evaluation track.""" | |
if track not in EVALUATION_TRACKS: | |
return [] | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
pairs = [] | |
for src in track_languages: | |
for tgt in track_languages: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def format_language_pair(src: str, tgt: str) -> str: | |
"""Format language pair for display.""" | |
src_name = LANGUAGE_NAMES.get(src, src.upper()) | |
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper()) | |
return f"{src_name} → {tgt_name}" | |
def validate_language_code(lang: str) -> bool: | |
"""Validate if language code is supported.""" | |
return lang in ALL_UG40_LANGUAGES | |
def create_submission_id() -> str: | |
"""Create unique submission ID with timestamp and random component.""" | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
random_suffix = str(np.random.randint(1000, 9999)) | |
return f"sub_{timestamp}_{random_suffix}" | |
def sanitize_model_name(name: str) -> str: | |
"""Sanitize model name for display and storage.""" | |
if not name or not isinstance(name, str): | |
return "Anonymous_Model" | |
# Remove special characters, limit length | |
name = re.sub(r"[^\w\-.]", "_", name.strip()) | |
# Remove multiple consecutive underscores | |
name = re.sub(r"_+", "_", name) | |
# Remove leading/trailing underscores | |
name = name.strip("_") | |
# Ensure minimum length | |
if len(name) < 3: | |
name = f"Model_{name}" | |
# Check for reserved names | |
reserved_names = ["admin", "test", "baseline", "google", "system"] | |
if name.lower() in reserved_names: | |
name = f"User_{name}" | |
return name[:50] # Limit to 50 characters | |
def format_metric_value(value: float, metric: str, precision: int = None) -> str: | |
"""Format metric value for display.""" | |
if pd.isna(value) or value is None: | |
return "N/A" | |
try: | |
if precision is None: | |
precision = METRICS_CONFIG["display_precision"] | |
if metric == "coverage_rate": | |
return f"{value:.1%}" | |
elif metric in ["bleu"]: | |
return f"{value:.2f}" | |
elif metric in ["cer", "wer"] and value > 1: | |
# Cap error rates at 1.0 for display | |
return f"{min(value, 1.0):.{precision}f}" | |
else: | |
return f"{value:.{precision}f}" | |
except (ValueError, TypeError): | |
return str(value) | |
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float: | |
"""Safely divide two numbers, handling edge cases.""" | |
try: | |
if denominator == 0 or pd.isna(denominator) or pd.isna(numerator): | |
return default | |
result = numerator / denominator | |
if pd.isna(result) or not np.isfinite(result): | |
return default | |
return float(result) | |
except (TypeError, ValueError, ZeroDivisionError): | |
return default | |
def clean_text_for_evaluation(text: str) -> str: | |
"""Clean text for evaluation, handling common encoding issues.""" | |
if not isinstance(text, str): | |
return str(text) if text is not None else "" | |
# Remove extra whitespace | |
text = re.sub(r"\s+", " ", text.strip()) | |
# Handle common encoding issues | |
text = text.replace("\u00a0", " ") # Non-breaking space | |
text = text.replace("\u2019", "'") # Right single quotation mark | |
text = text.replace("\u201c", '"') # Left double quotation mark | |
text = text.replace("\u201d", '"') # Right double quotation mark | |
return text | |
def validate_dataframe_structure( | |
df: pd.DataFrame, required_columns: List[str], track: str = None | |
) -> Tuple[bool, List[str]]: | |
"""Validate DataFrame structure.""" | |
if df.empty: | |
return False, ["DataFrame is empty"] | |
issues = [] | |
# Check required columns | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
issues.append(f"Missing columns: {', '.join(missing_columns)}") | |
# Check data types | |
if "sample_id" in df.columns: | |
if not df["sample_id"].dtype == "object": | |
try: | |
df["sample_id"] = df["sample_id"].astype(str) | |
except Exception: | |
issues.append("Cannot convert sample_id to string") | |
return len(issues) == 0, issues | |
def calculate_track_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame, track: str) -> Dict: | |
"""Calculate coverage statistics for a specific track.""" | |
if track not in EVALUATION_TRACKS: | |
return {"error": f"Unknown track: {track}"} | |
track_config = EVALUATION_TRACKS[track] | |
track_languages = track_config["languages"] | |
# Filter test set to track languages | |
track_test_set = test_set[ | |
(test_set["source_language"].isin(track_languages)) & | |
(test_set["target_language"].isin(track_languages)) | |
] | |
if track_test_set.empty: | |
return {"error": f"No test data available for {track} track"} | |
# Calculate coverage | |
pred_ids = set(predictions["sample_id"].astype(str)) | |
test_ids = set(track_test_set["sample_id"].astype(str)) | |
matching_ids = pred_ids & test_ids | |
coverage_rate = len(matching_ids) / len(test_ids) | |
# Analyze by language pair | |
pair_analysis = {} | |
for src in track_languages: | |
for tgt in track_languages: | |
if src == tgt: | |
continue | |
pair_test_data = track_test_set[ | |
(track_test_set["source_language"] == src) & | |
(track_test_set["target_language"] == tgt) | |
] | |
if len(pair_test_data) > 0: | |
pair_test_ids = set(pair_test_data["sample_id"].astype(str)) | |
pair_matching = pred_ids & pair_test_ids | |
pair_analysis[f"{src}_to_{tgt}"] = { | |
"total": len(pair_test_data), | |
"covered": len(pair_matching), | |
"coverage_rate": len(pair_matching) / len(pair_test_data), | |
} | |
return { | |
"track_name": track_config["name"], | |
"total_samples": len(track_test_set), | |
"covered_samples": len(matching_ids), | |
"coverage_rate": coverage_rate, | |
"pair_analysis": pair_analysis, | |
} | |
def generate_model_identifier(model_name: str, author: str, category: str) -> str: | |
"""Generate a unique identifier for a model.""" | |
clean_name = sanitize_model_name(model_name) | |
clean_author = re.sub(r"[^\w\-]", "_", author.strip())[:20] if author else "Anonymous" | |
clean_category = category[:10] if category in MODEL_CATEGORIES else "community" | |
timestamp = datetime.datetime.now().strftime("%m%d_%H%M") | |
return f"{clean_category}_{clean_name}_{clean_author}_{timestamp}" | |
def format_duration(seconds: float) -> str: | |
"""Format duration in seconds to human-readable format.""" | |
if seconds < 60: | |
return f"{seconds:.1f}s" | |
elif seconds < 3600: | |
return f"{seconds/60:.1f}m" | |
else: | |
return f"{seconds/3600:.1f}h" | |
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: | |
"""Truncate text to specified length with suffix.""" | |
if not isinstance(text, str): | |
text = str(text) | |
if len(text) <= max_length: | |
return text | |
return text[: max_length - len(suffix)] + suffix | |
def get_language_pair_display_name(src: str, tgt: str) -> str: | |
"""Get display name for a language pair.""" | |
src_name = LANGUAGE_NAMES.get(src, src.upper()) | |
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper()) | |
return f"{src_name} → {tgt_name}" | |
def validate_submission_completeness( | |
predictions: pd.DataFrame, test_set: pd.DataFrame, track: str = None | |
) -> Dict: | |
"""Validate submission completeness.""" | |
if predictions.empty or test_set.empty: | |
return { | |
"is_complete": False, | |
"missing_count": len(test_set) if not test_set.empty else 0, | |
"extra_count": len(predictions) if not predictions.empty else 0, | |
"missing_ids": [], | |
"coverage": 0.0, | |
} | |
# If track specified, filter to track languages | |
if track and track in EVALUATION_TRACKS: | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
test_set = test_set[ | |
(test_set["source_language"].isin(track_languages)) & | |
(test_set["target_language"].isin(track_languages)) | |
] | |
try: | |
required_ids = set(test_set["sample_id"].astype(str)) | |
provided_ids = set(predictions["sample_id"].astype(str)) | |
missing_ids = required_ids - provided_ids | |
extra_ids = provided_ids - required_ids | |
matching_ids = provided_ids & required_ids | |
return { | |
"is_complete": len(missing_ids) == 0, | |
"missing_count": len(missing_ids), | |
"extra_count": len(extra_ids), | |
"missing_ids": list(missing_ids)[:10], | |
"coverage": len(matching_ids) / len(required_ids) if required_ids else 0.0, | |
} | |
except Exception as e: | |
print(f"Error in submission completeness validation: {e}") | |
return { | |
"is_complete": False, | |
"missing_count": 0, | |
"extra_count": 0, | |
"missing_ids": [], | |
"coverage": 0.0, | |
} | |
def get_model_summary_stats(model_results: Dict, track: str = None) -> Dict: | |
"""Extract summary statistics from model evaluation results.""" | |
if not model_results or "tracks" not in model_results: | |
return {} | |
tracks = model_results["tracks"] | |
# If specific track requested | |
if track and track in tracks: | |
track_data = tracks[track] | |
if track_data.get("error"): | |
return {"error": f"No valid data for {track} track"} | |
track_averages = track_data.get("track_averages", {}) | |
summary = track_data.get("summary", {}) | |
stats = { | |
"track": track, | |
"track_name": EVALUATION_TRACKS[track]["name"], | |
"quality_score": track_averages.get("quality_score", 0.0), | |
"bleu": track_averages.get("bleu", 0.0), | |
"chrf": track_averages.get("chrf", 0.0), | |
"total_samples": summary.get("total_samples", 0), | |
"language_pairs": summary.get("language_pairs_evaluated", 0), | |
} | |
return stats | |
# Otherwise, return summary across all tracks | |
all_tracks_summary = { | |
"tracks_evaluated": len([t for t in tracks.values() if not t.get("error")]), | |
"total_tracks": len(EVALUATION_TRACKS), | |
"by_track": {}, | |
} | |
for track_name, track_data in tracks.items(): | |
if not track_data.get("error"): | |
track_averages = track_data.get("track_averages", {}) | |
summary = track_data.get("summary", {}) | |
all_tracks_summary["by_track"][track_name] = { | |
"quality_score": track_averages.get("quality_score", 0.0), | |
"samples": summary.get("total_samples", 0), | |
"pairs": summary.get("language_pairs_evaluated", 0), | |
} | |
return all_tracks_summary |