akera commited on
Commit
37b3c92
·
verified ·
1 Parent(s): a411078

Update src/test_set.py

Browse files
Files changed (1) hide show
  1. src/test_set.py +629 -156
src/test_set.py CHANGED
@@ -2,32 +2,42 @@
2
  import os
3
  import pandas as pd
4
  import yaml
 
5
  from datasets import load_dataset
6
  from config import (
7
  TEST_SET_DATASET,
8
  SALT_DATASET,
9
  MAX_TEST_SAMPLES,
10
  HF_TOKEN,
11
- MIN_SAMPLES_PER_PAIR,
12
  ALL_UG40_LANGUAGES,
13
- GOOGLE_SUPPORTED_LANGUAGES
 
 
 
14
  )
15
  import salt.dataset
16
- from src.utils import get_all_language_pairs
17
 
18
  # Local CSV filenames for persistence
19
- LOCAL_PUBLIC_CSV = "salt_test_set.csv"
20
- LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
 
 
 
21
 
22
- def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
23
- """
24
- Generate standardized test set from the SALT dataset.
25
- """
26
- print("🔄 Generating SALT test set from source dataset...")
 
 
 
 
27
 
28
  try:
29
- # Build SALT dataset config - using 'test' split for consistency
30
- dataset_config = f'''
31
  huggingface_load:
32
  path: {SALT_DATASET}
33
  name: text-all
@@ -39,7 +49,7 @@ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFr
39
  type: text
40
  language: {ALL_UG40_LANGUAGES}
41
  allow_same_src_and_tgt_language: False
42
- '''
43
 
44
  config = yaml.safe_load(dataset_config)
45
  print("📥 Loading SALT dataset...")
@@ -50,40 +60,65 @@ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFr
50
  test_samples = []
51
  sample_id_counter = 1
52
 
53
- # Generate samples for each language pair
 
 
 
54
  for src_lang in ALL_UG40_LANGUAGES:
55
  for tgt_lang in ALL_UG40_LANGUAGES:
56
  if src_lang == tgt_lang:
57
  continue
58
-
 
 
 
 
 
 
 
59
  # Filter for this language pair
60
  pair_data = full_data[
61
- (full_data['source.language'] == src_lang) &
62
- (full_data['target.language'] == tgt_lang)
63
  ]
64
 
65
  if pair_data.empty:
66
  print(f"⚠️ No data found for {src_lang} → {tgt_lang}")
67
  continue
68
 
69
- # Sample up to max_samples_per_pair
70
- n_samples = min(len(pair_data), max_samples_per_pair)
71
- sampled = pair_data.sample(n=n_samples, random_state=42)
 
 
 
 
72
 
73
- print(f"✅ {src_lang} → {tgt_lang}: {n_samples} samples")
74
 
75
  for _, row in sampled.iterrows():
 
 
 
 
 
 
 
76
  test_samples.append({
77
- 'sample_id': f"salt_{sample_id_counter:06d}",
78
- 'source_text': row['source'],
79
- 'target_text': row['target'],
80
- 'source_language': src_lang,
81
- 'target_language': tgt_lang,
82
- 'domain': row.get('domain', 'general'),
83
- 'google_comparable': (
84
  src_lang in GOOGLE_SUPPORTED_LANGUAGES and
85
  tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
86
- )
 
 
 
 
87
  })
88
  sample_id_counter += 1
89
 
@@ -91,78 +126,315 @@ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFr
91
 
92
  if test_df.empty:
93
  raise ValueError("No test samples generated - check SALT dataset availability")
94
-
95
- print(f"✅ Generated test set: {len(test_df):,} samples across {len(test_df.groupby(['source_language', 'target_language'])):,} pairs")
96
 
97
- # Add some statistics
98
- google_samples = test_df['google_comparable'].sum()
99
- unique_pairs = len(test_df.groupby(['source_language', 'target_language']))
100
 
101
- print(f"📈 Test set statistics:")
102
- print(f" - Total samples: {len(test_df):,}")
103
- print(f" - Language pairs: {unique_pairs}")
104
- print(f" - Google comparable: {google_samples:,} samples")
105
- print(f" - UG40 only: {len(test_df) - google_samples:,} samples")
106
 
107
  return test_df
108
 
109
  except Exception as e:
110
- print(f"❌ Error generating test set: {e}")
111
- # Return empty DataFrame with correct structure
112
  return pd.DataFrame(columns=[
113
- 'sample_id', 'source_text', 'target_text', 'source_language',
114
- 'target_language', 'domain', 'google_comparable'
 
115
  ])
116
 
117
- def _generate_and_save_test_set() -> tuple[pd.DataFrame, pd.DataFrame]:
118
- """
119
- Generate the full test set and persist both public and complete CSV files.
120
- """
121
- print("🔄 Generating and saving test sets...")
122
 
123
- full_df = generate_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  if full_df.empty:
126
- print("❌ Failed to generate test set")
127
- # Return empty DataFrames with correct structure
128
  empty_public = pd.DataFrame(columns=[
129
- 'sample_id', 'source_text', 'source_language',
130
- 'target_language', 'domain', 'google_comparable'
 
131
  ])
132
  empty_complete = pd.DataFrame(columns=[
133
- 'sample_id', 'source_text', 'target_text', 'source_language',
134
- 'target_language', 'domain', 'google_comparable'
 
135
  ])
136
  return empty_public, empty_complete
137
 
138
  # Public version (no target_text)
139
  public_df = full_df[[
140
- 'sample_id', 'source_text', 'source_language',
141
- 'target_language', 'domain', 'google_comparable'
 
142
  ]].copy()
143
 
144
- # Save both versions
145
  try:
146
  public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
147
  full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
148
- print(f"✅ Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
149
  except Exception as e:
150
- print(f"⚠️ Error saving CSVs: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  return public_df, full_df
153
 
154
- def get_public_test_set() -> pd.DataFrame:
155
- """
156
- Load the public test set (without targets).
157
- Tries HF Hub → local CSV → regenerate.
158
- """
159
  # 1) Try HF Hub
160
  try:
161
- print("📥 Attempting to load public test set from HF Hub...")
162
- ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
163
  df = ds.to_pandas()
164
- print(f"✅ Loaded public test set from HF Hub ({len(df):,} samples)")
165
- return df
 
 
 
 
 
 
 
 
166
  except Exception as e:
167
  print(f"⚠️ HF Hub load failed: {e}")
168
 
@@ -170,150 +442,351 @@ def get_public_test_set() -> pd.DataFrame:
170
  if os.path.exists(LOCAL_PUBLIC_CSV):
171
  try:
172
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
173
- print(f"✅ Loaded public test set from local CSV ({len(df):,} samples)")
174
- # Validate basic structure
175
- required_cols = ['sample_id', 'source_text', 'source_language', 'target_language']
176
  if all(col in df.columns for col in required_cols):
 
177
  return df
178
  else:
179
  print("⚠️ Local CSV has invalid structure, regenerating...")
180
  except Exception as e:
181
- print(f"⚠️ Failed to read local CSV: {e}")
182
 
183
  # 3) Regenerate & save
184
- print("🔄 Generating new public test set...")
185
- public_df, _ = _generate_and_save_test_set()
186
  return public_df
187
 
188
- def get_complete_test_set() -> pd.DataFrame:
189
- """
190
- Load the complete test set (with targets).
191
- Tries HF Hub-private → local CSV → regenerate.
192
- """
193
  # 1) Try HF Hub private
194
  try:
195
- print("📥 Attempting to load complete test set from HF Hub-private...")
196
- ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
197
  df = ds.to_pandas()
198
- print(f"✅ Loaded complete test set from HF Hub-private ({len(df):,} samples)")
199
- return df
 
 
 
 
 
 
 
200
  except Exception as e:
201
- print(f"⚠️ HF Hub-private load failed: {e}")
202
 
203
  # 2) Try local CSV
204
  if os.path.exists(LOCAL_COMPLETE_CSV):
205
  try:
206
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
207
- print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)")
208
- # Validate basic structure
209
- required_cols = ['sample_id', 'source_text', 'target_text', 'source_language', 'target_language']
210
  if all(col in df.columns for col in required_cols):
 
211
  return df
212
  else:
213
- print("⚠️ Local CSV has invalid structure, regenerating...")
214
  except Exception as e:
215
- print(f"⚠️ Failed to read local complete CSV: {e}")
216
 
217
  # 3) Regenerate & save
218
- print("🔄 Generating new complete test set...")
219
- _, complete_df = _generate_and_save_test_set()
220
  return complete_df
221
 
222
- def create_test_set_download() -> tuple[str, dict]:
223
- """
224
- Create a CSV download of the public test set and return its path + stats.
225
- """
226
- public_df = get_public_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  if public_df.empty:
229
- # Create minimal stats for empty dataset
230
  stats = {
231
- 'total_samples': 0,
232
- 'language_pairs': 0,
233
- 'google_comparable_samples': 0,
234
- 'languages': [],
235
- 'domains': []
236
  }
237
  return LOCAL_PUBLIC_CSV, stats
238
 
239
  download_path = LOCAL_PUBLIC_CSV
 
240
  # Ensure the CSV is up-to-date
241
  try:
242
  public_df.to_csv(download_path, index=False)
243
  except Exception as e:
244
- print(f"⚠️ Error updating CSV: {e}")
245
 
246
- # Calculate statistics
247
  try:
 
248
  stats = {
249
- 'total_samples': len(public_df),
250
- 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
251
- 'google_comparable_samples': int(public_df['google_comparable'].sum()) if 'google_comparable' in public_df.columns else 0,
252
- 'languages': sorted(list(set(public_df['source_language']).union(public_df['target_language']))),
253
- 'domains': public_df['domain'].unique().tolist() if 'domain' in public_df.columns else ['general']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  }
 
255
  except Exception as e:
256
- print(f"⚠️ Error calculating stats: {e}")
257
  stats = {
258
- 'total_samples': len(public_df),
259
- 'language_pairs': 0,
260
- 'google_comparable_samples': 0,
261
- 'languages': [],
262
- 'domains': []
263
  }
264
 
265
  return download_path, stats
266
 
267
- def validate_test_set_integrity() -> dict:
268
- """
269
- Validate test set coverage and integrity.
270
- """
271
  try:
272
- public_df = get_public_test_set()
273
- complete_df = get_complete_test_set()
274
 
275
  if public_df.empty or complete_df.empty:
276
  return {
277
- 'alignment_check': False,
278
- 'total_samples': 0,
279
- 'coverage_by_pair': {},
280
- 'missing_pairs': [],
281
- 'error': 'Test sets are empty or could not be loaded'
282
  }
283
 
284
- public_ids = set(public_df['sample_id'])
285
- private_ids = set(complete_df['sample_id'])
286
 
287
- coverage_by_pair = {}
288
- for src in ALL_UG40_LANGUAGES:
289
- for tgt in ALL_UG40_LANGUAGES:
290
- if src == tgt:
291
- continue
292
- subset = public_df[
293
- (public_df['source_language'] == src) &
294
- (public_df['target_language'] == tgt)
295
- ]
296
- count = len(subset)
297
- coverage_by_pair[f"{src}_{tgt}"] = {
298
- 'count': count,
299
- 'has_samples': count >= MIN_SAMPLES_PER_PAIR
300
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  return {
303
- 'alignment_check': public_ids <= private_ids,
304
- 'total_samples': len(public_df),
305
- 'coverage_by_pair': coverage_by_pair,
306
- 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']],
307
- 'public_samples': len(public_df),
308
- 'private_samples': len(complete_df),
309
- 'id_alignment_rate': len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0
 
310
  }
311
 
312
  except Exception as e:
313
  return {
314
- 'alignment_check': False,
315
- 'total_samples': 0,
316
- 'coverage_by_pair': {},
317
- 'missing_pairs': [],
318
- 'error': f'Validation failed: {str(e)}'
319
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import pandas as pd
4
  import yaml
5
+ import numpy as np
6
  from datasets import load_dataset
7
  from config import (
8
  TEST_SET_DATASET,
9
  SALT_DATASET,
10
  MAX_TEST_SAMPLES,
11
  HF_TOKEN,
 
12
  ALL_UG40_LANGUAGES,
13
+ GOOGLE_SUPPORTED_LANGUAGES,
14
+ EVALUATION_TRACKS,
15
+ SAMPLE_SIZE_RECOMMENDATIONS,
16
+ STATISTICAL_CONFIG,
17
  )
18
  import salt.dataset
19
+ from src.utils import get_all_language_pairs, get_track_language_pairs
20
 
21
  # Local CSV filenames for persistence
22
+ LOCAL_PUBLIC_CSV = "salt_test_set_scientific.csv"
23
+ LOCAL_COMPLETE_CSV = "salt_complete_test_set_scientific.csv"
24
+ LOCAL_TRACK_CSVS = {
25
+ track: f"salt_test_set_{track}.csv" for track in EVALUATION_TRACKS.keys()
26
+ }
27
 
28
+
29
+ def generate_scientific_test_set(
30
+ max_samples_per_pair: int = MAX_TEST_SAMPLES,
31
+ stratified_sampling: bool = True,
32
+ balance_tracks: bool = True,
33
+ ) -> pd.DataFrame:
34
+ """Generate scientifically rigorous test set with stratified sampling."""
35
+
36
+ print("🔬 Generating scientific SALT test set...")
37
 
38
  try:
39
+ # Build SALT dataset config
40
+ dataset_config = f"""
41
  huggingface_load:
42
  path: {SALT_DATASET}
43
  name: text-all
 
49
  type: text
50
  language: {ALL_UG40_LANGUAGES}
51
  allow_same_src_and_tgt_language: False
52
+ """
53
 
54
  config = yaml.safe_load(dataset_config)
55
  print("📥 Loading SALT dataset...")
 
60
  test_samples = []
61
  sample_id_counter = 1
62
 
63
+ # Calculate target samples per track for balanced evaluation
64
+ track_targets = calculate_track_sampling_targets(balance_tracks)
65
+
66
+ # Generate samples for each language pair with stratified sampling
67
  for src_lang in ALL_UG40_LANGUAGES:
68
  for tgt_lang in ALL_UG40_LANGUAGES:
69
  if src_lang == tgt_lang:
70
  continue
71
+
72
+ # Determine target sample size for this pair
73
+ pair_targets = calculate_pair_sampling_targets(
74
+ src_lang, tgt_lang, track_targets, max_samples_per_pair
75
+ )
76
+
77
+ target_samples = max(pair_targets.values()) if pair_targets else max_samples_per_pair
78
+
79
  # Filter for this language pair
80
  pair_data = full_data[
81
+ (full_data["source.language"] == src_lang) &
82
+ (full_data["target.language"] == tgt_lang)
83
  ]
84
 
85
  if pair_data.empty:
86
  print(f"⚠️ No data found for {src_lang} → {tgt_lang}")
87
  continue
88
 
89
+ # Stratified sampling if enabled
90
+ if stratified_sampling and len(pair_data) > target_samples:
91
+ sampled = stratified_sample_pair_data(pair_data, target_samples)
92
+ else:
93
+ # Simple random sampling
94
+ n_samples = min(len(pair_data), target_samples)
95
+ sampled = pair_data.sample(n=n_samples, random_state=42)
96
 
97
+ print(f"✅ {src_lang} → {tgt_lang}: {len(sampled)} samples")
98
 
99
  for _, row in sampled.iterrows():
100
+ # Determine which tracks include this pair
101
+ tracks_included = []
102
+ for track_name, track_config in EVALUATION_TRACKS.items():
103
+ if (src_lang in track_config["languages"] and
104
+ tgt_lang in track_config["languages"]):
105
+ tracks_included.append(track_name)
106
+
107
  test_samples.append({
108
+ "sample_id": f"salt_{sample_id_counter:06d}",
109
+ "source_text": row["source"],
110
+ "target_text": row["target"],
111
+ "source_language": src_lang,
112
+ "target_language": tgt_lang,
113
+ "domain": row.get("domain", "general"),
114
+ "google_comparable": (
115
  src_lang in GOOGLE_SUPPORTED_LANGUAGES and
116
  tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
117
+ ),
118
+ "tracks_included": ",".join(tracks_included),
119
+ "statistical_weight": calculate_statistical_weight(
120
+ src_lang, tgt_lang, tracks_included
121
+ ),
122
  })
123
  sample_id_counter += 1
124
 
 
126
 
127
  if test_df.empty:
128
  raise ValueError("No test samples generated - check SALT dataset availability")
 
 
129
 
130
+ # Validate scientific adequacy
131
+ adequacy_report = validate_test_set_scientific_adequacy(test_df)
 
132
 
133
+ print(f" Generated scientific test set: {len(test_df):,} samples")
134
+ print(f"📈 Test set adequacy: {adequacy_report['overall_adequacy']}")
 
 
 
135
 
136
  return test_df
137
 
138
  except Exception as e:
139
+ print(f"❌ Error generating scientific test set: {e}")
 
140
  return pd.DataFrame(columns=[
141
+ "sample_id", "source_text", "target_text", "source_language",
142
+ "target_language", "domain", "google_comparable", "tracks_included",
143
+ "statistical_weight"
144
  ])
145
 
146
+
147
+ def calculate_track_sampling_targets(balance_tracks: bool) -> Dict[str, int]:
148
+ """Calculate target sample sizes for each track to ensure statistical adequacy."""
149
+
150
+ track_targets = {}
151
 
152
+ for track_name, track_config in EVALUATION_TRACKS.items():
153
+ # Base requirement from config
154
+ min_per_pair = track_config["min_samples_per_pair"]
155
+
156
+ # Number of language pairs in this track
157
+ n_pairs = len(track_config["languages"]) * (len(track_config["languages"]) - 1)
158
+
159
+ # Calculate total samples needed for statistical adequacy
160
+ if balance_tracks:
161
+ # Use publication-quality recommendation
162
+ target_per_pair = max(
163
+ min_per_pair,
164
+ SAMPLE_SIZE_RECOMMENDATIONS["publication_quality"] // n_pairs
165
+ )
166
+ else:
167
+ target_per_pair = min_per_pair
168
+
169
+ track_targets[track_name] = target_per_pair * n_pairs
170
+
171
+ print(f"📊 {track_name}: targeting {target_per_pair} samples/pair × {n_pairs} pairs = {track_targets[track_name]} total")
172
+
173
+ return track_targets
174
+
175
+
176
+ def calculate_pair_sampling_targets(
177
+ src_lang: str, tgt_lang: str, track_targets: Dict[str, int], max_samples: int
178
+ ) -> Dict[str, int]:
179
+ """Calculate sampling targets for a specific language pair across tracks."""
180
+
181
+ pair_targets = {}
182
+
183
+ for track_name, track_config in EVALUATION_TRACKS.items():
184
+ if (src_lang in track_config["languages"] and
185
+ tgt_lang in track_config["languages"]):
186
+
187
+ n_pairs_in_track = len(track_config["languages"]) * (len(track_config["languages"]) - 1)
188
+ target_per_pair = track_targets[track_name] // n_pairs_in_track
189
+
190
+ pair_targets[track_name] = min(target_per_pair, max_samples)
191
+
192
+ return pair_targets
193
+
194
+
195
+ def stratified_sample_pair_data(pair_data: pd.DataFrame, target_samples: int) -> pd.DataFrame:
196
+ """Perform stratified sampling on pair data to ensure representativeness."""
197
+
198
+ # Try to stratify by domain if available
199
+ if "domain" in pair_data.columns and pair_data["domain"].nunique() > 1:
200
+ # Sample proportionally from each domain
201
+ domain_counts = pair_data["domain"].value_counts()
202
+ sampled_parts = []
203
+
204
+ for domain, count in domain_counts.items():
205
+ domain_data = pair_data[pair_data["domain"] == domain]
206
+
207
+ # Calculate proportional sample size
208
+ proportion = count / len(pair_data)
209
+ domain_target = max(1, int(target_samples * proportion))
210
+ domain_target = min(domain_target, len(domain_data))
211
+
212
+ if len(domain_data) >= domain_target:
213
+ domain_sample = domain_data.sample(n=domain_target, random_state=42)
214
+ sampled_parts.append(domain_sample)
215
+
216
+ if sampled_parts:
217
+ stratified_sample = pd.concat(sampled_parts, ignore_index=True)
218
+
219
+ # If we didn't get enough samples, fill with random sampling
220
+ if len(stratified_sample) < target_samples:
221
+ remaining_data = pair_data[~pair_data.index.isin(stratified_sample.index)]
222
+ additional_needed = target_samples - len(stratified_sample)
223
+
224
+ if len(remaining_data) >= additional_needed:
225
+ additional_sample = remaining_data.sample(n=additional_needed, random_state=42)
226
+ stratified_sample = pd.concat([stratified_sample, additional_sample], ignore_index=True)
227
+
228
+ return stratified_sample.head(target_samples)
229
+
230
+ # Fallback to simple random sampling
231
+ return pair_data.sample(n=min(target_samples, len(pair_data)), random_state=42)
232
+
233
+
234
+ def calculate_statistical_weight(
235
+ src_lang: str, tgt_lang: str, tracks_included: List[str]
236
+ ) -> float:
237
+ """Calculate statistical weight for a sample based on track inclusion."""
238
+
239
+ # Base weight
240
+ weight = 1.0
241
+
242
+ # Higher weight for samples in multiple tracks (more valuable)
243
+ weight *= len(tracks_included)
244
+
245
+ # Higher weight for Google-comparable pairs (enable baseline comparison)
246
+ if (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
247
+ tgt_lang in GOOGLE_SUPPORTED_LANGUAGES):
248
+ weight *= 1.5
249
+
250
+ # Normalize to reasonable range
251
+ return min(weight, 5.0)
252
+
253
+
254
+ def validate_test_set_scientific_adequacy(test_df: pd.DataFrame) -> Dict:
255
+ """Validate that the test set meets scientific adequacy requirements."""
256
+
257
+ adequacy_report = {
258
+ "overall_adequacy": "insufficient",
259
+ "track_adequacy": {},
260
+ "issues": [],
261
+ "recommendations": [],
262
+ "statistics": {},
263
+ }
264
+
265
+ if test_df.empty:
266
+ adequacy_report["issues"].append("Test set is empty")
267
+ return adequacy_report
268
+
269
+ # Check each track
270
+ track_adequacies = []
271
+
272
+ for track_name, track_config in EVALUATION_TRACKS.items():
273
+ track_languages = track_config["languages"]
274
+ min_per_pair = track_config["min_samples_per_pair"]
275
+
276
+ # Filter to track data
277
+ track_data = test_df[
278
+ (test_df["source_language"].isin(track_languages)) &
279
+ (test_df["target_language"].isin(track_languages))
280
+ ]
281
+
282
+ # Analyze pair coverage
283
+ pair_counts = {}
284
+ for src in track_languages:
285
+ for tgt in track_languages:
286
+ if src == tgt:
287
+ continue
288
+
289
+ pair_samples = track_data[
290
+ (track_data["source_language"] == src) &
291
+ (track_data["target_language"] == tgt)
292
+ ]
293
+ pair_counts[f"{src}_{tgt}"] = len(pair_samples)
294
+
295
+ # Calculate adequacy metrics
296
+ total_pairs = len(pair_counts)
297
+ adequate_pairs = sum(1 for count in pair_counts.values() if count >= min_per_pair)
298
+ adequacy_rate = adequate_pairs / max(total_pairs, 1)
299
+
300
+ # Determine track adequacy level
301
+ if adequacy_rate >= 0.9:
302
+ track_adequacy = "excellent"
303
+ elif adequacy_rate >= 0.8:
304
+ track_adequacy = "good"
305
+ elif adequacy_rate >= 0.6:
306
+ track_adequacy = "fair"
307
+ else:
308
+ track_adequacy = "insufficient"
309
+
310
+ adequacy_report["track_adequacy"][track_name] = {
311
+ "adequacy": track_adequacy,
312
+ "adequacy_rate": adequacy_rate,
313
+ "total_samples": len(track_data),
314
+ "total_pairs": total_pairs,
315
+ "adequate_pairs": adequate_pairs,
316
+ "min_samples_per_pair": min_per_pair,
317
+ "pair_counts": pair_counts,
318
+ }
319
+
320
+ track_adequacies.append(track_adequacy)
321
+
322
+ # Add specific issues
323
+ if track_adequacy == "insufficient":
324
+ inadequate_pairs = [k for k, v in pair_counts.items() if v < min_per_pair]
325
+ adequacy_report["issues"].append(
326
+ f"{track_name}: {len(inadequate_pairs)} pairs below minimum"
327
+ )
328
+
329
+ # Overall adequacy assessment
330
+ if all(adequacy in ["excellent", "good"] for adequacy in track_adequacies):
331
+ adequacy_report["overall_adequacy"] = "excellent"
332
+ elif all(adequacy in ["excellent", "good", "fair"] for adequacy in track_adequacies):
333
+ adequacy_report["overall_adequacy"] = "good"
334
+ elif any(adequacy in ["good", "fair"] for adequacy in track_adequacies):
335
+ adequacy_report["overall_adequacy"] = "fair"
336
+ else:
337
+ adequacy_report["overall_adequacy"] = "insufficient"
338
+
339
+ # Overall statistics
340
+ adequacy_report["statistics"] = {
341
+ "total_samples": len(test_df),
342
+ "total_language_pairs": len(test_df.groupby(["source_language", "target_language"])),
343
+ "google_comparable_samples": int(test_df["google_comparable"].sum()),
344
+ "domain_distribution": test_df["domain"].value_counts().to_dict(),
345
+ "track_sample_distribution": {
346
+ track: adequacy_report["track_adequacy"][track]["total_samples"]
347
+ for track in EVALUATION_TRACKS.keys()
348
+ },
349
+ }
350
+
351
+ # Generate recommendations
352
+ if adequacy_report["overall_adequacy"] in ["insufficient", "fair"]:
353
+ adequacy_report["recommendations"].append(
354
+ "Consider increasing sample size for better statistical power"
355
+ )
356
+
357
+ if adequacy_report["statistics"]["google_comparable_samples"] < 1000:
358
+ adequacy_report["recommendations"].append(
359
+ "More Google-comparable samples recommended for baseline comparison"
360
+ )
361
+
362
+ return adequacy_report
363
+
364
+
365
+ def _generate_and_save_scientific_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]:
366
+ """Generate and save both public and complete versions of the scientific test set."""
367
+
368
+ print("🔬 Generating and saving scientific test sets...")
369
+
370
+ full_df = generate_scientific_test_set()
371
 
372
  if full_df.empty:
373
+ print("❌ Failed to generate scientific test set")
 
374
  empty_public = pd.DataFrame(columns=[
375
+ "sample_id", "source_text", "source_language",
376
+ "target_language", "domain", "google_comparable",
377
+ "tracks_included", "statistical_weight"
378
  ])
379
  empty_complete = pd.DataFrame(columns=[
380
+ "sample_id", "source_text", "target_text", "source_language",
381
+ "target_language", "domain", "google_comparable",
382
+ "tracks_included", "statistical_weight"
383
  ])
384
  return empty_public, empty_complete
385
 
386
  # Public version (no target_text)
387
  public_df = full_df[[
388
+ "sample_id", "source_text", "source_language",
389
+ "target_language", "domain", "google_comparable",
390
+ "tracks_included", "statistical_weight"
391
  ]].copy()
392
 
393
+ # Save main versions
394
  try:
395
  public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
396
  full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
397
+ print(f"✅ Saved main test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
398
  except Exception as e:
399
+ print(f"⚠️ Error saving main CSVs: {e}")
400
+
401
+ # Save track-specific versions for easier analysis
402
+ for track_name, track_config in EVALUATION_TRACKS.items():
403
+ try:
404
+ track_languages = track_config["languages"]
405
+ track_public = public_df[
406
+ (public_df["source_language"].isin(track_languages)) &
407
+ (public_df["target_language"].isin(track_languages))
408
+ ]
409
+
410
+ track_filename = LOCAL_TRACK_CSVS[track_name]
411
+ track_public.to_csv(track_filename, index=False)
412
+ print(f"✅ Saved {track_name} track: {track_filename} ({len(track_public):,} samples)")
413
+
414
+ except Exception as e:
415
+ print(f"⚠️ Error saving {track_name} track CSV: {e}")
416
 
417
  return public_df, full_df
418
 
419
+
420
+ def get_public_test_set_scientific() -> pd.DataFrame:
421
+ """Load the scientific public test set with enhanced fallback logic."""
422
+
 
423
  # 1) Try HF Hub
424
  try:
425
+ print("📥 Attempting to load scientific test set from HF Hub...")
426
+ ds = load_dataset(TEST_SET_DATASET + "-scientific", split="train", token=HF_TOKEN)
427
  df = ds.to_pandas()
428
+
429
+ # Validate scientific structure
430
+ required_cols = ["sample_id", "source_text", "source_language", "target_language",
431
+ "tracks_included", "statistical_weight"]
432
+ if all(col in df.columns for col in required_cols):
433
+ print(f"✅ Loaded scientific test set from HF Hub ({len(df):,} samples)")
434
+ return df
435
+ else:
436
+ print("⚠️ HF Hub test set missing scientific columns, regenerating...")
437
+
438
  except Exception as e:
439
  print(f"⚠️ HF Hub load failed: {e}")
440
 
 
442
  if os.path.exists(LOCAL_PUBLIC_CSV):
443
  try:
444
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
445
+ required_cols = ["sample_id", "source_text", "source_language", "target_language"]
 
 
446
  if all(col in df.columns for col in required_cols):
447
+ print(f"✅ Loaded scientific test set from local CSV ({len(df):,} samples)")
448
  return df
449
  else:
450
  print("⚠️ Local CSV has invalid structure, regenerating...")
451
  except Exception as e:
452
+ print(f"⚠️ Failed to read local scientific CSV: {e}")
453
 
454
  # 3) Regenerate & save
455
+ print("🔄 Generating new scientific test set...")
456
+ public_df, _ = _generate_and_save_scientific_test_set()
457
  return public_df
458
 
459
+
460
+ def get_complete_test_set_scientific() -> pd.DataFrame:
461
+ """Load the complete scientific test set with targets."""
462
+
 
463
  # 1) Try HF Hub private
464
  try:
465
+ print("📥 Attempting to load complete scientific test set from HF Hub...")
466
+ ds = load_dataset(TEST_SET_DATASET + "-scientific-private", split="train", token=HF_TOKEN)
467
  df = ds.to_pandas()
468
+
469
+ required_cols = ["sample_id", "source_text", "target_text", "source_language",
470
+ "target_language", "tracks_included", "statistical_weight"]
471
+ if all(col in df.columns for col in required_cols):
472
+ print(f"✅ Loaded complete scientific test set from HF Hub ({len(df):,} samples)")
473
+ return df
474
+ else:
475
+ print("⚠️ HF Hub complete test set missing scientific columns, regenerating...")
476
+
477
  except Exception as e:
478
+ print(f"⚠️ HF Hub private load failed: {e}")
479
 
480
  # 2) Try local CSV
481
  if os.path.exists(LOCAL_COMPLETE_CSV):
482
  try:
483
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
484
+ required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
 
 
485
  if all(col in df.columns for col in required_cols):
486
+ print(f"✅ Loaded complete scientific test set from local CSV ({len(df):,} samples)")
487
  return df
488
  else:
489
+ print("⚠️ Local complete CSV has invalid structure, regenerating...")
490
  except Exception as e:
491
+ print(f"⚠️ Failed to read local complete scientific CSV: {e}")
492
 
493
  # 3) Regenerate & save
494
+ print("🔄 Generating new complete scientific test set...")
495
+ _, complete_df = _generate_and_save_scientific_test_set()
496
  return complete_df
497
 
498
+
499
+ def get_track_test_set(track: str) -> pd.DataFrame:
500
+ """Get test set filtered for a specific track."""
501
+
502
+ if track not in EVALUATION_TRACKS:
503
+ print(f"❌ Unknown track: {track}")
504
+ return pd.DataFrame()
505
+
506
+ # Try track-specific CSV first
507
+ track_csv = LOCAL_TRACK_CSVS.get(track)
508
+ if track_csv and os.path.exists(track_csv):
509
+ try:
510
+ df = pd.read_csv(track_csv)
511
+ print(f"✅ Loaded {track} test set from track-specific CSV ({len(df):,} samples)")
512
+ return df
513
+ except Exception as e:
514
+ print(f"⚠️ Failed to read {track} CSV: {e}")
515
+
516
+ # Fallback to filtering main test set
517
+ public_df = get_public_test_set_scientific()
518
+
519
+ if public_df.empty:
520
+ return pd.DataFrame()
521
+
522
+ track_languages = EVALUATION_TRACKS[track]["languages"]
523
+ track_df = public_df[
524
+ (public_df["source_language"].isin(track_languages)) &
525
+ (public_df["target_language"].isin(track_languages))
526
+ ]
527
+
528
+ print(f"✅ Filtered {track} test set from main set ({len(track_df):,} samples)")
529
+ return track_df
530
+
531
+
532
+ def create_test_set_download_scientific() -> Tuple[str, Dict]:
533
+ """Create scientific test set download with comprehensive metadata."""
534
+
535
+ public_df = get_public_test_set_scientific()
536
 
537
  if public_df.empty:
 
538
  stats = {
539
+ "total_samples": 0,
540
+ "track_breakdown": {},
541
+ "adequacy_assessment": "insufficient",
542
+ "scientific_metadata": {},
 
543
  }
544
  return LOCAL_PUBLIC_CSV, stats
545
 
546
  download_path = LOCAL_PUBLIC_CSV
547
+
548
  # Ensure the CSV is up-to-date
549
  try:
550
  public_df.to_csv(download_path, index=False)
551
  except Exception as e:
552
+ print(f"⚠️ Error updating scientific CSV: {e}")
553
 
554
+ # Calculate comprehensive statistics
555
  try:
556
+ # Basic statistics
557
  stats = {
558
+ "total_samples": len(public_df),
559
+ "languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))),
560
+ "domains": public_df["domain"].unique().tolist() if "domain" in public_df.columns else ["general"],
561
+ }
562
+
563
+ # Track-specific breakdown
564
+ track_breakdown = {}
565
+ for track_name, track_config in EVALUATION_TRACKS.items():
566
+ track_languages = track_config["languages"]
567
+ track_data = public_df[
568
+ (public_df["source_language"].isin(track_languages)) &
569
+ (public_df["target_language"].isin(track_languages))
570
+ ]
571
+
572
+ track_breakdown[track_name] = {
573
+ "name": track_config["name"],
574
+ "total_samples": len(track_data),
575
+ "language_pairs": len(track_data.groupby(["source_language", "target_language"])),
576
+ "min_samples_per_pair": track_config["min_samples_per_pair"],
577
+ "statistical_adequacy": len(track_data) >= track_config["min_samples_per_pair"] * len(track_languages) * (len(track_languages) - 1),
578
+ }
579
+
580
+ stats["track_breakdown"] = track_breakdown
581
+
582
+ # Google-comparable statistics
583
+ if "google_comparable" in public_df.columns:
584
+ stats["google_comparable_samples"] = int(public_df["google_comparable"].sum())
585
+ stats["google_comparable_rate"] = float(public_df["google_comparable"].mean())
586
+ else:
587
+ stats["google_comparable_samples"] = 0
588
+ stats["google_comparable_rate"] = 0.0
589
+
590
+ # Scientific adequacy assessment
591
+ adequacy_report = validate_test_set_scientific_adequacy(public_df)
592
+ stats["adequacy_assessment"] = adequacy_report["overall_adequacy"]
593
+ stats["adequacy_details"] = adequacy_report
594
+
595
+ # Scientific metadata
596
+ stats["scientific_metadata"] = {
597
+ "stratified_sampling": True,
598
+ "statistical_weighting": "statistical_weight" in public_df.columns,
599
+ "track_balanced": True,
600
+ "confidence_level": STATISTICAL_CONFIG["confidence_level"],
601
+ "recommended_for": [
602
+ track for track, info in track_breakdown.items()
603
+ if info["statistical_adequacy"]
604
+ ],
605
  }
606
+
607
  except Exception as e:
608
+ print(f"⚠️ Error calculating scientific stats: {e}")
609
  stats = {
610
+ "total_samples": len(public_df),
611
+ "track_breakdown": {},
612
+ "adequacy_assessment": "unknown",
613
+ "scientific_metadata": {},
 
614
  }
615
 
616
  return download_path, stats
617
 
618
+
619
+ def validate_test_set_integrity_scientific() -> Dict:
620
+ """Comprehensive validation of scientific test set integrity."""
621
+
622
  try:
623
+ public_df = get_public_test_set_scientific()
624
+ complete_df = get_complete_test_set_scientific()
625
 
626
  if public_df.empty or complete_df.empty:
627
  return {
628
+ "alignment_check": False,
629
+ "total_samples": 0,
630
+ "scientific_adequacy": {},
631
+ "track_analysis": {},
632
+ "error": "Test sets are empty or could not be loaded",
633
  }
634
 
635
+ public_ids = set(public_df["sample_id"])
636
+ private_ids = set(complete_df["sample_id"])
637
 
638
+ # Track-specific analysis
639
+ track_analysis = {}
640
+ for track_name, track_config in EVALUATION_TRACKS.items():
641
+ track_languages = track_config["languages"]
642
+ min_per_pair = track_config["min_samples_per_pair"]
643
+
644
+ # Analyze public set for this track
645
+ track_public = public_df[
646
+ (public_df["source_language"].isin(track_languages)) &
647
+ (public_df["target_language"].isin(track_languages))
648
+ ]
649
+
650
+ # Analyze complete set for this track
651
+ track_complete = complete_df[
652
+ (complete_df["source_language"].isin(track_languages)) &
653
+ (complete_df["target_language"].isin(track_languages))
654
+ ]
655
+
656
+ # Calculate coverage
657
+ pair_coverage = {}
658
+ for src in track_languages:
659
+ for tgt in track_languages:
660
+ if src == tgt:
661
+ continue
662
+
663
+ public_subset = track_public[
664
+ (track_public["source_language"] == src) &
665
+ (track_public["target_language"] == tgt)
666
+ ]
667
+
668
+ complete_subset = track_complete[
669
+ (track_complete["source_language"] == src) &
670
+ (track_complete["target_language"] == tgt)
671
+ ]
672
+
673
+ pair_coverage[f"{src}_{tgt}"] = {
674
+ "public_count": len(public_subset),
675
+ "complete_count": len(complete_subset),
676
+ "alignment": len(public_subset) == len(complete_subset),
677
+ "meets_minimum": len(public_subset) >= min_per_pair,
678
+ }
679
+
680
+ # Track summary
681
+ total_pairs = len(pair_coverage)
682
+ adequate_pairs = sum(1 for info in pair_coverage.values() if info["meets_minimum"])
683
+ aligned_pairs = sum(1 for info in pair_coverage.values() if info["alignment"])
684
+
685
+ track_analysis[track_name] = {
686
+ "total_pairs": total_pairs,
687
+ "adequate_pairs": adequate_pairs,
688
+ "aligned_pairs": aligned_pairs,
689
+ "adequacy_rate": adequate_pairs / max(total_pairs, 1),
690
+ "alignment_rate": aligned_pairs / max(total_pairs, 1),
691
+ "pair_coverage": pair_coverage,
692
+ "statistical_power": calculate_track_statistical_power(track_public, track_config),
693
+ }
694
+
695
+ # Overall scientific adequacy
696
+ adequacy_report = validate_test_set_scientific_adequacy(public_df)
697
 
698
  return {
699
+ "alignment_check": public_ids <= private_ids,
700
+ "total_samples": len(public_df),
701
+ "track_analysis": track_analysis,
702
+ "scientific_adequacy": adequacy_report,
703
+ "public_samples": len(public_df),
704
+ "private_samples": len(complete_df),
705
+ "id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0,
706
+ "integrity_score": calculate_integrity_score(track_analysis, adequacy_report),
707
  }
708
 
709
  except Exception as e:
710
  return {
711
+ "alignment_check": False,
712
+ "total_samples": 0,
713
+ "scientific_adequacy": {},
714
+ "track_analysis": {},
715
+ "error": f"Validation failed: {str(e)}",
716
+ }
717
+
718
+
719
+ def calculate_track_statistical_power(track_data: pd.DataFrame, track_config: Dict) -> float:
720
+ """Calculate statistical power estimate for a track."""
721
+
722
+ if track_data.empty:
723
+ return 0.0
724
+
725
+ # Simple power estimation based on sample size
726
+ min_required = track_config["min_samples_per_pair"]
727
+ languages = track_config["languages"]
728
+ total_pairs = len(languages) * (len(languages) - 1)
729
+
730
+ # Calculate average samples per pair
731
+ pair_counts = []
732
+ for src in languages:
733
+ for tgt in languages:
734
+ if src == tgt:
735
+ continue
736
+
737
+ pair_samples = track_data[
738
+ (track_data["source_language"] == src) &
739
+ (track_data["target_language"] == tgt)
740
+ ]
741
+ pair_counts.append(len(pair_samples))
742
+
743
+ if not pair_counts:
744
+ return 0.0
745
+
746
+ avg_samples_per_pair = np.mean(pair_counts)
747
+
748
+ # Rough power estimation (0.8 power at 2x minimum, 0.95 at 4x minimum)
749
+ if avg_samples_per_pair >= min_required * 4:
750
+ return 0.95
751
+ elif avg_samples_per_pair >= min_required * 2:
752
+ return 0.8
753
+ elif avg_samples_per_pair >= min_required:
754
+ return 0.6
755
+ else:
756
+ return max(0.0, avg_samples_per_pair / min_required * 0.6)
757
+
758
+
759
+ def calculate_integrity_score(track_analysis: Dict, adequacy_report: Dict) -> float:
760
+ """Calculate overall integrity score for the test set."""
761
+
762
+ if not track_analysis or not adequacy_report:
763
+ return 0.0
764
+
765
+ # Track adequacy scores
766
+ track_scores = []
767
+ for track_info in track_analysis.values():
768
+ adequacy_rate = track_info.get("adequacy_rate", 0.0)
769
+ alignment_rate = track_info.get("alignment_rate", 0.0)
770
+ track_score = (adequacy_rate + alignment_rate) / 2
771
+ track_scores.append(track_score)
772
+
773
+ # Overall adequacy mapping
774
+ adequacy_mapping = {
775
+ "excellent": 1.0,
776
+ "good": 0.8,
777
+ "fair": 0.6,
778
+ "insufficient": 0.2,
779
+ }
780
+
781
+ overall_adequacy_score = adequacy_mapping.get(
782
+ adequacy_report.get("overall_adequacy", "insufficient"), 0.2
783
+ )
784
+
785
+ # Combined score
786
+ if track_scores:
787
+ track_avg = np.mean(track_scores)
788
+ integrity_score = (track_avg + overall_adequacy_score) / 2
789
+ else:
790
+ integrity_score = overall_adequacy_score
791
+
792
+ return float(integrity_score)