leaderboard / src /utils.py
akera's picture
Update src/utils.py
b78ec70 verified
raw
history blame
4.27 kB
# src/utils.py
import re
import datetime
import pandas as pd
from typing import Dict, List, Tuple, Set
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES
def get_all_language_pairs() -> List[Tuple[str, str]]:
"""Get all possible UG40 language pairs."""
pairs = []
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def get_google_comparable_pairs() -> List[Tuple[str, str]]:
"""Get language pairs that can be compared with Google Translate."""
pairs = []
for src in GOOGLE_SUPPORTED_LANGUAGES:
for tgt in GOOGLE_SUPPORTED_LANGUAGES:
if src != tgt:
pairs.append((src, tgt))
return pairs
def format_language_pair(src: str, tgt: str) -> str:
"""Format language pair for display."""
src_name = LANGUAGE_NAMES.get(src, src)
tgt_name = LANGUAGE_NAMES.get(tgt, tgt)
return f"{src_name}{tgt_name}"
def validate_language_code(lang: str) -> bool:
"""Validate if language code is supported."""
return lang in ALL_UG40_LANGUAGES
def create_submission_id() -> str:
"""Create unique submission ID."""
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
def sanitize_model_name(name: str) -> str:
"""Sanitize model name for display."""
if not name:
return "Anonymous Model"
# Remove special characters, limit length
name = re.sub(r'[^\w\-.]', '_', name.strip())
return name[:50]
def format_metric_value(value: float, metric: str) -> str:
"""Format metric value for display."""
if metric in ['bleu']:
return f"{value:.2f}"
elif metric in ['cer', 'wer'] and value > 1:
return f"{min(value, 1.0):.4f}" # Cap error rates at 1.0
else:
return f"{value:.4f}"
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
"""Get statistics about language pair coverage in test data."""
stats = {}
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pair_data = test_data[
(test_data['source_language'] == src) &
(test_data['target_language'] == tgt)
]
stats[f"{src}_{tgt}"] = {
'count': len(pair_data),
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
'display_name': format_language_pair(src, tgt)
}
return stats
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
"""Validate that submission covers all required samples."""
required_ids = set(test_set['sample_id'].astype(str))
provided_ids = set(predictions['sample_id'].astype(str))
missing_ids = required_ids - provided_ids
extra_ids = provided_ids - required_ids
return {
'is_complete': len(missing_ids) == 0,
'missing_count': len(missing_ids),
'extra_count': len(extra_ids),
'missing_ids': list(missing_ids)[:10], # First 10 for display
'coverage': len(provided_ids & required_ids) / len(required_ids)
}
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
"""Calculate coverage by language pair."""
# Merge to get language info
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
coverage = {}
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pair_data = merged[
(merged['source_language'] == src) &
(merged['target_language'] == tgt)
]
if len(pair_data) > 0:
predicted_count = pair_data['prediction'].notna().sum()
coverage[f"{src}_{tgt}"] = {
'total': len(pair_data),
'predicted': predicted_count,
'coverage': predicted_count / len(pair_data)
}
return coverage