Spaces:
Sleeping
Sleeping
Update src/utils.py
Browse files- src/utils.py +110 -50
src/utils.py
CHANGED
@@ -1,57 +1,117 @@
|
|
1 |
# src/utils.py
|
2 |
import re
|
3 |
import datetime
|
4 |
-
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
return
|
31 |
-
|
32 |
-
def
|
33 |
-
"""
|
34 |
-
|
35 |
-
|
36 |
-
if model_path == 'google-translate':
|
37 |
-
return 'google-translate'
|
38 |
-
elif 'gemma' in model_path_lower:
|
39 |
-
return 'gemma'
|
40 |
-
elif 'qwen' in model_path_lower:
|
41 |
-
return 'qwen'
|
42 |
-
elif 'llama' in model_path_lower:
|
43 |
-
return 'llama'
|
44 |
-
elif 'nllb' in model_path_lower:
|
45 |
-
return 'nllb'
|
46 |
-
else:
|
47 |
-
return 'other'
|
48 |
|
49 |
def create_submission_id() -> str:
|
50 |
"""Create unique submission ID."""
|
51 |
-
return datetime.datetime.now().strftime("%Y%m%d_%H%M%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
def
|
54 |
-
"""
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# src/utils.py
|
2 |
import re
|
3 |
import datetime
|
4 |
+
import pandas as pd
|
5 |
+
from typing import Dict, List, Tuple, Set
|
6 |
+
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES
|
7 |
+
|
8 |
+
def get_all_language_pairs() -> List[Tuple[str, str]]:
|
9 |
+
"""Get all possible UG40 language pairs."""
|
10 |
+
pairs = []
|
11 |
+
for src in ALL_UG40_LANGUAGES:
|
12 |
+
for tgt in ALL_UG40_LANGUAGES:
|
13 |
+
if src != tgt:
|
14 |
+
pairs.append((src, tgt))
|
15 |
+
return pairs
|
16 |
+
|
17 |
+
def get_google_comparable_pairs() -> List[Tuple[str, str]]:
|
18 |
+
"""Get language pairs that can be compared with Google Translate."""
|
19 |
+
pairs = []
|
20 |
+
for src in GOOGLE_SUPPORTED_LANGUAGES:
|
21 |
+
for tgt in GOOGLE_SUPPORTED_LANGUAGES:
|
22 |
+
if src != tgt:
|
23 |
+
pairs.append((src, tgt))
|
24 |
+
return pairs
|
25 |
+
|
26 |
+
def format_language_pair(src: str, tgt: str) -> str:
|
27 |
+
"""Format language pair for display."""
|
28 |
+
src_name = LANGUAGE_NAMES.get(src, src)
|
29 |
+
tgt_name = LANGUAGE_NAMES.get(tgt, tgt)
|
30 |
+
return f"{src_name} → {tgt_name}"
|
31 |
+
|
32 |
+
def validate_language_code(lang: str) -> bool:
|
33 |
+
"""Validate if language code is supported."""
|
34 |
+
return lang in ALL_UG40_LANGUAGES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def create_submission_id() -> str:
|
37 |
"""Create unique submission ID."""
|
38 |
+
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
39 |
+
|
40 |
+
def sanitize_model_name(name: str) -> str:
|
41 |
+
"""Sanitize model name for display."""
|
42 |
+
if not name:
|
43 |
+
return "Anonymous Model"
|
44 |
+
# Remove special characters, limit length
|
45 |
+
name = re.sub(r'[^\w\-.]', '_', name.strip())
|
46 |
+
return name[:50]
|
47 |
+
|
48 |
+
def format_metric_value(value: float, metric: str) -> str:
|
49 |
+
"""Format metric value for display."""
|
50 |
+
if metric in ['bleu']:
|
51 |
+
return f"{value:.2f}"
|
52 |
+
elif metric in ['cer', 'wer'] and value > 1:
|
53 |
+
return f"{min(value, 1.0):.4f}" # Cap error rates at 1.0
|
54 |
+
else:
|
55 |
+
return f"{value:.4f}"
|
56 |
+
|
57 |
+
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
|
58 |
+
"""Get statistics about language pair coverage in test data."""
|
59 |
+
stats = {}
|
60 |
+
|
61 |
+
for src in ALL_UG40_LANGUAGES:
|
62 |
+
for tgt in ALL_UG40_LANGUAGES:
|
63 |
+
if src != tgt:
|
64 |
+
pair_data = test_data[
|
65 |
+
(test_data['source_language'] == src) &
|
66 |
+
(test_data['target_language'] == tgt)
|
67 |
+
]
|
68 |
+
|
69 |
+
stats[f"{src}_{tgt}"] = {
|
70 |
+
'count': len(pair_data),
|
71 |
+
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
|
72 |
+
'display_name': format_language_pair(src, tgt)
|
73 |
+
}
|
74 |
+
|
75 |
+
return stats
|
76 |
|
77 |
+
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
|
78 |
+
"""Validate that submission covers all required samples."""
|
79 |
+
|
80 |
+
required_ids = set(test_set['sample_id'].astype(str))
|
81 |
+
provided_ids = set(predictions['sample_id'].astype(str))
|
82 |
+
|
83 |
+
missing_ids = required_ids - provided_ids
|
84 |
+
extra_ids = provided_ids - required_ids
|
85 |
+
|
86 |
+
return {
|
87 |
+
'is_complete': len(missing_ids) == 0,
|
88 |
+
'missing_count': len(missing_ids),
|
89 |
+
'extra_count': len(extra_ids),
|
90 |
+
'missing_ids': list(missing_ids)[:10], # First 10 for display
|
91 |
+
'coverage': len(provided_ids & required_ids) / len(required_ids)
|
92 |
+
}
|
93 |
+
|
94 |
+
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
|
95 |
+
"""Calculate coverage by language pair."""
|
96 |
+
|
97 |
+
# Merge to get language info
|
98 |
+
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
|
99 |
+
|
100 |
+
coverage = {}
|
101 |
+
for src in ALL_UG40_LANGUAGES:
|
102 |
+
for tgt in ALL_UG40_LANGUAGES:
|
103 |
+
if src != tgt:
|
104 |
+
pair_data = merged[
|
105 |
+
(merged['source_language'] == src) &
|
106 |
+
(merged['target_language'] == tgt)
|
107 |
+
]
|
108 |
+
|
109 |
+
if len(pair_data) > 0:
|
110 |
+
predicted_count = pair_data['prediction'].notna().sum()
|
111 |
+
coverage[f"{src}_{tgt}"] = {
|
112 |
+
'total': len(pair_data),
|
113 |
+
'predicted': predicted_count,
|
114 |
+
'coverage': predicted_count / len(pair_data)
|
115 |
+
}
|
116 |
+
|
117 |
+
return coverage
|