Spaces:
Sleeping
Sleeping
Update src/utils.py
Browse files- src/utils.py +199 -58
src/utils.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
import re
|
3 |
import datetime
|
4 |
import pandas as pd
|
5 |
-
from typing import Dict, List, Tuple, Set
|
6 |
-
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES
|
7 |
|
8 |
def get_all_language_pairs() -> List[Tuple[str, str]]:
|
9 |
"""Get all possible UG40 language pairs."""
|
@@ -25,8 +25,8 @@ def get_google_comparable_pairs() -> List[Tuple[str, str]]:
|
|
25 |
|
26 |
def format_language_pair(src: str, tgt: str) -> str:
|
27 |
"""Format language pair for display."""
|
28 |
-
src_name = LANGUAGE_NAMES.get(src, src)
|
29 |
-
tgt_name = LANGUAGE_NAMES.get(tgt, tgt)
|
30 |
return f"{src_name} → {tgt_name}"
|
31 |
|
32 |
def validate_language_code(lang: str) -> bool:
|
@@ -38,80 +38,221 @@ def create_submission_id() -> str:
|
|
38 |
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
39 |
|
40 |
def sanitize_model_name(name: str) -> str:
|
41 |
-
"""Sanitize model name for display."""
|
42 |
-
if not name:
|
43 |
-
return "
|
|
|
44 |
# Remove special characters, limit length
|
45 |
name = re.sub(r'[^\w\-.]', '_', name.strip())
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def format_metric_value(value: float, metric: str) -> str:
|
49 |
-
"""Format metric value for display."""
|
50 |
-
if
|
51 |
-
return
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
|
58 |
"""Get statistics about language pair coverage in test data."""
|
|
|
|
|
|
|
59 |
stats = {}
|
60 |
|
61 |
-
|
62 |
-
for
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
return stats
|
76 |
|
77 |
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
|
78 |
"""Validate that submission covers all required samples."""
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
return {
|
87 |
-
'
|
88 |
-
'
|
89 |
-
'
|
90 |
-
'
|
91 |
-
'
|
|
|
|
|
|
|
92 |
}
|
93 |
|
94 |
-
def
|
95 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
|
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
|
117 |
-
return
|
|
|
2 |
import re
|
3 |
import datetime
|
4 |
import pandas as pd
|
5 |
+
from typing import Dict, List, Tuple, Set, Optional
|
6 |
+
from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES, DISPLAY_CONFIG
|
7 |
|
8 |
def get_all_language_pairs() -> List[Tuple[str, str]]:
|
9 |
"""Get all possible UG40 language pairs."""
|
|
|
25 |
|
26 |
def format_language_pair(src: str, tgt: str) -> str:
|
27 |
"""Format language pair for display."""
|
28 |
+
src_name = LANGUAGE_NAMES.get(src, src.upper())
|
29 |
+
tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper())
|
30 |
return f"{src_name} → {tgt_name}"
|
31 |
|
32 |
def validate_language_code(lang: str) -> bool:
|
|
|
38 |
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
39 |
|
40 |
def sanitize_model_name(name: str) -> str:
|
41 |
+
"""Sanitize model name for display and storage."""
|
42 |
+
if not name or not isinstance(name, str):
|
43 |
+
return "Anonymous_Model"
|
44 |
+
|
45 |
# Remove special characters, limit length
|
46 |
name = re.sub(r'[^\w\-.]', '_', name.strip())
|
47 |
+
# Remove multiple consecutive underscores
|
48 |
+
name = re.sub(r'_+', '_', name)
|
49 |
+
# Remove leading/trailing underscores
|
50 |
+
name = name.strip('_')
|
51 |
+
|
52 |
+
# Ensure minimum length
|
53 |
+
if len(name) < 3:
|
54 |
+
name = f"Model_{name}"
|
55 |
+
|
56 |
+
return name[:50] # Limit to 50 characters
|
57 |
|
58 |
def format_metric_value(value: float, metric: str) -> str:
|
59 |
+
"""Format metric value for display with appropriate precision."""
|
60 |
+
if pd.isna(value) or value is None:
|
61 |
+
return "N/A"
|
62 |
+
|
63 |
+
try:
|
64 |
+
precision = DISPLAY_CONFIG['decimal_places'].get(metric, 4)
|
65 |
+
|
66 |
+
if metric == 'coverage_rate':
|
67 |
+
return f"{value:.{precision}%}"
|
68 |
+
elif metric in ['bleu']:
|
69 |
+
return f"{value:.{precision}f}"
|
70 |
+
elif metric in ['cer', 'wer'] and value > 1:
|
71 |
+
# Cap error rates at 1.0 for display
|
72 |
+
return f"{min(value, 1.0):.{precision}f}"
|
73 |
+
else:
|
74 |
+
return f"{value:.{precision}f}"
|
75 |
+
except (ValueError, TypeError):
|
76 |
+
return str(value)
|
77 |
|
78 |
def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
|
79 |
"""Get statistics about language pair coverage in test data."""
|
80 |
+
if test_data.empty:
|
81 |
+
return {}
|
82 |
+
|
83 |
stats = {}
|
84 |
|
85 |
+
try:
|
86 |
+
for src in ALL_UG40_LANGUAGES:
|
87 |
+
for tgt in ALL_UG40_LANGUAGES:
|
88 |
+
if src != tgt:
|
89 |
+
pair_data = test_data[
|
90 |
+
(test_data['source_language'] == src) &
|
91 |
+
(test_data['target_language'] == tgt)
|
92 |
+
]
|
93 |
+
|
94 |
+
stats[f"{src}_{tgt}"] = {
|
95 |
+
'count': len(pair_data),
|
96 |
+
'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
|
97 |
+
'display_name': format_language_pair(src, tgt),
|
98 |
+
'source_language': src,
|
99 |
+
'target_language': tgt
|
100 |
+
}
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error calculating language pair stats: {e}")
|
103 |
+
return {}
|
104 |
|
105 |
return stats
|
106 |
|
107 |
def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
|
108 |
"""Validate that submission covers all required samples."""
|
109 |
|
110 |
+
if predictions.empty or test_set.empty:
|
111 |
+
return {
|
112 |
+
'is_complete': False,
|
113 |
+
'missing_count': len(test_set) if not test_set.empty else 0,
|
114 |
+
'extra_count': len(predictions) if not predictions.empty else 0,
|
115 |
+
'missing_ids': [],
|
116 |
+
'coverage': 0.0
|
117 |
+
}
|
118 |
+
|
119 |
+
try:
|
120 |
+
required_ids = set(test_set['sample_id'].astype(str))
|
121 |
+
provided_ids = set(predictions['sample_id'].astype(str))
|
122 |
+
|
123 |
+
missing_ids = required_ids - provided_ids
|
124 |
+
extra_ids = provided_ids - required_ids
|
125 |
+
|
126 |
+
return {
|
127 |
+
'is_complete': len(missing_ids) == 0,
|
128 |
+
'missing_count': len(missing_ids),
|
129 |
+
'extra_count': len(extra_ids),
|
130 |
+
'missing_ids': list(missing_ids)[:10], # First 10 for display
|
131 |
+
'coverage': len(provided_ids & required_ids) / len(required_ids) if required_ids else 0.0
|
132 |
+
}
|
133 |
+
except Exception as e:
|
134 |
+
print(f"Error validating submission completeness: {e}")
|
135 |
+
return {
|
136 |
+
'is_complete': False,
|
137 |
+
'missing_count': 0,
|
138 |
+
'extra_count': 0,
|
139 |
+
'missing_ids': [],
|
140 |
+
'coverage': 0.0
|
141 |
+
}
|
142 |
+
|
143 |
+
def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
|
144 |
+
"""Calculate coverage by language pair."""
|
145 |
+
|
146 |
+
if predictions.empty or test_set.empty:
|
147 |
+
return {}
|
148 |
|
149 |
+
try:
|
150 |
+
# Merge to get language info
|
151 |
+
merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
|
152 |
+
|
153 |
+
coverage = {}
|
154 |
+
for src in ALL_UG40_LANGUAGES:
|
155 |
+
for tgt in ALL_UG40_LANGUAGES:
|
156 |
+
if src != tgt:
|
157 |
+
pair_data = merged[
|
158 |
+
(merged['source_language'] == src) &
|
159 |
+
(merged['target_language'] == tgt)
|
160 |
+
]
|
161 |
+
|
162 |
+
if len(pair_data) > 0:
|
163 |
+
predicted_count = pair_data['prediction'].notna().sum()
|
164 |
+
coverage[f"{src}_{tgt}"] = {
|
165 |
+
'total': len(pair_data),
|
166 |
+
'predicted': predicted_count,
|
167 |
+
'coverage': predicted_count / len(pair_data),
|
168 |
+
'display_name': format_language_pair(src, tgt)
|
169 |
+
}
|
170 |
+
|
171 |
+
return coverage
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Error calculating language pair coverage: {e}")
|
174 |
+
return {}
|
175 |
+
|
176 |
+
def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
|
177 |
+
"""Safely divide two numbers, handling edge cases."""
|
178 |
+
try:
|
179 |
+
if denominator == 0 or pd.isna(denominator) or pd.isna(numerator):
|
180 |
+
return default
|
181 |
+
result = numerator / denominator
|
182 |
+
if pd.isna(result) or not pd.isfinite(result):
|
183 |
+
return default
|
184 |
+
return float(result)
|
185 |
+
except (TypeError, ValueError, ZeroDivisionError):
|
186 |
+
return default
|
187 |
+
|
188 |
+
def clean_text_for_evaluation(text: str) -> str:
|
189 |
+
"""Clean text for evaluation, handling common encoding issues."""
|
190 |
+
if not isinstance(text, str):
|
191 |
+
return str(text) if text is not None else ""
|
192 |
+
|
193 |
+
# Remove extra whitespace
|
194 |
+
text = re.sub(r'\s+', ' ', text.strip())
|
195 |
+
|
196 |
+
# Handle common encoding issues
|
197 |
+
text = text.replace('\u00a0', ' ') # Non-breaking space
|
198 |
+
text = text.replace('\u2019', "'") # Right single quotation mark
|
199 |
+
text = text.replace('\u201c', '"') # Left double quotation mark
|
200 |
+
text = text.replace('\u201d', '"') # Right double quotation mark
|
201 |
+
|
202 |
+
return text
|
203 |
+
|
204 |
+
def get_model_summary_stats(model_results: Dict) -> Dict:
|
205 |
+
"""Extract summary statistics from model evaluation results."""
|
206 |
+
if not model_results or 'averages' not in model_results:
|
207 |
+
return {}
|
208 |
+
|
209 |
+
averages = model_results.get('averages', {})
|
210 |
+
summary = model_results.get('summary', {})
|
211 |
|
212 |
return {
|
213 |
+
'quality_score': averages.get('quality_score', 0.0),
|
214 |
+
'bleu': averages.get('bleu', 0.0),
|
215 |
+
'chrf': averages.get('chrf', 0.0),
|
216 |
+
'rouge1': averages.get('rouge1', 0.0),
|
217 |
+
'rougeL': averages.get('rougeL', 0.0),
|
218 |
+
'total_samples': summary.get('total_samples', 0),
|
219 |
+
'language_pairs': summary.get('language_pairs_covered', 0),
|
220 |
+
'google_pairs': summary.get('google_comparable_pairs', 0)
|
221 |
}
|
222 |
|
223 |
+
def generate_model_identifier(model_name: str, author: str) -> str:
|
224 |
+
"""Generate a unique identifier for a model."""
|
225 |
+
clean_name = sanitize_model_name(model_name)
|
226 |
+
clean_author = re.sub(r'[^\w\-]', '_', author.strip())[:20] if author else "Anonymous"
|
227 |
+
timestamp = datetime.datetime.now().strftime("%m%d_%H%M")
|
228 |
+
return f"{clean_name}_{clean_author}_{timestamp}"
|
229 |
+
|
230 |
+
def validate_dataframe_structure(df: pd.DataFrame, required_columns: List[str]) -> Tuple[bool, List[str]]:
|
231 |
+
"""Validate that a DataFrame has the required structure."""
|
232 |
+
if df.empty:
|
233 |
+
return False, ["DataFrame is empty"]
|
234 |
|
235 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
236 |
+
if missing_columns:
|
237 |
+
return False, [f"Missing columns: {', '.join(missing_columns)}"]
|
238 |
|
239 |
+
return True, []
|
240 |
+
|
241 |
+
def format_duration(seconds: float) -> str:
|
242 |
+
"""Format duration in seconds to human-readable format."""
|
243 |
+
if seconds < 60:
|
244 |
+
return f"{seconds:.1f}s"
|
245 |
+
elif seconds < 3600:
|
246 |
+
return f"{seconds/60:.1f}m"
|
247 |
+
else:
|
248 |
+
return f"{seconds/3600:.1f}h"
|
249 |
+
|
250 |
+
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
|
251 |
+
"""Truncate text to specified length with suffix."""
|
252 |
+
if not isinstance(text, str):
|
253 |
+
text = str(text)
|
254 |
+
|
255 |
+
if len(text) <= max_length:
|
256 |
+
return text
|
257 |
|
258 |
+
return text[:max_length - len(suffix)] + suffix
|