Spaces:
Running
Running
# src/utils.py | |
import re | |
import datetime | |
import pandas as pd | |
from typing import Dict, List, Tuple, Set | |
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES | |
def get_all_language_pairs() -> List[Tuple[str, str]]: | |
"""Get all possible UG40 language pairs.""" | |
pairs = [] | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def get_google_comparable_pairs() -> List[Tuple[str, str]]: | |
"""Get language pairs that can be compared with Google Translate.""" | |
pairs = [] | |
for src in GOOGLE_SUPPORTED_LANGUAGES: | |
for tgt in GOOGLE_SUPPORTED_LANGUAGES: | |
if src != tgt: | |
pairs.append((src, tgt)) | |
return pairs | |
def format_language_pair(src: str, tgt: str) -> str: | |
"""Format language pair for display.""" | |
src_name = LANGUAGE_NAMES.get(src, src) | |
tgt_name = LANGUAGE_NAMES.get(tgt, tgt) | |
return f"{src_name} → {tgt_name}" | |
def validate_language_code(lang: str) -> bool: | |
"""Validate if language code is supported.""" | |
return lang in ALL_UG40_LANGUAGES | |
def create_submission_id() -> str: | |
"""Create unique submission ID.""" | |
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] | |
def sanitize_model_name(name: str) -> str: | |
"""Sanitize model name for display.""" | |
if not name: | |
return "Anonymous Model" | |
# Remove special characters, limit length | |
name = re.sub(r'[^\w\-.]', '_', name.strip()) | |
return name[:50] | |
def format_metric_value(value: float, metric: str) -> str: | |
"""Format metric value for display.""" | |
if metric in ['bleu']: | |
return f"{value:.2f}" | |
elif metric in ['cer', 'wer'] and value > 1: | |
return f"{min(value, 1.0):.4f}" # Cap error rates at 1.0 | |
else: | |
return f"{value:.4f}" | |
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]: | |
"""Get statistics about language pair coverage in test data.""" | |
stats = {} | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pair_data = test_data[ | |
(test_data['source_language'] == src) & | |
(test_data['target_language'] == tgt) | |
] | |
stats[f"{src}_{tgt}"] = { | |
'count': len(pair_data), | |
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES, | |
'display_name': format_language_pair(src, tgt) | |
} | |
return stats | |
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: | |
"""Validate that submission covers all required samples.""" | |
required_ids = set(test_set['sample_id'].astype(str)) | |
provided_ids = set(predictions['sample_id'].astype(str)) | |
missing_ids = required_ids - provided_ids | |
extra_ids = provided_ids - required_ids | |
return { | |
'is_complete': len(missing_ids) == 0, | |
'missing_count': len(missing_ids), | |
'extra_count': len(extra_ids), | |
'missing_ids': list(missing_ids)[:10], # First 10 for display | |
'coverage': len(provided_ids & required_ids) / len(required_ids) | |
} | |
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: | |
"""Calculate coverage by language pair.""" | |
# Merge to get language info | |
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred')) | |
coverage = {} | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pair_data = merged[ | |
(merged['source_language'] == src) & | |
(merged['target_language'] == tgt) | |
] | |
if len(pair_data) > 0: | |
predicted_count = pair_data['prediction'].notna().sum() | |
coverage[f"{src}_{tgt}"] = { | |
'total': len(pair_data), | |
'predicted': predicted_count, | |
'coverage': predicted_count / len(pair_data) | |
} | |
return coverage |