akera commited on
Commit
fa045d5
·
verified ·
1 Parent(s): aecc3e1

Update src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +199 -58
src/utils.py CHANGED
@@ -2,8 +2,8 @@
2
  import re
3
  import datetime
4
  import pandas as pd
5
- from typing import Dict, List, Tuple, Set
6
- from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES
7
 
8
  def get_all_language_pairs() -> List[Tuple[str, str]]:
9
  """Get all possible UG40 language pairs."""
@@ -25,8 +25,8 @@ def get_google_comparable_pairs() -> List[Tuple[str, str]]:
25
 
26
  def format_language_pair(src: str, tgt: str) -> str:
27
  """Format language pair for display."""
28
- src_name = LANGUAGE_NAMES.get(src, src)
29
- tgt_name = LANGUAGE_NAMES.get(tgt, tgt)
30
  return f"{src_name} → {tgt_name}"
31
 
32
  def validate_language_code(lang: str) -> bool:
@@ -38,80 +38,221 @@ def create_submission_id() -> str:
38
  return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
39
 
40
  def sanitize_model_name(name: str) -> str:
41
- """Sanitize model name for display."""
42
- if not name:
43
- return "Anonymous Model"
 
44
  # Remove special characters, limit length
45
  name = re.sub(r'[^\w\-.]', '_', name.strip())
46
- return name[:50]
 
 
 
 
 
 
 
 
 
47
 
48
  def format_metric_value(value: float, metric: str) -> str:
49
- """Format metric value for display."""
50
- if metric in ['bleu']:
51
- return f"{value:.2f}"
52
- elif metric in ['cer', 'wer'] and value > 1:
53
- return f"{min(value, 1.0):.4f}" # Cap error rates at 1.0
54
- else:
55
- return f"{value:.4f}"
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
58
  """Get statistics about language pair coverage in test data."""
 
 
 
59
  stats = {}
60
 
61
- for src in ALL_UG40_LANGUAGES:
62
- for tgt in ALL_UG40_LANGUAGES:
63
- if src != tgt:
64
- pair_data = test_data[
65
- (test_data['source_language'] == src) &
66
- (test_data['target_language'] == tgt)
67
- ]
68
-
69
- stats[f"{src}_{tgt}"] = {
70
- 'count': len(pair_data),
71
- 'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
72
- 'display_name': format_language_pair(src, tgt)
73
- }
 
 
 
 
 
 
74
 
75
  return stats
76
 
77
  def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
78
  """Validate that submission covers all required samples."""
79
 
80
- required_ids = set(test_set['sample_id'].astype(str))
81
- provided_ids = set(predictions['sample_id'].astype(str))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- missing_ids = required_ids - provided_ids
84
- extra_ids = provided_ids - required_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  return {
87
- 'is_complete': len(missing_ids) == 0,
88
- 'missing_count': len(missing_ids),
89
- 'extra_count': len(extra_ids),
90
- 'missing_ids': list(missing_ids)[:10], # First 10 for display
91
- 'coverage': len(provided_ids & required_ids) / len(required_ids)
 
 
 
92
  }
93
 
94
- def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
95
- """Calculate coverage by language pair."""
 
 
 
 
 
 
 
 
 
96
 
97
- # Merge to get language info
98
- merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
 
99
 
100
- coverage = {}
101
- for src in ALL_UG40_LANGUAGES:
102
- for tgt in ALL_UG40_LANGUAGES:
103
- if src != tgt:
104
- pair_data = merged[
105
- (merged['source_language'] == src) &
106
- (merged['target_language'] == tgt)
107
- ]
108
-
109
- if len(pair_data) > 0:
110
- predicted_count = pair_data['prediction'].notna().sum()
111
- coverage[f"{src}_{tgt}"] = {
112
- 'total': len(pair_data),
113
- 'predicted': predicted_count,
114
- 'coverage': predicted_count / len(pair_data)
115
- }
 
 
116
 
117
- return coverage
 
2
  import re
3
  import datetime
4
  import pandas as pd
5
+ from typing import Dict, List, Tuple, Set, Optional
6
+ from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES, DISPLAY_CONFIG
7
 
8
  def get_all_language_pairs() -> List[Tuple[str, str]]:
9
  """Get all possible UG40 language pairs."""
 
25
 
26
  def format_language_pair(src: str, tgt: str) -> str:
27
  """Format language pair for display."""
28
+ src_name = LANGUAGE_NAMES.get(src, src.upper())
29
+ tgt_name = LANGUAGE_NAMES.get(tgt, tgt.upper())
30
  return f"{src_name} → {tgt_name}"
31
 
32
  def validate_language_code(lang: str) -> bool:
 
38
  return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
39
 
40
  def sanitize_model_name(name: str) -> str:
41
+ """Sanitize model name for display and storage."""
42
+ if not name or not isinstance(name, str):
43
+ return "Anonymous_Model"
44
+
45
  # Remove special characters, limit length
46
  name = re.sub(r'[^\w\-.]', '_', name.strip())
47
+ # Remove multiple consecutive underscores
48
+ name = re.sub(r'_+', '_', name)
49
+ # Remove leading/trailing underscores
50
+ name = name.strip('_')
51
+
52
+ # Ensure minimum length
53
+ if len(name) < 3:
54
+ name = f"Model_{name}"
55
+
56
+ return name[:50] # Limit to 50 characters
57
 
58
  def format_metric_value(value: float, metric: str) -> str:
59
+ """Format metric value for display with appropriate precision."""
60
+ if pd.isna(value) or value is None:
61
+ return "N/A"
62
+
63
+ try:
64
+ precision = DISPLAY_CONFIG['decimal_places'].get(metric, 4)
65
+
66
+ if metric == 'coverage_rate':
67
+ return f"{value:.{precision}%}"
68
+ elif metric in ['bleu']:
69
+ return f"{value:.{precision}f}"
70
+ elif metric in ['cer', 'wer'] and value > 1:
71
+ # Cap error rates at 1.0 for display
72
+ return f"{min(value, 1.0):.{precision}f}"
73
+ else:
74
+ return f"{value:.{precision}f}"
75
+ except (ValueError, TypeError):
76
+ return str(value)
77
 
78
  def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
79
  """Get statistics about language pair coverage in test data."""
80
+ if test_data.empty:
81
+ return {}
82
+
83
  stats = {}
84
 
85
+ try:
86
+ for src in ALL_UG40_LANGUAGES:
87
+ for tgt in ALL_UG40_LANGUAGES:
88
+ if src != tgt:
89
+ pair_data = test_data[
90
+ (test_data['source_language'] == src) &
91
+ (test_data['target_language'] == tgt)
92
+ ]
93
+
94
+ stats[f"{src}_{tgt}"] = {
95
+ 'count': len(pair_data),
96
+ 'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
97
+ 'display_name': format_language_pair(src, tgt),
98
+ 'source_language': src,
99
+ 'target_language': tgt
100
+ }
101
+ except Exception as e:
102
+ print(f"Error calculating language pair stats: {e}")
103
+ return {}
104
 
105
  return stats
106
 
107
  def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
108
  """Validate that submission covers all required samples."""
109
 
110
+ if predictions.empty or test_set.empty:
111
+ return {
112
+ 'is_complete': False,
113
+ 'missing_count': len(test_set) if not test_set.empty else 0,
114
+ 'extra_count': len(predictions) if not predictions.empty else 0,
115
+ 'missing_ids': [],
116
+ 'coverage': 0.0
117
+ }
118
+
119
+ try:
120
+ required_ids = set(test_set['sample_id'].astype(str))
121
+ provided_ids = set(predictions['sample_id'].astype(str))
122
+
123
+ missing_ids = required_ids - provided_ids
124
+ extra_ids = provided_ids - required_ids
125
+
126
+ return {
127
+ 'is_complete': len(missing_ids) == 0,
128
+ 'missing_count': len(missing_ids),
129
+ 'extra_count': len(extra_ids),
130
+ 'missing_ids': list(missing_ids)[:10], # First 10 for display
131
+ 'coverage': len(provided_ids & required_ids) / len(required_ids) if required_ids else 0.0
132
+ }
133
+ except Exception as e:
134
+ print(f"Error validating submission completeness: {e}")
135
+ return {
136
+ 'is_complete': False,
137
+ 'missing_count': 0,
138
+ 'extra_count': 0,
139
+ 'missing_ids': [],
140
+ 'coverage': 0.0
141
+ }
142
+
143
+ def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
144
+ """Calculate coverage by language pair."""
145
+
146
+ if predictions.empty or test_set.empty:
147
+ return {}
148
 
149
+ try:
150
+ # Merge to get language info
151
+ merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
152
+
153
+ coverage = {}
154
+ for src in ALL_UG40_LANGUAGES:
155
+ for tgt in ALL_UG40_LANGUAGES:
156
+ if src != tgt:
157
+ pair_data = merged[
158
+ (merged['source_language'] == src) &
159
+ (merged['target_language'] == tgt)
160
+ ]
161
+
162
+ if len(pair_data) > 0:
163
+ predicted_count = pair_data['prediction'].notna().sum()
164
+ coverage[f"{src}_{tgt}"] = {
165
+ 'total': len(pair_data),
166
+ 'predicted': predicted_count,
167
+ 'coverage': predicted_count / len(pair_data),
168
+ 'display_name': format_language_pair(src, tgt)
169
+ }
170
+
171
+ return coverage
172
+ except Exception as e:
173
+ print(f"Error calculating language pair coverage: {e}")
174
+ return {}
175
+
176
+ def safe_divide(numerator: float, denominator: float, default: float = 0.0) -> float:
177
+ """Safely divide two numbers, handling edge cases."""
178
+ try:
179
+ if denominator == 0 or pd.isna(denominator) or pd.isna(numerator):
180
+ return default
181
+ result = numerator / denominator
182
+ if pd.isna(result) or not pd.isfinite(result):
183
+ return default
184
+ return float(result)
185
+ except (TypeError, ValueError, ZeroDivisionError):
186
+ return default
187
+
188
+ def clean_text_for_evaluation(text: str) -> str:
189
+ """Clean text for evaluation, handling common encoding issues."""
190
+ if not isinstance(text, str):
191
+ return str(text) if text is not None else ""
192
+
193
+ # Remove extra whitespace
194
+ text = re.sub(r'\s+', ' ', text.strip())
195
+
196
+ # Handle common encoding issues
197
+ text = text.replace('\u00a0', ' ') # Non-breaking space
198
+ text = text.replace('\u2019', "'") # Right single quotation mark
199
+ text = text.replace('\u201c', '"') # Left double quotation mark
200
+ text = text.replace('\u201d', '"') # Right double quotation mark
201
+
202
+ return text
203
+
204
+ def get_model_summary_stats(model_results: Dict) -> Dict:
205
+ """Extract summary statistics from model evaluation results."""
206
+ if not model_results or 'averages' not in model_results:
207
+ return {}
208
+
209
+ averages = model_results.get('averages', {})
210
+ summary = model_results.get('summary', {})
211
 
212
  return {
213
+ 'quality_score': averages.get('quality_score', 0.0),
214
+ 'bleu': averages.get('bleu', 0.0),
215
+ 'chrf': averages.get('chrf', 0.0),
216
+ 'rouge1': averages.get('rouge1', 0.0),
217
+ 'rougeL': averages.get('rougeL', 0.0),
218
+ 'total_samples': summary.get('total_samples', 0),
219
+ 'language_pairs': summary.get('language_pairs_covered', 0),
220
+ 'google_pairs': summary.get('google_comparable_pairs', 0)
221
  }
222
 
223
+ def generate_model_identifier(model_name: str, author: str) -> str:
224
+ """Generate a unique identifier for a model."""
225
+ clean_name = sanitize_model_name(model_name)
226
+ clean_author = re.sub(r'[^\w\-]', '_', author.strip())[:20] if author else "Anonymous"
227
+ timestamp = datetime.datetime.now().strftime("%m%d_%H%M")
228
+ return f"{clean_name}_{clean_author}_{timestamp}"
229
+
230
+ def validate_dataframe_structure(df: pd.DataFrame, required_columns: List[str]) -> Tuple[bool, List[str]]:
231
+ """Validate that a DataFrame has the required structure."""
232
+ if df.empty:
233
+ return False, ["DataFrame is empty"]
234
 
235
+ missing_columns = [col for col in required_columns if col not in df.columns]
236
+ if missing_columns:
237
+ return False, [f"Missing columns: {', '.join(missing_columns)}"]
238
 
239
+ return True, []
240
+
241
+ def format_duration(seconds: float) -> str:
242
+ """Format duration in seconds to human-readable format."""
243
+ if seconds < 60:
244
+ return f"{seconds:.1f}s"
245
+ elif seconds < 3600:
246
+ return f"{seconds/60:.1f}m"
247
+ else:
248
+ return f"{seconds/3600:.1f}h"
249
+
250
+ def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
251
+ """Truncate text to specified length with suffix."""
252
+ if not isinstance(text, str):
253
+ text = str(text)
254
+
255
+ if len(text) <= max_length:
256
+ return text
257
 
258
+ return text[:max_length - len(suffix)] + suffix