akera commited on
Commit
b78ec70
·
verified ·
1 Parent(s): aa99a22

Update src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +110 -50
src/utils.py CHANGED
@@ -1,57 +1,117 @@
1
  # src/utils.py
2
  import re
3
  import datetime
4
- from typing import Dict, List, Any
5
- import salt.constants
6
-
7
- def get_language_name(lang_code: str) -> str:
8
- """Get full language name from ISO code."""
9
- if lang_code is None:
10
- return "Unknown"
11
- return salt.constants.SALT_LANGUAGE_NAMES.get(lang_code, str(lang_code))
12
-
13
- def format_model_name(model_path: str) -> str:
14
- """Format model name for display in leaderboard."""
15
- if model_path == 'google-translate':
16
- return 'Google Translate'
17
-
18
- # Extract model name from HuggingFace path
19
- if '/' in model_path:
20
- return model_path.split('/')[-1]
21
- return model_path
22
-
23
- def validate_model_path(model_path: str) -> bool:
24
- """Validate if model path is supported."""
25
- if model_path == 'google-translate':
26
- return True
27
-
28
- # Check if it's a valid HuggingFace model path format
29
- pattern = r'^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$'
30
- return bool(re.match(pattern, model_path)) or not '/' in model_path
31
-
32
- def get_model_type(model_path: str) -> str:
33
- """Determine model type from path."""
34
- model_path_lower = model_path.lower()
35
-
36
- if model_path == 'google-translate':
37
- return 'google-translate'
38
- elif 'gemma' in model_path_lower:
39
- return 'gemma'
40
- elif 'qwen' in model_path_lower:
41
- return 'qwen'
42
- elif 'llama' in model_path_lower:
43
- return 'llama'
44
- elif 'nllb' in model_path_lower:
45
- return 'nllb'
46
- else:
47
- return 'other'
48
 
49
  def create_submission_id() -> str:
50
  """Create unique submission ID."""
51
- return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def sanitize_input(text: str) -> str:
54
- """Sanitize user input."""
55
- if not text:
56
- return ""
57
- return re.sub(r'[^\w\-./]', '', text.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # src/utils.py
2
  import re
3
  import datetime
4
+ import pandas as pd
5
+ from typing import Dict, List, Tuple, Set
6
+ from config import ALL_UG40_LANGUAGES, LANGUAGE_NAMES, GOOGLE_SUPPORTED_LANGUAGES
7
+
8
+ def get_all_language_pairs() -> List[Tuple[str, str]]:
9
+ """Get all possible UG40 language pairs."""
10
+ pairs = []
11
+ for src in ALL_UG40_LANGUAGES:
12
+ for tgt in ALL_UG40_LANGUAGES:
13
+ if src != tgt:
14
+ pairs.append((src, tgt))
15
+ return pairs
16
+
17
+ def get_google_comparable_pairs() -> List[Tuple[str, str]]:
18
+ """Get language pairs that can be compared with Google Translate."""
19
+ pairs = []
20
+ for src in GOOGLE_SUPPORTED_LANGUAGES:
21
+ for tgt in GOOGLE_SUPPORTED_LANGUAGES:
22
+ if src != tgt:
23
+ pairs.append((src, tgt))
24
+ return pairs
25
+
26
+ def format_language_pair(src: str, tgt: str) -> str:
27
+ """Format language pair for display."""
28
+ src_name = LANGUAGE_NAMES.get(src, src)
29
+ tgt_name = LANGUAGE_NAMES.get(tgt, tgt)
30
+ return f"{src_name} {tgt_name}"
31
+
32
+ def validate_language_code(lang: str) -> bool:
33
+ """Validate if language code is supported."""
34
+ return lang in ALL_UG40_LANGUAGES
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def create_submission_id() -> str:
37
  """Create unique submission ID."""
38
+ return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
39
+
40
+ def sanitize_model_name(name: str) -> str:
41
+ """Sanitize model name for display."""
42
+ if not name:
43
+ return "Anonymous Model"
44
+ # Remove special characters, limit length
45
+ name = re.sub(r'[^\w\-.]', '_', name.strip())
46
+ return name[:50]
47
+
48
+ def format_metric_value(value: float, metric: str) -> str:
49
+ """Format metric value for display."""
50
+ if metric in ['bleu']:
51
+ return f"{value:.2f}"
52
+ elif metric in ['cer', 'wer'] and value > 1:
53
+ return f"{min(value, 1.0):.4f}" # Cap error rates at 1.0
54
+ else:
55
+ return f"{value:.4f}"
56
+
57
+ def get_language_pair_stats(test_data: pd.DataFrame) -> Dict[str, Dict]:
58
+ """Get statistics about language pair coverage in test data."""
59
+ stats = {}
60
+
61
+ for src in ALL_UG40_LANGUAGES:
62
+ for tgt in ALL_UG40_LANGUAGES:
63
+ if src != tgt:
64
+ pair_data = test_data[
65
+ (test_data['source_language'] == src) &
66
+ (test_data['target_language'] == tgt)
67
+ ]
68
+
69
+ stats[f"{src}_{tgt}"] = {
70
+ 'count': len(pair_data),
71
+ 'google_comparable': src in GOOGLE_SUPPORTED_LANGUAGES and tgt in GOOGLE_SUPPORTED_LANGUAGES,
72
+ 'display_name': format_language_pair(src, tgt)
73
+ }
74
+
75
+ return stats
76
 
77
+ def validate_submission_completeness(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
78
+ """Validate that submission covers all required samples."""
79
+
80
+ required_ids = set(test_set['sample_id'].astype(str))
81
+ provided_ids = set(predictions['sample_id'].astype(str))
82
+
83
+ missing_ids = required_ids - provided_ids
84
+ extra_ids = provided_ids - required_ids
85
+
86
+ return {
87
+ 'is_complete': len(missing_ids) == 0,
88
+ 'missing_count': len(missing_ids),
89
+ 'extra_count': len(extra_ids),
90
+ 'missing_ids': list(missing_ids)[:10], # First 10 for display
91
+ 'coverage': len(provided_ids & required_ids) / len(required_ids)
92
+ }
93
+
94
+ def calculate_language_pair_coverage(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
95
+ """Calculate coverage by language pair."""
96
+
97
+ # Merge to get language info
98
+ merged = test_set.merge(predictions, on='sample_id', how='left', suffixes=('', '_pred'))
99
+
100
+ coverage = {}
101
+ for src in ALL_UG40_LANGUAGES:
102
+ for tgt in ALL_UG40_LANGUAGES:
103
+ if src != tgt:
104
+ pair_data = merged[
105
+ (merged['source_language'] == src) &
106
+ (merged['target_language'] == tgt)
107
+ ]
108
+
109
+ if len(pair_data) > 0:
110
+ predicted_count = pair_data['prediction'].notna().sum()
111
+ coverage[f"{src}_{tgt}"] = {
112
+ 'total': len(pair_data),
113
+ 'predicted': predicted_count,
114
+ 'coverage': predicted_count / len(pair_data)
115
+ }
116
+
117
+ return coverage