akera commited on
Commit
d82b528
·
verified ·
1 Parent(s): 7827065

Update src/test_set.py

Browse files
Files changed (1) hide show
  1. src/test_set.py +93 -516
src/test_set.py CHANGED
@@ -12,31 +12,22 @@ from config import (
12
  ALL_UG40_LANGUAGES,
13
  GOOGLE_SUPPORTED_LANGUAGES,
14
  EVALUATION_TRACKS,
15
- SAMPLE_SIZE_RECOMMENDATIONS,
16
- STATISTICAL_CONFIG,
17
  )
18
  import salt.dataset
19
- from src.utils import get_all_language_pairs, get_track_language_pairs
20
  from typing import Dict, List, Optional, Tuple
21
 
22
 
23
-
24
  # Local CSV filenames for persistence
25
- LOCAL_PUBLIC_CSV = "salt_test_set_scientific.csv"
26
- LOCAL_COMPLETE_CSV = "salt_complete_test_set_scientific.csv"
27
- LOCAL_TRACK_CSVS = {
28
- track: f"salt_test_set_{track}.csv" for track in EVALUATION_TRACKS.keys()
29
- }
30
 
31
 
32
- def generate_scientific_test_set(
33
- max_samples_per_pair: int = MAX_TEST_SAMPLES,
34
- stratified_sampling: bool = True,
35
- balance_tracks: bool = True,
36
- ) -> pd.DataFrame:
37
- """Generate scientifically rigorous test set with stratified sampling."""
38
 
39
- print("🔬 Generating scientific SALT test set...")
40
 
41
  try:
42
  # Build SALT dataset config
@@ -63,22 +54,12 @@ def generate_scientific_test_set(
63
  test_samples = []
64
  sample_id_counter = 1
65
 
66
- # Calculate target samples per track for balanced evaluation
67
- track_targets = calculate_track_sampling_targets(balance_tracks)
68
-
69
- # Generate samples for each language pair with stratified sampling
70
  for src_lang in ALL_UG40_LANGUAGES:
71
  for tgt_lang in ALL_UG40_LANGUAGES:
72
  if src_lang == tgt_lang:
73
  continue
74
 
75
- # Determine target sample size for this pair
76
- pair_targets = calculate_pair_sampling_targets(
77
- src_lang, tgt_lang, track_targets, max_samples_per_pair
78
- )
79
-
80
- target_samples = max(pair_targets.values()) if pair_targets else max_samples_per_pair
81
-
82
  # Filter for this language pair
83
  pair_data = full_data[
84
  (full_data["source.language"] == src_lang) &
@@ -89,24 +70,13 @@ def generate_scientific_test_set(
89
  print(f"⚠️ No data found for {src_lang} → {tgt_lang}")
90
  continue
91
 
92
- # Stratified sampling if enabled
93
- if stratified_sampling and len(pair_data) > target_samples:
94
- sampled = stratified_sample_pair_data(pair_data, target_samples)
95
- else:
96
- # Simple random sampling
97
- n_samples = min(len(pair_data), target_samples)
98
- sampled = pair_data.sample(n=n_samples, random_state=42)
99
 
100
  print(f"✅ {src_lang} → {tgt_lang}: {len(sampled)} samples")
101
 
102
  for _, row in sampled.iterrows():
103
- # Determine which tracks include this pair
104
- tracks_included = []
105
- for track_name, track_config in EVALUATION_TRACKS.items():
106
- if (src_lang in track_config["languages"] and
107
- tgt_lang in track_config["languages"]):
108
- tracks_included.append(track_name)
109
-
110
  test_samples.append({
111
  "sample_id": f"salt_{sample_id_counter:06d}",
112
  "source_text": row["source"],
@@ -118,10 +88,6 @@ def generate_scientific_test_set(
118
  src_lang in GOOGLE_SUPPORTED_LANGUAGES and
119
  tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
120
  ),
121
- "tracks_included": ",".join(tracks_included),
122
- "statistical_weight": calculate_statistical_weight(
123
- src_lang, tgt_lang, tracks_included
124
- ),
125
  })
126
  sample_id_counter += 1
127
 
@@ -130,313 +96,70 @@ def generate_scientific_test_set(
130
  if test_df.empty:
131
  raise ValueError("No test samples generated - check SALT dataset availability")
132
 
133
- # Validate scientific adequacy
134
- adequacy_report = validate_test_set_scientific_adequacy(test_df)
135
-
136
- print(f"✅ Generated scientific test set: {len(test_df):,} samples")
137
- print(f"📈 Test set adequacy: {adequacy_report['overall_adequacy']}")
138
 
139
  return test_df
140
 
141
  except Exception as e:
142
- print(f"❌ Error generating scientific test set: {e}")
143
  return pd.DataFrame(columns=[
144
  "sample_id", "source_text", "target_text", "source_language",
145
- "target_language", "domain", "google_comparable", "tracks_included",
146
- "statistical_weight"
147
  ])
148
 
149
 
150
- def calculate_track_sampling_targets(balance_tracks: bool) -> Dict[str, int]:
151
- """Calculate target sample sizes for each track to ensure statistical adequacy."""
152
-
153
- track_targets = {}
154
-
155
- for track_name, track_config in EVALUATION_TRACKS.items():
156
- # Base requirement from config
157
- min_per_pair = track_config["min_samples_per_pair"]
158
-
159
- # Number of language pairs in this track
160
- n_pairs = len(track_config["languages"]) * (len(track_config["languages"]) - 1)
161
-
162
- # Calculate total samples needed for statistical adequacy
163
- if balance_tracks:
164
- # Use publication-quality recommendation
165
- target_per_pair = max(
166
- min_per_pair,
167
- SAMPLE_SIZE_RECOMMENDATIONS["publication_quality"] // n_pairs
168
- )
169
- else:
170
- target_per_pair = min_per_pair
171
-
172
- track_targets[track_name] = target_per_pair * n_pairs
173
-
174
- print(f"📊 {track_name}: targeting {target_per_pair} samples/pair × {n_pairs} pairs = {track_targets[track_name]} total")
175
-
176
- return track_targets
177
-
178
-
179
- def calculate_pair_sampling_targets(
180
- src_lang: str, tgt_lang: str, track_targets: Dict[str, int], max_samples: int
181
- ) -> Dict[str, int]:
182
- """Calculate sampling targets for a specific language pair across tracks."""
183
-
184
- pair_targets = {}
185
-
186
- for track_name, track_config in EVALUATION_TRACKS.items():
187
- if (src_lang in track_config["languages"] and
188
- tgt_lang in track_config["languages"]):
189
-
190
- n_pairs_in_track = len(track_config["languages"]) * (len(track_config["languages"]) - 1)
191
- target_per_pair = track_targets[track_name] // n_pairs_in_track
192
-
193
- pair_targets[track_name] = min(target_per_pair, max_samples)
194
-
195
- return pair_targets
196
-
197
-
198
- def stratified_sample_pair_data(pair_data: pd.DataFrame, target_samples: int) -> pd.DataFrame:
199
- """Perform stratified sampling on pair data to ensure representativeness."""
200
-
201
- # Try to stratify by domain if available
202
- if "domain" in pair_data.columns and pair_data["domain"].nunique() > 1:
203
- # Sample proportionally from each domain
204
- domain_counts = pair_data["domain"].value_counts()
205
- sampled_parts = []
206
-
207
- for domain, count in domain_counts.items():
208
- domain_data = pair_data[pair_data["domain"] == domain]
209
-
210
- # Calculate proportional sample size
211
- proportion = count / len(pair_data)
212
- domain_target = max(1, int(target_samples * proportion))
213
- domain_target = min(domain_target, len(domain_data))
214
-
215
- if len(domain_data) >= domain_target:
216
- domain_sample = domain_data.sample(n=domain_target, random_state=42)
217
- sampled_parts.append(domain_sample)
218
-
219
- if sampled_parts:
220
- stratified_sample = pd.concat(sampled_parts, ignore_index=True)
221
-
222
- # If we didn't get enough samples, fill with random sampling
223
- if len(stratified_sample) < target_samples:
224
- remaining_data = pair_data[~pair_data.index.isin(stratified_sample.index)]
225
- additional_needed = target_samples - len(stratified_sample)
226
-
227
- if len(remaining_data) >= additional_needed:
228
- additional_sample = remaining_data.sample(n=additional_needed, random_state=42)
229
- stratified_sample = pd.concat([stratified_sample, additional_sample], ignore_index=True)
230
-
231
- return stratified_sample.head(target_samples)
232
-
233
- # Fallback to simple random sampling
234
- return pair_data.sample(n=min(target_samples, len(pair_data)), random_state=42)
235
-
236
-
237
- def calculate_statistical_weight(
238
- src_lang: str, tgt_lang: str, tracks_included: List[str]
239
- ) -> float:
240
- """Calculate statistical weight for a sample based on track inclusion."""
241
 
242
- # Base weight
243
- weight = 1.0
244
 
245
- # Higher weight for samples in multiple tracks (more valuable)
246
- weight *= len(tracks_included)
247
-
248
- # Higher weight for Google-comparable pairs (enable baseline comparison)
249
- if (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
250
- tgt_lang in GOOGLE_SUPPORTED_LANGUAGES):
251
- weight *= 1.5
252
-
253
- # Normalize to reasonable range
254
- return min(weight, 5.0)
255
-
256
-
257
- def validate_test_set_scientific_adequacy(test_df: pd.DataFrame) -> Dict:
258
- """Validate that the test set meets scientific adequacy requirements."""
259
-
260
- adequacy_report = {
261
- "overall_adequacy": "insufficient",
262
- "track_adequacy": {},
263
- "issues": [],
264
- "recommendations": [],
265
- "statistics": {},
266
- }
267
-
268
- if test_df.empty:
269
- adequacy_report["issues"].append("Test set is empty")
270
- return adequacy_report
271
-
272
- # Check each track
273
- track_adequacies = []
274
-
275
- for track_name, track_config in EVALUATION_TRACKS.items():
276
- track_languages = track_config["languages"]
277
- min_per_pair = track_config["min_samples_per_pair"]
278
-
279
- # Filter to track data
280
- track_data = test_df[
281
- (test_df["source_language"].isin(track_languages)) &
282
- (test_df["target_language"].isin(track_languages))
283
- ]
284
-
285
- # Analyze pair coverage
286
- pair_counts = {}
287
- for src in track_languages:
288
- for tgt in track_languages:
289
- if src == tgt:
290
- continue
291
-
292
- pair_samples = track_data[
293
- (track_data["source_language"] == src) &
294
- (track_data["target_language"] == tgt)
295
- ]
296
- pair_counts[f"{src}_{tgt}"] = len(pair_samples)
297
-
298
- # Calculate adequacy metrics
299
- total_pairs = len(pair_counts)
300
- adequate_pairs = sum(1 for count in pair_counts.values() if count >= min_per_pair)
301
- adequacy_rate = adequate_pairs / max(total_pairs, 1)
302
-
303
- # Determine track adequacy level
304
- if adequacy_rate >= 0.9:
305
- track_adequacy = "excellent"
306
- elif adequacy_rate >= 0.8:
307
- track_adequacy = "good"
308
- elif adequacy_rate >= 0.6:
309
- track_adequacy = "fair"
310
- else:
311
- track_adequacy = "insufficient"
312
-
313
- adequacy_report["track_adequacy"][track_name] = {
314
- "adequacy": track_adequacy,
315
- "adequacy_rate": adequacy_rate,
316
- "total_samples": len(track_data),
317
- "total_pairs": total_pairs,
318
- "adequate_pairs": adequate_pairs,
319
- "min_samples_per_pair": min_per_pair,
320
- "pair_counts": pair_counts,
321
- }
322
-
323
- track_adequacies.append(track_adequacy)
324
-
325
- # Add specific issues
326
- if track_adequacy == "insufficient":
327
- inadequate_pairs = [k for k, v in pair_counts.items() if v < min_per_pair]
328
- adequacy_report["issues"].append(
329
- f"{track_name}: {len(inadequate_pairs)} pairs below minimum"
330
- )
331
-
332
- # Overall adequacy assessment
333
- if all(adequacy in ["excellent", "good"] for adequacy in track_adequacies):
334
- adequacy_report["overall_adequacy"] = "excellent"
335
- elif all(adequacy in ["excellent", "good", "fair"] for adequacy in track_adequacies):
336
- adequacy_report["overall_adequacy"] = "good"
337
- elif any(adequacy in ["good", "fair"] for adequacy in track_adequacies):
338
- adequacy_report["overall_adequacy"] = "fair"
339
- else:
340
- adequacy_report["overall_adequacy"] = "insufficient"
341
-
342
- # Overall statistics
343
- adequacy_report["statistics"] = {
344
- "total_samples": len(test_df),
345
- "total_language_pairs": len(test_df.groupby(["source_language", "target_language"])),
346
- "google_comparable_samples": int(test_df["google_comparable"].sum()),
347
- "domain_distribution": test_df["domain"].value_counts().to_dict(),
348
- "track_sample_distribution": {
349
- track: adequacy_report["track_adequacy"][track]["total_samples"]
350
- for track in EVALUATION_TRACKS.keys()
351
- },
352
- }
353
-
354
- # Generate recommendations
355
- if adequacy_report["overall_adequacy"] in ["insufficient", "fair"]:
356
- adequacy_report["recommendations"].append(
357
- "Consider increasing sample size for better statistical power"
358
- )
359
-
360
- if adequacy_report["statistics"]["google_comparable_samples"] < 1000:
361
- adequacy_report["recommendations"].append(
362
- "More Google-comparable samples recommended for baseline comparison"
363
- )
364
-
365
- return adequacy_report
366
-
367
-
368
- def _generate_and_save_scientific_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]:
369
- """Generate and save both public and complete versions of the scientific test set."""
370
-
371
- print("🔬 Generating and saving scientific test sets...")
372
-
373
- full_df = generate_scientific_test_set()
374
 
375
  if full_df.empty:
376
- print("❌ Failed to generate scientific test set")
377
  empty_public = pd.DataFrame(columns=[
378
  "sample_id", "source_text", "source_language",
379
- "target_language", "domain", "google_comparable",
380
- "tracks_included", "statistical_weight"
381
  ])
382
  empty_complete = pd.DataFrame(columns=[
383
  "sample_id", "source_text", "target_text", "source_language",
384
- "target_language", "domain", "google_comparable",
385
- "tracks_included", "statistical_weight"
386
  ])
387
  return empty_public, empty_complete
388
 
389
  # Public version (no target_text)
390
  public_df = full_df[[
391
  "sample_id", "source_text", "source_language",
392
- "target_language", "domain", "google_comparable",
393
- "tracks_included", "statistical_weight"
394
  ]].copy()
395
 
396
- # Save main versions
397
  try:
398
  public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
399
  full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
400
- print(f"✅ Saved main test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
401
  except Exception as e:
402
- print(f"⚠️ Error saving main CSVs: {e}")
403
-
404
- # Save track-specific versions for easier analysis
405
- for track_name, track_config in EVALUATION_TRACKS.items():
406
- try:
407
- track_languages = track_config["languages"]
408
- track_public = public_df[
409
- (public_df["source_language"].isin(track_languages)) &
410
- (public_df["target_language"].isin(track_languages))
411
- ]
412
-
413
- track_filename = LOCAL_TRACK_CSVS[track_name]
414
- track_public.to_csv(track_filename, index=False)
415
- print(f"✅ Saved {track_name} track: {track_filename} ({len(track_public):,} samples)")
416
-
417
- except Exception as e:
418
- print(f"⚠️ Error saving {track_name} track CSV: {e}")
419
 
420
  return public_df, full_df
421
 
422
 
423
- def get_public_test_set_scientific() -> pd.DataFrame:
424
- """Load the scientific public test set with enhanced fallback logic."""
425
 
426
  # 1) Try HF Hub
427
  try:
428
- print("📥 Attempting to load scientific test set from HF Hub...")
429
- ds = load_dataset(TEST_SET_DATASET + "-scientific", split="train", token=HF_TOKEN)
430
  df = ds.to_pandas()
431
 
432
- # Validate scientific structure
433
- required_cols = ["sample_id", "source_text", "source_language", "target_language",
434
- "tracks_included", "statistical_weight"]
435
  if all(col in df.columns for col in required_cols):
436
- print(f"✅ Loaded scientific test set from HF Hub ({len(df):,} samples)")
437
  return df
438
  else:
439
- print("⚠️ HF Hub test set missing scientific columns, regenerating...")
440
 
441
  except Exception as e:
442
  print(f"⚠️ HF Hub load failed: {e}")
@@ -447,35 +170,34 @@ def get_public_test_set_scientific() -> pd.DataFrame:
447
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
448
  required_cols = ["sample_id", "source_text", "source_language", "target_language"]
449
  if all(col in df.columns for col in required_cols):
450
- print(f"✅ Loaded scientific test set from local CSV ({len(df):,} samples)")
451
  return df
452
  else:
453
  print("⚠️ Local CSV has invalid structure, regenerating...")
454
  except Exception as e:
455
- print(f"⚠️ Failed to read local scientific CSV: {e}")
456
 
457
  # 3) Regenerate & save
458
- print("🔄 Generating new scientific test set...")
459
- public_df, _ = _generate_and_save_scientific_test_set()
460
  return public_df
461
 
462
 
463
- def get_complete_test_set_scientific() -> pd.DataFrame:
464
- """Load the complete scientific test set with targets."""
465
 
466
  # 1) Try HF Hub private
467
  try:
468
- print("📥 Attempting to load complete scientific test set from HF Hub...")
469
- ds = load_dataset(TEST_SET_DATASET + "-scientific-private", split="train", token=HF_TOKEN)
470
  df = ds.to_pandas()
471
 
472
- required_cols = ["sample_id", "source_text", "target_text", "source_language",
473
- "target_language", "tracks_included", "statistical_weight"]
474
  if all(col in df.columns for col in required_cols):
475
- print(f"✅ Loaded complete scientific test set from HF Hub ({len(df):,} samples)")
476
  return df
477
  else:
478
- print("⚠️ HF Hub complete test set missing scientific columns, regenerating...")
479
 
480
  except Exception as e:
481
  print(f"⚠️ HF Hub private load failed: {e}")
@@ -486,63 +208,31 @@ def get_complete_test_set_scientific() -> pd.DataFrame:
486
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
487
  required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
488
  if all(col in df.columns for col in required_cols):
489
- print(f"✅ Loaded complete scientific test set from local CSV ({len(df):,} samples)")
490
  return df
491
  else:
492
  print("⚠️ Local complete CSV has invalid structure, regenerating...")
493
  except Exception as e:
494
- print(f"⚠️ Failed to read local complete scientific CSV: {e}")
495
 
496
  # 3) Regenerate & save
497
- print("🔄 Generating new complete scientific test set...")
498
- _, complete_df = _generate_and_save_scientific_test_set()
499
  return complete_df
500
 
501
 
502
- def get_track_test_set(track: str) -> pd.DataFrame:
503
- """Get test set filtered for a specific track."""
504
-
505
- if track not in EVALUATION_TRACKS:
506
- print(f"❌ Unknown track: {track}")
507
- return pd.DataFrame()
508
-
509
- # Try track-specific CSV first
510
- track_csv = LOCAL_TRACK_CSVS.get(track)
511
- if track_csv and os.path.exists(track_csv):
512
- try:
513
- df = pd.read_csv(track_csv)
514
- print(f"✅ Loaded {track} test set from track-specific CSV ({len(df):,} samples)")
515
- return df
516
- except Exception as e:
517
- print(f"⚠️ Failed to read {track} CSV: {e}")
518
-
519
- # Fallback to filtering main test set
520
- public_df = get_public_test_set_scientific()
521
-
522
- if public_df.empty:
523
- return pd.DataFrame()
524
-
525
- track_languages = EVALUATION_TRACKS[track]["languages"]
526
- track_df = public_df[
527
- (public_df["source_language"].isin(track_languages)) &
528
- (public_df["target_language"].isin(track_languages))
529
- ]
530
-
531
- print(f"✅ Filtered {track} test set from main set ({len(track_df):,} samples)")
532
- return track_df
533
-
534
-
535
- def create_test_set_download_scientific() -> Tuple[str, Dict]:
536
- """Create scientific test set download with comprehensive metadata."""
537
 
538
- public_df = get_public_test_set_scientific()
539
 
540
  if public_df.empty:
541
  stats = {
542
  "total_samples": 0,
543
  "track_breakdown": {},
544
- "adequacy_assessment": "insufficient",
545
- "scientific_metadata": {},
 
546
  }
547
  return LOCAL_PUBLIC_CSV, stats
548
 
@@ -552,7 +242,7 @@ def create_test_set_download_scientific() -> Tuple[str, Dict]:
552
  try:
553
  public_df.to_csv(download_path, index=False)
554
  except Exception as e:
555
- print(f"⚠️ Error updating scientific CSV: {e}")
556
 
557
  # Calculate comprehensive statistics
558
  try:
@@ -560,7 +250,7 @@ def create_test_set_download_scientific() -> Tuple[str, Dict]:
560
  stats = {
561
  "total_samples": len(public_df),
562
  "languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))),
563
- "domains": public_df["domain"].unique().tolist() if "domain" in public_df.columns else ["general"],
564
  }
565
 
566
  # Track-specific breakdown
@@ -573,11 +263,9 @@ def create_test_set_download_scientific() -> Tuple[str, Dict]:
573
  ]
574
 
575
  track_breakdown[track_name] = {
576
- "name": track_config["name"],
577
  "total_samples": len(track_data),
578
  "language_pairs": len(track_data.groupby(["source_language", "target_language"])),
579
- "min_samples_per_pair": track_config["min_samples_per_pair"],
580
- "statistical_adequacy": len(track_data) >= track_config["min_samples_per_pair"] * len(track_languages) * (len(track_languages) - 1),
581
  }
582
 
583
  stats["track_breakdown"] = track_breakdown
@@ -585,52 +273,56 @@ def create_test_set_download_scientific() -> Tuple[str, Dict]:
585
  # Google-comparable statistics
586
  if "google_comparable" in public_df.columns:
587
  stats["google_comparable_samples"] = int(public_df["google_comparable"].sum())
588
- stats["google_comparable_rate"] = float(public_df["google_comparable"].mean())
589
  else:
590
  stats["google_comparable_samples"] = 0
591
- stats["google_comparable_rate"] = 0.0
592
-
593
- # Scientific adequacy assessment
594
- adequacy_report = validate_test_set_scientific_adequacy(public_df)
595
- stats["adequacy_assessment"] = adequacy_report["overall_adequacy"]
596
- stats["adequacy_details"] = adequacy_report
597
-
598
- # Scientific metadata
599
- stats["scientific_metadata"] = {
600
- "stratified_sampling": True,
601
- "statistical_weighting": "statistical_weight" in public_df.columns,
602
- "track_balanced": True,
603
- "confidence_level": STATISTICAL_CONFIG["confidence_level"],
604
- "recommended_for": [
605
- track for track, info in track_breakdown.items()
606
- if info["statistical_adequacy"]
607
- ],
608
- }
609
 
610
  except Exception as e:
611
- print(f"⚠️ Error calculating scientific stats: {e}")
612
  stats = {
613
  "total_samples": len(public_df),
614
  "track_breakdown": {},
615
- "adequacy_assessment": "unknown",
616
- "scientific_metadata": {},
 
617
  }
618
 
619
  return download_path, stats
620
 
621
 
622
- def validate_test_set_integrity_scientific() -> Dict:
623
- """Comprehensive validation of scientific test set integrity."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  try:
626
- public_df = get_public_test_set_scientific()
627
- complete_df = get_complete_test_set_scientific()
628
 
629
  if public_df.empty or complete_df.empty:
630
  return {
631
  "alignment_check": False,
632
  "total_samples": 0,
633
- "scientific_adequacy": {},
634
  "track_analysis": {},
635
  "error": "Test sets are empty or could not be loaded",
636
  }
@@ -642,7 +334,6 @@ def validate_test_set_integrity_scientific() -> Dict:
642
  track_analysis = {}
643
  for track_name, track_config in EVALUATION_TRACKS.items():
644
  track_languages = track_config["languages"]
645
- min_per_pair = track_config["min_samples_per_pair"]
646
 
647
  # Analyze public set for this track
648
  track_public = public_df[
@@ -656,140 +347,26 @@ def validate_test_set_integrity_scientific() -> Dict:
656
  (complete_df["target_language"].isin(track_languages))
657
  ]
658
 
659
- # Calculate coverage
660
- pair_coverage = {}
661
- for src in track_languages:
662
- for tgt in track_languages:
663
- if src == tgt:
664
- continue
665
-
666
- public_subset = track_public[
667
- (track_public["source_language"] == src) &
668
- (track_public["target_language"] == tgt)
669
- ]
670
-
671
- complete_subset = track_complete[
672
- (track_complete["source_language"] == src) &
673
- (track_complete["target_language"] == tgt)
674
- ]
675
-
676
- pair_coverage[f"{src}_{tgt}"] = {
677
- "public_count": len(public_subset),
678
- "complete_count": len(complete_subset),
679
- "alignment": len(public_subset) == len(complete_subset),
680
- "meets_minimum": len(public_subset) >= min_per_pair,
681
- }
682
-
683
- # Track summary
684
- total_pairs = len(pair_coverage)
685
- adequate_pairs = sum(1 for info in pair_coverage.values() if info["meets_minimum"])
686
- aligned_pairs = sum(1 for info in pair_coverage.values() if info["alignment"])
687
-
688
  track_analysis[track_name] = {
689
- "total_pairs": total_pairs,
690
- "adequate_pairs": adequate_pairs,
691
- "aligned_pairs": aligned_pairs,
692
- "adequacy_rate": adequate_pairs / max(total_pairs, 1),
693
- "alignment_rate": aligned_pairs / max(total_pairs, 1),
694
- "pair_coverage": pair_coverage,
695
- "statistical_power": calculate_track_statistical_power(track_public, track_config),
696
  }
697
 
698
- # Overall scientific adequacy
699
- adequacy_report = validate_test_set_scientific_adequacy(public_df)
700
-
701
  return {
702
  "alignment_check": public_ids <= private_ids,
703
  "total_samples": len(public_df),
704
  "track_analysis": track_analysis,
705
- "scientific_adequacy": adequacy_report,
706
  "public_samples": len(public_df),
707
  "private_samples": len(complete_df),
708
  "id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0,
709
- "integrity_score": calculate_integrity_score(track_analysis, adequacy_report),
710
  }
711
 
712
  except Exception as e:
713
  return {
714
  "alignment_check": False,
715
  "total_samples": 0,
716
- "scientific_adequacy": {},
717
  "track_analysis": {},
718
  "error": f"Validation failed: {str(e)}",
719
- }
720
-
721
-
722
- def calculate_track_statistical_power(track_data: pd.DataFrame, track_config: Dict) -> float:
723
- """Calculate statistical power estimate for a track."""
724
-
725
- if track_data.empty:
726
- return 0.0
727
-
728
- # Simple power estimation based on sample size
729
- min_required = track_config["min_samples_per_pair"]
730
- languages = track_config["languages"]
731
- total_pairs = len(languages) * (len(languages) - 1)
732
-
733
- # Calculate average samples per pair
734
- pair_counts = []
735
- for src in languages:
736
- for tgt in languages:
737
- if src == tgt:
738
- continue
739
-
740
- pair_samples = track_data[
741
- (track_data["source_language"] == src) &
742
- (track_data["target_language"] == tgt)
743
- ]
744
- pair_counts.append(len(pair_samples))
745
-
746
- if not pair_counts:
747
- return 0.0
748
-
749
- avg_samples_per_pair = np.mean(pair_counts)
750
-
751
- # Rough power estimation (0.8 power at 2x minimum, 0.95 at 4x minimum)
752
- if avg_samples_per_pair >= min_required * 4:
753
- return 0.95
754
- elif avg_samples_per_pair >= min_required * 2:
755
- return 0.8
756
- elif avg_samples_per_pair >= min_required:
757
- return 0.6
758
- else:
759
- return max(0.0, avg_samples_per_pair / min_required * 0.6)
760
-
761
-
762
- def calculate_integrity_score(track_analysis: Dict, adequacy_report: Dict) -> float:
763
- """Calculate overall integrity score for the test set."""
764
-
765
- if not track_analysis or not adequacy_report:
766
- return 0.0
767
-
768
- # Track adequacy scores
769
- track_scores = []
770
- for track_info in track_analysis.values():
771
- adequacy_rate = track_info.get("adequacy_rate", 0.0)
772
- alignment_rate = track_info.get("alignment_rate", 0.0)
773
- track_score = (adequacy_rate + alignment_rate) / 2
774
- track_scores.append(track_score)
775
-
776
- # Overall adequacy mapping
777
- adequacy_mapping = {
778
- "excellent": 1.0,
779
- "good": 0.8,
780
- "fair": 0.6,
781
- "insufficient": 0.2,
782
- }
783
-
784
- overall_adequacy_score = adequacy_mapping.get(
785
- adequacy_report.get("overall_adequacy", "insufficient"), 0.2
786
- )
787
-
788
- # Combined score
789
- if track_scores:
790
- track_avg = np.mean(track_scores)
791
- integrity_score = (track_avg + overall_adequacy_score) / 2
792
- else:
793
- integrity_score = overall_adequacy_score
794
-
795
- return float(integrity_score)
 
12
  ALL_UG40_LANGUAGES,
13
  GOOGLE_SUPPORTED_LANGUAGES,
14
  EVALUATION_TRACKS,
15
+ LANGUAGE_NAMES,
 
16
  )
17
  import salt.dataset
18
+ from src.utils import get_all_language_pairs
19
  from typing import Dict, List, Optional, Tuple
20
 
21
 
 
22
  # Local CSV filenames for persistence
23
+ LOCAL_PUBLIC_CSV = "salt_test_set.csv"
24
+ LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
 
 
 
25
 
26
 
27
+ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
28
+ """Generate test set from SALT dataset."""
 
 
 
 
29
 
30
+ print("🔬 Generating SALT test set...")
31
 
32
  try:
33
  # Build SALT dataset config
 
54
  test_samples = []
55
  sample_id_counter = 1
56
 
57
+ # Generate samples for each language pair
 
 
 
58
  for src_lang in ALL_UG40_LANGUAGES:
59
  for tgt_lang in ALL_UG40_LANGUAGES:
60
  if src_lang == tgt_lang:
61
  continue
62
 
 
 
 
 
 
 
 
63
  # Filter for this language pair
64
  pair_data = full_data[
65
  (full_data["source.language"] == src_lang) &
 
70
  print(f"⚠️ No data found for {src_lang} → {tgt_lang}")
71
  continue
72
 
73
+ # Sample data for this pair
74
+ n_samples = min(len(pair_data), max_samples_per_pair)
75
+ sampled = pair_data.sample(n=n_samples, random_state=42)
 
 
 
 
76
 
77
  print(f"✅ {src_lang} → {tgt_lang}: {len(sampled)} samples")
78
 
79
  for _, row in sampled.iterrows():
 
 
 
 
 
 
 
80
  test_samples.append({
81
  "sample_id": f"salt_{sample_id_counter:06d}",
82
  "source_text": row["source"],
 
88
  src_lang in GOOGLE_SUPPORTED_LANGUAGES and
89
  tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
90
  ),
 
 
 
 
91
  })
92
  sample_id_counter += 1
93
 
 
96
  if test_df.empty:
97
  raise ValueError("No test samples generated - check SALT dataset availability")
98
 
99
+ print(f"✅ Generated test set: {len(test_df):,} samples")
 
 
 
 
100
 
101
  return test_df
102
 
103
  except Exception as e:
104
+ print(f"❌ Error generating test set: {e}")
105
  return pd.DataFrame(columns=[
106
  "sample_id", "source_text", "target_text", "source_language",
107
+ "target_language", "domain", "google_comparable"
 
108
  ])
109
 
110
 
111
+ def _generate_and_save_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]:
112
+ """Generate and save both public and complete versions of the test set."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ print("🔬 Generating and saving test sets...")
 
115
 
116
+ full_df = generate_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  if full_df.empty:
119
+ print("❌ Failed to generate test set")
120
  empty_public = pd.DataFrame(columns=[
121
  "sample_id", "source_text", "source_language",
122
+ "target_language", "domain", "google_comparable"
 
123
  ])
124
  empty_complete = pd.DataFrame(columns=[
125
  "sample_id", "source_text", "target_text", "source_language",
126
+ "target_language", "domain", "google_comparable"
 
127
  ])
128
  return empty_public, empty_complete
129
 
130
  # Public version (no target_text)
131
  public_df = full_df[[
132
  "sample_id", "source_text", "source_language",
133
+ "target_language", "domain", "google_comparable"
 
134
  ]].copy()
135
 
136
+ # Save versions
137
  try:
138
  public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
139
  full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
140
+ print(f"✅ Saved test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
141
  except Exception as e:
142
+ print(f"⚠️ Error saving CSVs: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  return public_df, full_df
145
 
146
 
147
+ def get_public_test_set() -> pd.DataFrame:
148
+ """Load the public test set with enhanced fallback logic."""
149
 
150
  # 1) Try HF Hub
151
  try:
152
+ print("📥 Attempting to load test set from HF Hub...")
153
+ ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
154
  df = ds.to_pandas()
155
 
156
+ # Validate structure
157
+ required_cols = ["sample_id", "source_text", "source_language", "target_language"]
 
158
  if all(col in df.columns for col in required_cols):
159
+ print(f"✅ Loaded test set from HF Hub ({len(df):,} samples)")
160
  return df
161
  else:
162
+ print("⚠️ HF Hub test set missing columns, regenerating...")
163
 
164
  except Exception as e:
165
  print(f"⚠️ HF Hub load failed: {e}")
 
170
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
171
  required_cols = ["sample_id", "source_text", "source_language", "target_language"]
172
  if all(col in df.columns for col in required_cols):
173
+ print(f"✅ Loaded test set from local CSV ({len(df):,} samples)")
174
  return df
175
  else:
176
  print("⚠️ Local CSV has invalid structure, regenerating...")
177
  except Exception as e:
178
+ print(f"⚠️ Failed to read local CSV: {e}")
179
 
180
  # 3) Regenerate & save
181
+ print("🔄 Generating new test set...")
182
+ public_df, _ = _generate_and_save_test_set()
183
  return public_df
184
 
185
 
186
+ def get_complete_test_set() -> pd.DataFrame:
187
+ """Load the complete test set with targets."""
188
 
189
  # 1) Try HF Hub private
190
  try:
191
+ print("📥 Attempting to load complete test set from HF Hub...")
192
+ ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
193
  df = ds.to_pandas()
194
 
195
+ required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
 
196
  if all(col in df.columns for col in required_cols):
197
+ print(f"✅ Loaded complete test set from HF Hub ({len(df):,} samples)")
198
  return df
199
  else:
200
+ print("⚠️ HF Hub complete test set missing columns, regenerating...")
201
 
202
  except Exception as e:
203
  print(f"⚠️ HF Hub private load failed: {e}")
 
208
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
209
  required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
210
  if all(col in df.columns for col in required_cols):
211
+ print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)")
212
  return df
213
  else:
214
  print("⚠️ Local complete CSV has invalid structure, regenerating...")
215
  except Exception as e:
216
+ print(f"⚠️ Failed to read local complete CSV: {e}")
217
 
218
  # 3) Regenerate & save
219
+ print("🔄 Generating new complete test set...")
220
+ _, complete_df = _generate_and_save_test_set()
221
  return complete_df
222
 
223
 
224
+ def create_test_set_download() -> Tuple[str, Dict]:
225
+ """Create test set download with comprehensive metadata."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ public_df = get_public_test_set()
228
 
229
  if public_df.empty:
230
  stats = {
231
  "total_samples": 0,
232
  "track_breakdown": {},
233
+ "languages": [],
234
+ "language_pairs": 0,
235
+ "google_comparable_samples": 0,
236
  }
237
  return LOCAL_PUBLIC_CSV, stats
238
 
 
242
  try:
243
  public_df.to_csv(download_path, index=False)
244
  except Exception as e:
245
+ print(f"⚠️ Error updating CSV: {e}")
246
 
247
  # Calculate comprehensive statistics
248
  try:
 
250
  stats = {
251
  "total_samples": len(public_df),
252
  "languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))),
253
+ "language_pairs": len(public_df.groupby(["source_language", "target_language"])),
254
  }
255
 
256
  # Track-specific breakdown
 
263
  ]
264
 
265
  track_breakdown[track_name] = {
 
266
  "total_samples": len(track_data),
267
  "language_pairs": len(track_data.groupby(["source_language", "target_language"])),
268
+ "languages": track_languages,
 
269
  }
270
 
271
  stats["track_breakdown"] = track_breakdown
 
273
  # Google-comparable statistics
274
  if "google_comparable" in public_df.columns:
275
  stats["google_comparable_samples"] = int(public_df["google_comparable"].sum())
 
276
  else:
277
  stats["google_comparable_samples"] = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  except Exception as e:
280
+ print(f"⚠️ Error calculating stats: {e}")
281
  stats = {
282
  "total_samples": len(public_df),
283
  "track_breakdown": {},
284
+ "languages": [],
285
+ "language_pairs": 0,
286
+ "google_comparable_samples": 0,
287
  }
288
 
289
  return download_path, stats
290
 
291
 
292
+ def get_track_test_set(track: str) -> pd.DataFrame:
293
+ """Get test set filtered for a specific track."""
294
+
295
+ if track not in EVALUATION_TRACKS:
296
+ print(f"❌ Unknown track: {track}")
297
+ return pd.DataFrame()
298
+
299
+ # Get main test set and filter
300
+ public_df = get_public_test_set()
301
+
302
+ if public_df.empty:
303
+ return pd.DataFrame()
304
+
305
+ track_languages = EVALUATION_TRACKS[track]["languages"]
306
+ track_df = public_df[
307
+ (public_df["source_language"].isin(track_languages)) &
308
+ (public_df["target_language"].isin(track_languages))
309
+ ]
310
+
311
+ print(f"✅ Filtered {track} test set: {len(track_df):,} samples")
312
+ return track_df
313
+
314
+
315
+ def validate_test_set_integrity() -> Dict:
316
+ """Validate test set integrity."""
317
 
318
  try:
319
+ public_df = get_public_test_set()
320
+ complete_df = get_complete_test_set()
321
 
322
  if public_df.empty or complete_df.empty:
323
  return {
324
  "alignment_check": False,
325
  "total_samples": 0,
 
326
  "track_analysis": {},
327
  "error": "Test sets are empty or could not be loaded",
328
  }
 
334
  track_analysis = {}
335
  for track_name, track_config in EVALUATION_TRACKS.items():
336
  track_languages = track_config["languages"]
 
337
 
338
  # Analyze public set for this track
339
  track_public = public_df[
 
347
  (complete_df["target_language"].isin(track_languages))
348
  ]
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  track_analysis[track_name] = {
351
+ "public_samples": len(track_public),
352
+ "complete_samples": len(track_complete),
353
+ "alignment": len(track_public) == len(track_complete),
354
+ "languages": track_languages,
 
 
 
355
  }
356
 
 
 
 
357
  return {
358
  "alignment_check": public_ids <= private_ids,
359
  "total_samples": len(public_df),
360
  "track_analysis": track_analysis,
 
361
  "public_samples": len(public_df),
362
  "private_samples": len(complete_df),
363
  "id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0,
 
364
  }
365
 
366
  except Exception as e:
367
  return {
368
  "alignment_check": False,
369
  "total_samples": 0,
 
370
  "track_analysis": {},
371
  "error": f"Validation failed: {str(e)}",
372
+ }