akera commited on
Commit
4fa2f10
Β·
verified Β·
1 Parent(s): c1926c2

Rename src/evaluation.py to src/validation.py

Browse files
Files changed (2) hide show
  1. src/evaluation.py +0 -413
  2. src/validation.py +274 -0
src/evaluation.py DELETED
@@ -1,413 +0,0 @@
1
- # src/evaluation.py
2
- import torch
3
- import numpy as np
4
- from tqdm.auto import tqdm
5
- from sacrebleu.metrics import BLEU, CHRF
6
- from rouge_score import rouge_scorer
7
- import Levenshtein
8
- from collections import defaultdict
9
- from transformers.models.whisper.english_normalizer import BasicTextNormalizer
10
- import salt.constants
11
- import datetime
12
- import os
13
- from google.cloud import translate_v3
14
- from config import GOOGLE_LANG_MAP
15
-
16
- def setup_google_translate():
17
- """Setup Google Cloud Translation client if credentials available."""
18
- try:
19
- # Check if running in HF Space with credentials
20
- if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") or os.getenv("GOOGLE_CLOUD_PROJECT"):
21
- client = translate_v3.TranslationServiceClient()
22
- project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "sb-gcp-project-01")
23
- parent = f"projects/{project_id}/locations/global"
24
- return client, parent
25
- else:
26
- print("Google Cloud credentials not found. Google Translate will not be available.")
27
- return None, None
28
- except Exception as e:
29
- print(f"Error setting up Google Translate: {e}")
30
- return None, None
31
-
32
- def google_translate_batch(texts, source_langs, target_langs, client, parent):
33
- """Translate using Google Cloud Translation API."""
34
- translations = []
35
-
36
- for text, src_lang, tgt_lang in tqdm(zip(texts, source_langs, target_langs),
37
- total=len(texts), desc="Google Translate"):
38
- try:
39
- # Map SALT language codes to Google's format
40
- src_google = GOOGLE_LANG_MAP.get(src_lang, src_lang)
41
- tgt_google = GOOGLE_LANG_MAP.get(tgt_lang, tgt_lang)
42
-
43
- # Check if language pair is supported
44
- supported_langs = ['lg', 'ach', 'sw', 'en']
45
- if src_google not in supported_langs or tgt_google not in supported_langs:
46
- translations.append(f"[UNSUPPORTED: {src_lang}->{tgt_lang}]")
47
- continue
48
-
49
- # Make translation request
50
- request = {
51
- "parent": parent,
52
- "contents": [text],
53
- "mime_type": "text/plain",
54
- "source_language_code": src_google,
55
- "target_language_code": tgt_google,
56
- }
57
-
58
- response = client.translate_text(request=request)
59
- translation = response.translations[0].translated_text
60
- translations.append(translation)
61
-
62
- except Exception as e:
63
- print(f"Error translating '{text}': {e}")
64
- translations.append(f"[ERROR: {str(e)[:50]}]")
65
-
66
- return translations
67
-
68
- def get_translation_function(model, tokenizer, model_path):
69
- """Get appropriate translation function based on model type."""
70
-
71
- if model_path == 'google-translate':
72
- client, parent = setup_google_translate()
73
- if client is None:
74
- raise Exception("Google Translate credentials not available")
75
-
76
- def translation_fn(texts, from_langs, to_langs):
77
- return google_translate_batch(texts, from_langs, to_langs, client, parent)
78
-
79
- return translation_fn
80
-
81
- elif 'gemma' in str(type(model)).lower() or 'gemma' in model_path.lower():
82
- return get_gemma_translation_fn(model, tokenizer)
83
-
84
- elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'Qwen2ForCausalLM' in str(type(model.base_model.model)):
85
- return get_qwen_translation_fn(model, tokenizer)
86
-
87
- elif 'm2m_100' in str(type(model)).lower():
88
- return get_nllb_translation_fn(model, tokenizer)
89
-
90
- elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'LlamaForCausalLM' in str(type(model.base_model.model)):
91
- return get_llama_translation_fn(model, tokenizer)
92
-
93
- else:
94
- # Generic function for other models
95
- return get_generic_translation_fn(model, tokenizer)
96
-
97
- def get_gemma_translation_fn(model, tokenizer):
98
- """Translation function for Gemma models."""
99
- def translation_fn(texts, from_langs, to_langs):
100
- SYSTEM_MESSAGE = 'You are a linguist and translation assistant specialising in Ugandan languages.'
101
- translations = []
102
- batch_size = 4
103
- device = next(model.parameters()).device
104
-
105
- instructions = [
106
- f'Translate from {salt.constants.SALT_LANGUAGE_NAMES[from_lang]} '
107
- f'to {salt.constants.SALT_LANGUAGE_NAMES[to_lang]}: {text}'
108
- for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
109
- ]
110
-
111
- for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
112
- batch_instructions = instructions[i:i + batch_size]
113
- messages_list = [
114
- [
115
- {"role": "system", "content": SYSTEM_MESSAGE},
116
- {"role": "user", "content": instruction}
117
- ] for instruction in batch_instructions
118
- ]
119
-
120
- prompts = [
121
- tokenizer.apply_chat_template(
122
- messages, tokenize=False, add_generation_prompt=True
123
- ) for messages in messages_list
124
- ]
125
-
126
- inputs = tokenizer(
127
- prompts, return_tensors="pt",
128
- padding=True, padding_side='left',
129
- max_length=512, truncation=True
130
- ).to(device)
131
-
132
- with torch.no_grad():
133
- outputs = model.generate(
134
- **inputs,
135
- max_new_tokens=100,
136
- temperature=0.5,
137
- num_beams=5,
138
- do_sample=True,
139
- no_repeat_ngram_size=5,
140
- pad_token_id=tokenizer.eos_token_id
141
- )
142
-
143
- for j in range(len(outputs)):
144
- translation = tokenizer.decode(
145
- outputs[j, inputs['input_ids'].shape[1]:],
146
- skip_special_tokens=True
147
- )
148
- translations.append(translation)
149
-
150
- return translations
151
-
152
- return translation_fn
153
-
154
- def get_qwen_translation_fn(model, tokenizer):
155
- """Translation function for Qwen models."""
156
- def translation_fn(texts, from_langs, to_langs):
157
- SYSTEM_MESSAGE = 'You are a Ugandan language assistant.'
158
- translations = []
159
- batch_size = 8
160
- device = next(model.parameters()).device
161
-
162
- instructions = [
163
- f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
164
- f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
165
- for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
166
- ]
167
-
168
- for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
169
- batch_instructions = instructions[i:i + batch_size]
170
- messages_list = [
171
- [
172
- {"role": "system", "content": SYSTEM_MESSAGE},
173
- {"role": "user", "content": instruction}
174
- ] for instruction in batch_instructions
175
- ]
176
-
177
- prompts = [
178
- tokenizer.apply_chat_template(
179
- messages, tokenize=False, add_generation_prompt=True
180
- ) for messages in messages_list
181
- ]
182
-
183
- inputs = tokenizer(
184
- prompts, return_tensors="pt",
185
- padding=True, padding_side='left', truncation=True
186
- ).to(device)
187
-
188
- with torch.no_grad():
189
- outputs = model.generate(
190
- **inputs, max_new_tokens=100,
191
- temperature=0.01,
192
- pad_token_id=tokenizer.eos_token_id
193
- )
194
-
195
- for j in range(len(outputs)):
196
- translation = tokenizer.decode(
197
- outputs[j, inputs['input_ids'].shape[1]:],
198
- skip_special_tokens=True
199
- )
200
- translations.append(translation)
201
-
202
- return translations
203
-
204
- return translation_fn
205
-
206
- def get_nllb_translation_fn(model, tokenizer):
207
- """Translation function for NLLB models."""
208
- def translation_fn(texts, source_langs, target_langs):
209
- translations = []
210
- language_tokens = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION
211
- device = next(model.parameters()).device
212
-
213
- for text, source_language, target_language in tqdm(
214
- zip(texts, source_langs, target_langs), total=len(texts), desc="NLLB Translation"):
215
-
216
- inputs = tokenizer(text, return_tensors="pt").to(device)
217
- inputs['input_ids'][0][0] = language_tokens[source_language]
218
-
219
- with torch.no_grad():
220
- translated_tokens = model.generate(
221
- **inputs,
222
- forced_bos_token_id=language_tokens[target_language],
223
- max_length=100,
224
- num_beams=5,
225
- )
226
-
227
- result = tokenizer.batch_decode(
228
- translated_tokens, skip_special_tokens=True)[0]
229
- translations.append(result)
230
-
231
- return translations
232
-
233
- return translation_fn
234
-
235
- def get_llama_translation_fn(model, tokenizer):
236
- """Translation function for Llama models."""
237
- def translation_fn(texts, from_langs, to_langs):
238
- DATE_TODAY = datetime.datetime.now().strftime("%d %b %Y")
239
- SYSTEM_MESSAGE = ''
240
- translations = []
241
- batch_size = 8
242
- device = next(model.parameters()).device
243
-
244
- instructions = [
245
- f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
246
- f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
247
- for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
248
- ]
249
-
250
- for i in tqdm(range(0, len(instructions), batch_size), desc="Llama Translation"):
251
- batch_instructions = instructions[i:i + batch_size]
252
- messages_list = [
253
- [
254
- {"role": "system", "content": SYSTEM_MESSAGE},
255
- {"role": "user", "content": instruction}
256
- ] for instruction in batch_instructions
257
- ]
258
-
259
- prompts = [
260
- tokenizer.apply_chat_template(
261
- messages, tokenize=False, add_generation_prompt=True,
262
- date_string=DATE_TODAY,
263
- ) for messages in messages_list
264
- ]
265
-
266
- inputs = tokenizer(
267
- prompts, return_tensors="pt",
268
- padding=True, padding_side='left',
269
- ).to(device)
270
-
271
- with torch.no_grad():
272
- outputs = model.generate(
273
- **inputs, max_new_tokens=100,
274
- temperature=0.01,
275
- pad_token_id=tokenizer.eos_token_id
276
- )
277
-
278
- for j in range(len(outputs)):
279
- translation = tokenizer.decode(
280
- outputs[j, inputs['input_ids'].shape[1]:],
281
- skip_special_tokens=True
282
- )
283
- translations.append(translation)
284
-
285
- return translations
286
-
287
- return translation_fn
288
-
289
- def get_generic_translation_fn(model, tokenizer):
290
- """Generic translation function for unknown model types."""
291
- def translation_fn(texts, from_langs, to_langs):
292
- translations = []
293
- device = next(model.parameters()).device
294
-
295
- for text, from_lang, to_lang in tqdm(zip(texts, from_langs, to_langs),
296
- desc="Generic Translation"):
297
- prompt = f"Translate from {from_lang} to {to_lang}: {text}"
298
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
299
-
300
- with torch.no_grad():
301
- outputs = model.generate(
302
- **inputs,
303
- max_new_tokens=100,
304
- temperature=0.7,
305
- pad_token_id=tokenizer.eos_token_id
306
- )
307
-
308
- translation = tokenizer.decode(
309
- outputs[0, inputs['input_ids'].shape[1]:],
310
- skip_special_tokens=True
311
- )
312
- translations.append(translation)
313
-
314
- return translations
315
-
316
- return translation_fn
317
-
318
- def calculate_metrics(reference: str, prediction: str) -> dict:
319
- """Calculate multiple translation quality metrics."""
320
- bleu = BLEU(effective_order=True)
321
- bleu_score = bleu.sentence_score(prediction, [reference]).score
322
-
323
- chrf = CHRF()
324
- chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0
325
-
326
- cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1)
327
-
328
- ref_words = reference.split()
329
- pred_words = prediction.split()
330
- wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1)
331
-
332
- len_ratio = len(prediction) / max(len(reference), 1)
333
-
334
- metrics = {
335
- "bleu": bleu_score,
336
- "chrf": chrf_score,
337
- "cer": cer,
338
- "wer": wer,
339
- "len_ratio": len_ratio,
340
- }
341
-
342
- try:
343
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
344
- rouge_scores = scorer.score(reference, prediction)
345
-
346
- metrics["rouge1"] = rouge_scores['rouge1'].fmeasure
347
- metrics["rouge2"] = rouge_scores['rouge2'].fmeasure
348
- metrics["rougeL"] = rouge_scores['rougeL'].fmeasure
349
-
350
- metrics["quality_score"] = (
351
- bleu_score/100 +
352
- chrf_score +
353
- (1-cer) +
354
- (1-wer) +
355
- rouge_scores['rouge1'].fmeasure +
356
- rouge_scores['rougeL'].fmeasure
357
- ) / 6
358
- except Exception as e:
359
- print(f"Error calculating ROUGE metrics: {e}")
360
- metrics["quality_score"] = (bleu_score/100 + chrf_score + (1-cer) + (1-wer)) / 4
361
-
362
- return metrics
363
-
364
- def evaluate_model_full(model, tokenizer, model_path: str, test_data) -> dict:
365
- """Complete model evaluation pipeline."""
366
-
367
- # Get translation function
368
- translation_fn = get_translation_function(model, tokenizer, model_path)
369
-
370
- # Generate predictions
371
- print("Generating translations...")
372
- predictions = translation_fn(
373
- list(test_data['source']),
374
- list(test_data['source.language']),
375
- list(test_data['target.language']),
376
- )
377
-
378
- # Calculate metrics by language pair
379
- print("Calculating metrics...")
380
- translation_subsets = defaultdict(list)
381
- for idx, row in test_data.iterrows():
382
- direction = row['source.language'] + '_to_' + row['target.language']
383
- row_dict = dict(row)
384
- row_dict['prediction'] = predictions[idx]
385
- translation_subsets[direction].append(row_dict)
386
-
387
- normalizer = BasicTextNormalizer()
388
- grouped_metrics = defaultdict(dict)
389
-
390
- for subset in translation_subsets.keys():
391
- subset_metrics = defaultdict(list)
392
- for example in translation_subsets[subset]:
393
- prediction = normalizer(str(example['prediction']))
394
- reference = normalizer(example['target'])
395
- metrics = calculate_metrics(reference, prediction)
396
- for m in metrics.keys():
397
- subset_metrics[m].append(metrics[m])
398
-
399
- for m in subset_metrics.keys():
400
- if subset_metrics[m]: # Check if list is not empty
401
- grouped_metrics[subset][m] = float(np.mean(subset_metrics[m]))
402
-
403
- # Calculate overall averages
404
- all_metrics = list(grouped_metrics.values())[0].keys() if grouped_metrics else []
405
- for m in all_metrics:
406
- metric_values = []
407
- for subset in translation_subsets.keys():
408
- if m in grouped_metrics[subset]:
409
- metric_values.append(grouped_metrics[subset][m])
410
- if metric_values:
411
- grouped_metrics['averages'][m] = float(np.mean(metric_values))
412
-
413
- return dict(grouped_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/validation.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/validation.py
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Dict, List, Tuple, Optional
5
+ import json
6
+ import io
7
+ from config import PREDICTION_FORMAT
8
+
9
+ def validate_file_format(file_content: bytes, filename: str) -> Dict:
10
+ """Validate uploaded file format and structure."""
11
+
12
+ try:
13
+ # Determine file type
14
+ if filename.endswith('.csv'):
15
+ df = pd.read_csv(io.BytesIO(file_content))
16
+ elif filename.endswith('.tsv'):
17
+ df = pd.read_csv(io.BytesIO(file_content), sep='\t')
18
+ elif filename.endswith('.json'):
19
+ data = json.loads(file_content.decode('utf-8'))
20
+ df = pd.DataFrame(data)
21
+ else:
22
+ return {
23
+ 'valid': False,
24
+ 'error': f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}"
25
+ }
26
+
27
+ # Check required columns
28
+ missing_cols = set(PREDICTION_FORMAT['required_columns']) - set(df.columns)
29
+ if missing_cols:
30
+ return {
31
+ 'valid': False,
32
+ 'error': f"Missing required columns: {', '.join(missing_cols)}"
33
+ }
34
+
35
+ # Basic data validation
36
+ if len(df) == 0:
37
+ return {
38
+ 'valid': False,
39
+ 'error': "File is empty"
40
+ }
41
+
42
+ # Check for required data
43
+ if df['sample_id'].isna().any():
44
+ return {
45
+ 'valid': False,
46
+ 'error': "Missing sample_id values found"
47
+ }
48
+
49
+ if df['prediction'].isna().any():
50
+ na_count = df['prediction'].isna().sum()
51
+ return {
52
+ 'valid': False,
53
+ 'error': f"Missing prediction values found ({na_count} empty predictions)"
54
+ }
55
+
56
+ # Check for duplicates
57
+ duplicates = df['sample_id'].duplicated()
58
+ if duplicates.any():
59
+ dup_count = duplicates.sum()
60
+ return {
61
+ 'valid': False,
62
+ 'error': f"Duplicate sample_id values found ({dup_count} duplicates)"
63
+ }
64
+
65
+ return {
66
+ 'valid': True,
67
+ 'dataframe': df,
68
+ 'row_count': len(df),
69
+ 'columns': list(df.columns)
70
+ }
71
+
72
+ except Exception as e:
73
+ return {
74
+ 'valid': False,
75
+ 'error': f"Error parsing file: {str(e)}"
76
+ }
77
+
78
+ def validate_predictions_content(predictions: pd.DataFrame) -> Dict:
79
+ """Validate prediction content quality."""
80
+
81
+ issues = []
82
+ warnings = []
83
+
84
+ # Check prediction text quality
85
+ empty_predictions = predictions['prediction'].str.strip().eq('').sum()
86
+ if empty_predictions > 0:
87
+ issues.append(f"{empty_predictions} empty predictions found")
88
+
89
+ # Check for suspiciously short predictions
90
+ short_predictions = (predictions['prediction'].str.len() < 3).sum()
91
+ if short_predictions > len(predictions) * 0.1: # More than 10%
92
+ warnings.append(f"{short_predictions} very short predictions (< 3 characters)")
93
+
94
+ # Check for suspiciously long predictions
95
+ long_predictions = (predictions['prediction'].str.len() > 500).sum()
96
+ if long_predictions > 0:
97
+ warnings.append(f"{long_predictions} very long predictions (> 500 characters)")
98
+
99
+ # Check for repeated predictions
100
+ duplicate_predictions = predictions['prediction'].duplicated().sum()
101
+ if duplicate_predictions > len(predictions) * 0.5: # More than 50%
102
+ warnings.append(f"{duplicate_predictions} duplicate prediction texts")
103
+
104
+ # Check for non-text content
105
+ non_text_pattern = r'^[A-Za-z\s\'".,!?;:()\-]+$'
106
+ non_text_predictions = ~predictions['prediction'].str.match(non_text_pattern, na=False)
107
+ if non_text_predictions.sum() > 0:
108
+ warnings.append(f"{non_text_predictions.sum()} predictions contain unusual characters")
109
+
110
+ return {
111
+ 'has_issues': len(issues) > 0,
112
+ 'issues': issues,
113
+ 'warnings': warnings,
114
+ 'quality_score': max(0, 1.0 - len(issues) * 0.2 - len(warnings) * 0.1)
115
+ }
116
+
117
+ def validate_against_test_set(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict:
118
+ """Validate predictions against the official test set."""
119
+
120
+ # Convert IDs to string for comparison
121
+ pred_ids = set(predictions['sample_id'].astype(str))
122
+ test_ids = set(test_set['sample_id'].astype(str))
123
+
124
+ # Check coverage
125
+ missing_ids = test_ids - pred_ids
126
+ extra_ids = pred_ids - test_ids
127
+ matching_ids = pred_ids & test_ids
128
+
129
+ coverage = len(matching_ids) / len(test_ids)
130
+
131
+ # Detailed coverage by language pair
132
+ pair_coverage = {}
133
+ for _, row in test_set.iterrows():
134
+ pair_key = f"{row['source_language']}_{row['target_language']}"
135
+ if pair_key not in pair_coverage:
136
+ pair_coverage[pair_key] = {'total': 0, 'covered': 0}
137
+
138
+ pair_coverage[pair_key]['total'] += 1
139
+ if str(row['sample_id']) in pred_ids:
140
+ pair_coverage[pair_key]['covered'] += 1
141
+
142
+ # Calculate pair-wise coverage rates
143
+ for pair_key in pair_coverage:
144
+ pair_info = pair_coverage[pair_key]
145
+ pair_info['coverage_rate'] = pair_info['covered'] / pair_info['total']
146
+
147
+ return {
148
+ 'overall_coverage': coverage,
149
+ 'missing_count': len(missing_ids),
150
+ 'extra_count': len(extra_ids),
151
+ 'matching_count': len(matching_ids),
152
+ 'is_complete': coverage == 1.0,
153
+ 'pair_coverage': pair_coverage,
154
+ 'missing_ids_sample': list(missing_ids)[:10], # First 10 for display
155
+ 'extra_ids_sample': list(extra_ids)[:10]
156
+ }
157
+
158
+ def generate_validation_report(
159
+ format_result: Dict,
160
+ content_result: Dict,
161
+ test_set_result: Dict,
162
+ model_name: str = ""
163
+ ) -> str:
164
+ """Generate human-readable validation report."""
165
+
166
+ report = []
167
+
168
+ # Header
169
+ report.append(f"# Validation Report: {model_name or 'Submission'}")
170
+ report.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
171
+ report.append("")
172
+
173
+ # File format validation
174
+ if format_result['valid']:
175
+ report.append("βœ… **File Format**: Valid")
176
+ report.append(f" - Rows: {format_result['row_count']:,}")
177
+ report.append(f" - Columns: {', '.join(format_result['columns'])}")
178
+ else:
179
+ report.append("❌ **File Format**: Invalid")
180
+ report.append(f" - Error: {format_result['error']}")
181
+ return "\n".join(report)
182
+
183
+ # Content validation
184
+ if content_result['has_issues']:
185
+ report.append("⚠️ **Content Quality**: Issues Found")
186
+ for issue in content_result['issues']:
187
+ report.append(f" - ❌ {issue}")
188
+ else:
189
+ report.append("βœ… **Content Quality**: Good")
190
+
191
+ if content_result['warnings']:
192
+ for warning in content_result['warnings']:
193
+ report.append(f" - ⚠️ {warning}")
194
+
195
+ # Test set validation
196
+ coverage = test_set_result['overall_coverage']
197
+ if coverage == 1.0:
198
+ report.append("βœ… **Test Set Coverage**: Complete")
199
+ elif coverage >= 0.95:
200
+ report.append("⚠️ **Test Set Coverage**: Nearly Complete")
201
+ else:
202
+ report.append("❌ **Test Set Coverage**: Incomplete")
203
+
204
+ report.append(f" - Coverage: {coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})")
205
+
206
+ if test_set_result['missing_count'] > 0:
207
+ report.append(f" - Missing: {test_set_result['missing_count']:,} samples")
208
+
209
+ if test_set_result['extra_count'] > 0:
210
+ report.append(f" - Extra: {test_set_result['extra_count']:,} samples")
211
+
212
+ # Language pair coverage
213
+ pair_cov = test_set_result['pair_coverage']
214
+ incomplete_pairs = [k for k, v in pair_cov.items() if v['coverage_rate'] < 1.0]
215
+
216
+ if incomplete_pairs:
217
+ report.append("")
218
+ report.append("**Incomplete Language Pairs:**")
219
+ for pair in incomplete_pairs[:5]: # Show first 5
220
+ info = pair_cov[pair]
221
+ src, tgt = pair.split('_')
222
+ report.append(f" - {src}β†’{tgt}: {info['covered']}/{info['total']} ({info['coverage_rate']:.1%})")
223
+
224
+ if len(incomplete_pairs) > 5:
225
+ report.append(f" - ... and {len(incomplete_pairs) - 5} more pairs")
226
+
227
+ # Final verdict
228
+ report.append("")
229
+ if format_result['valid'] and coverage >= 0.95 and not content_result['has_issues']:
230
+ report.append("πŸŽ‰ **Overall**: Ready for evaluation!")
231
+ elif format_result['valid'] and coverage >= 0.8:
232
+ report.append("⚠️ **Overall**: Can be evaluated with warnings")
233
+ else:
234
+ report.append("❌ **Overall**: Please fix issues before submission")
235
+
236
+ return "\n".join(report)
237
+
238
+ def validate_submission_complete(file_content: bytes, filename: str, test_set: pd.DataFrame, model_name: str = "") -> Dict:
239
+ """Complete validation pipeline for a submission."""
240
+
241
+ # Step 1: File format validation
242
+ format_result = validate_file_format(file_content, filename)
243
+ if not format_result['valid']:
244
+ return {
245
+ 'valid': False,
246
+ 'report': generate_validation_report(format_result, {}, {}, model_name),
247
+ 'predictions': None
248
+ }
249
+
250
+ predictions = format_result['dataframe']
251
+
252
+ # Step 2: Content validation
253
+ content_result = validate_predictions_content(predictions)
254
+
255
+ # Step 3: Test set validation
256
+ test_set_result = validate_against_test_set(predictions, test_set)
257
+
258
+ # Step 4: Generate report
259
+ report = generate_validation_report(format_result, content_result, test_set_result, model_name)
260
+
261
+ # Overall validity
262
+ is_valid = (
263
+ format_result['valid'] and
264
+ not content_result['has_issues'] and
265
+ test_set_result['overall_coverage'] >= 0.95
266
+ )
267
+
268
+ return {
269
+ 'valid': is_valid,
270
+ 'coverage': test_set_result['overall_coverage'],
271
+ 'report': report,
272
+ 'predictions': predictions,
273
+ 'pair_coverage': test_set_result['pair_coverage']
274
+ }