akera commited on
Commit
aecc3e1
·
verified ·
1 Parent(s): cb7f64d

Update src/test_set.py

Browse files
Files changed (1) hide show
  1. src/test_set.py +219 -106
src/test_set.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pandas as pd
3
  import yaml
@@ -18,84 +19,138 @@ from src.utils import get_all_language_pairs
18
  LOCAL_PUBLIC_CSV = "salt_test_set.csv"
19
  LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
20
 
21
-
22
  def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
23
  """
24
  Generate standardized test set from the SALT dataset.
25
  """
26
  print("🔄 Generating SALT test set from source dataset...")
27
- # Build SALT dataset config
28
- dataset_config = f'''
29
- huggingface_load:
30
- path: {SALT_DATASET}
31
- name: text-all
32
- split: test
33
- source:
34
- type: text
35
- language: {ALL_UG40_LANGUAGES}
36
- target:
37
- type: text
38
- language: {ALL_UG40_LANGUAGES}
39
- allow_same_src_and_tgt_language: False
40
- '''
41
- config = yaml.safe_load(dataset_config)
42
- full_data = pd.DataFrame(salt.dataset.create(config))
43
-
44
- test_samples = []
45
- sample_id_counter = 1
46
-
47
- for src_lang in ALL_UG40_LANGUAGES:
48
- for tgt_lang in ALL_UG40_LANGUAGES:
49
- if src_lang == tgt_lang:
50
- continue
51
- pair_data = full_data[
52
- (full_data['source.language'] == src_lang) &
53
- (full_data['target.language'] == tgt_lang)
54
- ]
55
- if pair_data.empty:
56
- continue
57
-
58
- # Sample up to max_samples_per_pair
59
- n_samples = min(len(pair_data), max_samples_per_pair)
60
- sampled = pair_data.sample(n=n_samples, random_state=42)
61
-
62
- for _, row in sampled.iterrows():
63
- test_samples.append({
64
- 'sample_id': f"salt_{sample_id_counter:06d}",
65
- 'source_text': row['source'],
66
- 'target_text': row['target'],
67
- 'source_language': src_lang,
68
- 'target_language': tgt_lang,
69
- 'domain': row.get('domain', 'general'),
70
- 'google_comparable': (
71
- src_lang in GOOGLE_SUPPORTED_LANGUAGES and
72
- tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
73
- )
74
- })
75
- sample_id_counter += 1
76
-
77
- test_df = pd.DataFrame(test_samples)
78
- print(f"✅ Generated test set: {len(test_df):,} samples across {len(get_all_language_pairs()):,} pairs")
79
- return test_df
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- def _generate_and_save_test_set() -> (pd.DataFrame, pd.DataFrame):
83
  """
84
  Generate the full test set and persist both public and complete CSV files.
85
  """
 
 
86
  full_df = generate_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Public version (no target_text)
88
  public_df = full_df[[
89
  'sample_id', 'source_text', 'source_language',
90
  'target_language', 'domain', 'google_comparable'
91
- ]]
92
- public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
93
- # Complete version (with target_text)
94
- full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
95
- print(f"✅ Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
 
 
 
 
 
96
  return public_df, full_df
97
 
98
-
99
  def get_public_test_set() -> pd.DataFrame:
100
  """
101
  Load the public test set (without targets).
@@ -103,28 +158,33 @@ def get_public_test_set() -> pd.DataFrame:
103
  """
104
  # 1) Try HF Hub
105
  try:
 
106
  ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
107
  df = ds.to_pandas()
108
  print(f"✅ Loaded public test set from HF Hub ({len(df):,} samples)")
109
  return df
110
  except Exception as e:
111
- print("⚠️ HF Hub load failed, falling back to local CSV:", e)
112
 
113
  # 2) Try local CSV
114
  if os.path.exists(LOCAL_PUBLIC_CSV):
115
  try:
116
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
117
  print(f"✅ Loaded public test set from local CSV ({len(df):,} samples)")
118
- return df
 
 
 
 
 
119
  except Exception as e:
120
- print("⚠️ Failed to read local CSV, regenerating:", e)
121
 
122
  # 3) Regenerate & save
123
- print("🔄 Generating new public test set and saving to CSV...")
124
  public_df, _ = _generate_and_save_test_set()
125
  return public_df
126
 
127
-
128
  def get_complete_test_set() -> pd.DataFrame:
129
  """
130
  Load the complete test set (with targets).
@@ -132,75 +192,128 @@ def get_complete_test_set() -> pd.DataFrame:
132
  """
133
  # 1) Try HF Hub private
134
  try:
 
135
  ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
136
  df = ds.to_pandas()
137
  print(f"✅ Loaded complete test set from HF Hub-private ({len(df):,} samples)")
138
  return df
139
  except Exception as e:
140
- print("⚠️ HF Hub-private load failed, falling back to local CSV:", e)
141
 
142
  # 2) Try local CSV
143
  if os.path.exists(LOCAL_COMPLETE_CSV):
144
  try:
145
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
146
  print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)")
147
- return df
 
 
 
 
 
148
  except Exception as e:
149
- print("⚠️ Failed to read local complete CSV, regenerating:", e)
150
 
151
  # 3) Regenerate & save
152
- print("🔄 Generating new complete test set and saving to CSV...")
153
  _, complete_df = _generate_and_save_test_set()
154
  return complete_df
155
 
156
-
157
- def create_test_set_download() -> (str, dict):
158
  """
159
  Create a CSV download of the public test set and return its path + stats.
160
  """
161
  public_df = get_public_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
162
  download_path = LOCAL_PUBLIC_CSV
163
  # Ensure the CSV is up-to-date
164
- public_df.to_csv(download_path, index=False)
 
 
 
165
 
166
- stats = {
167
- 'total_samples': len(public_df),
168
- 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
169
- 'google_comparable_samples': int(public_df['google_comparable'].sum()),
170
- 'languages': list(set(public_df['source_language']).union(public_df['target_language'])),
171
- 'domains': public_df['domain'].unique().tolist()
172
- }
 
 
 
 
 
 
 
 
 
 
 
 
173
  return download_path, stats
174
 
175
-
176
  def validate_test_set_integrity() -> dict:
177
  """
178
  Validate test set coverage and integrity.
179
  """
180
- public_df = get_public_test_set()
181
- complete_df = get_complete_test_set()
 
 
 
 
 
 
 
 
 
 
182
 
183
- public_ids = set(public_df['sample_id'])
184
- private_ids = set(complete_df['sample_id'])
185
 
186
- coverage_by_pair = {}
187
- for src in ALL_UG40_LANGUAGES:
188
- for tgt in ALL_UG40_LANGUAGES:
189
- if src == tgt:
190
- continue
191
- subset = public_df[
192
- (public_df['source_language'] == src) &
193
- (public_df['target_language'] == tgt)
194
- ]
195
- count = len(subset)
196
- coverage_by_pair[f"{src}_{tgt}"] = {
197
- 'count': count,
198
- 'has_samples': count >= MIN_SAMPLES_PER_PAIR
199
- }
200
 
201
- return {
202
- 'alignment_check': public_ids <= private_ids,
203
- 'total_samples': len(public_df),
204
- 'coverage_by_pair': coverage_by_pair,
205
- 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
206
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/test_set.py
2
  import os
3
  import pandas as pd
4
  import yaml
 
19
  LOCAL_PUBLIC_CSV = "salt_test_set.csv"
20
  LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
21
 
 
22
  def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
23
  """
24
  Generate standardized test set from the SALT dataset.
25
  """
26
  print("🔄 Generating SALT test set from source dataset...")
27
+
28
+ try:
29
+ # Build SALT dataset config - using 'test' split for consistency
30
+ dataset_config = f'''
31
+ huggingface_load:
32
+ path: {SALT_DATASET}
33
+ name: text-all
34
+ split: test
35
+ source:
36
+ type: text
37
+ language: {ALL_UG40_LANGUAGES}
38
+ target:
39
+ type: text
40
+ language: {ALL_UG40_LANGUAGES}
41
+ allow_same_src_and_tgt_language: False
42
+ '''
43
+
44
+ config = yaml.safe_load(dataset_config)
45
+ print("📥 Loading SALT dataset...")
46
+ full_data = pd.DataFrame(salt.dataset.create(config))
47
+
48
+ print(f"📊 Loaded {len(full_data):,} samples from SALT dataset")
49
+
50
+ test_samples = []
51
+ sample_id_counter = 1
52
+
53
+ # Generate samples for each language pair
54
+ for src_lang in ALL_UG40_LANGUAGES:
55
+ for tgt_lang in ALL_UG40_LANGUAGES:
56
+ if src_lang == tgt_lang:
57
+ continue
58
+
59
+ # Filter for this language pair
60
+ pair_data = full_data[
61
+ (full_data['source.language'] == src_lang) &
62
+ (full_data['target.language'] == tgt_lang)
63
+ ]
64
+
65
+ if pair_data.empty:
66
+ print(f"⚠️ No data found for {src_lang} → {tgt_lang}")
67
+ continue
68
+
69
+ # Sample up to max_samples_per_pair
70
+ n_samples = min(len(pair_data), max_samples_per_pair)
71
+ sampled = pair_data.sample(n=n_samples, random_state=42)
72
+
73
+ print(f"✅ {src_lang} → {tgt_lang}: {n_samples} samples")
74
+
75
+ for _, row in sampled.iterrows():
76
+ test_samples.append({
77
+ 'sample_id': f"salt_{sample_id_counter:06d}",
78
+ 'source_text': row['source'],
79
+ 'target_text': row['target'],
80
+ 'source_language': src_lang,
81
+ 'target_language': tgt_lang,
82
+ 'domain': row.get('domain', 'general'),
83
+ 'google_comparable': (
84
+ src_lang in GOOGLE_SUPPORTED_LANGUAGES and
85
+ tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
86
+ )
87
+ })
88
+ sample_id_counter += 1
89
+
90
+ test_df = pd.DataFrame(test_samples)
91
+
92
+ if test_df.empty:
93
+ raise ValueError("No test samples generated - check SALT dataset availability")
94
+
95
+ print(f"✅ Generated test set: {len(test_df):,} samples across {len(test_df.groupby(['source_language', 'target_language'])):,} pairs")
96
+
97
+ # Add some statistics
98
+ google_samples = test_df['google_comparable'].sum()
99
+ unique_pairs = len(test_df.groupby(['source_language', 'target_language']))
100
+
101
+ print(f"📈 Test set statistics:")
102
+ print(f" - Total samples: {len(test_df):,}")
103
+ print(f" - Language pairs: {unique_pairs}")
104
+ print(f" - Google comparable: {google_samples:,} samples")
105
+ print(f" - UG40 only: {len(test_df) - google_samples:,} samples")
106
+
107
+ return test_df
108
+
109
+ except Exception as e:
110
+ print(f"❌ Error generating test set: {e}")
111
+ # Return empty DataFrame with correct structure
112
+ return pd.DataFrame(columns=[
113
+ 'sample_id', 'source_text', 'target_text', 'source_language',
114
+ 'target_language', 'domain', 'google_comparable'
115
+ ])
116
 
117
+ def _generate_and_save_test_set() -> tuple[pd.DataFrame, pd.DataFrame]:
118
  """
119
  Generate the full test set and persist both public and complete CSV files.
120
  """
121
+ print("🔄 Generating and saving test sets...")
122
+
123
  full_df = generate_test_set()
124
+
125
+ if full_df.empty:
126
+ print("❌ Failed to generate test set")
127
+ # Return empty DataFrames with correct structure
128
+ empty_public = pd.DataFrame(columns=[
129
+ 'sample_id', 'source_text', 'source_language',
130
+ 'target_language', 'domain', 'google_comparable'
131
+ ])
132
+ empty_complete = pd.DataFrame(columns=[
133
+ 'sample_id', 'source_text', 'target_text', 'source_language',
134
+ 'target_language', 'domain', 'google_comparable'
135
+ ])
136
+ return empty_public, empty_complete
137
+
138
  # Public version (no target_text)
139
  public_df = full_df[[
140
  'sample_id', 'source_text', 'source_language',
141
  'target_language', 'domain', 'google_comparable'
142
+ ]].copy()
143
+
144
+ # Save both versions
145
+ try:
146
+ public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
147
+ full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
148
+ print(f"✅ Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
149
+ except Exception as e:
150
+ print(f"⚠️ Error saving CSVs: {e}")
151
+
152
  return public_df, full_df
153
 
 
154
  def get_public_test_set() -> pd.DataFrame:
155
  """
156
  Load the public test set (without targets).
 
158
  """
159
  # 1) Try HF Hub
160
  try:
161
+ print("📥 Attempting to load public test set from HF Hub...")
162
  ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
163
  df = ds.to_pandas()
164
  print(f"✅ Loaded public test set from HF Hub ({len(df):,} samples)")
165
  return df
166
  except Exception as e:
167
+ print(f"⚠️ HF Hub load failed: {e}")
168
 
169
  # 2) Try local CSV
170
  if os.path.exists(LOCAL_PUBLIC_CSV):
171
  try:
172
  df = pd.read_csv(LOCAL_PUBLIC_CSV)
173
  print(f"✅ Loaded public test set from local CSV ({len(df):,} samples)")
174
+ # Validate basic structure
175
+ required_cols = ['sample_id', 'source_text', 'source_language', 'target_language']
176
+ if all(col in df.columns for col in required_cols):
177
+ return df
178
+ else:
179
+ print("⚠️ Local CSV has invalid structure, regenerating...")
180
  except Exception as e:
181
+ print(f"⚠️ Failed to read local CSV: {e}")
182
 
183
  # 3) Regenerate & save
184
+ print("🔄 Generating new public test set...")
185
  public_df, _ = _generate_and_save_test_set()
186
  return public_df
187
 
 
188
  def get_complete_test_set() -> pd.DataFrame:
189
  """
190
  Load the complete test set (with targets).
 
192
  """
193
  # 1) Try HF Hub private
194
  try:
195
+ print("📥 Attempting to load complete test set from HF Hub-private...")
196
  ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
197
  df = ds.to_pandas()
198
  print(f"✅ Loaded complete test set from HF Hub-private ({len(df):,} samples)")
199
  return df
200
  except Exception as e:
201
+ print(f"⚠️ HF Hub-private load failed: {e}")
202
 
203
  # 2) Try local CSV
204
  if os.path.exists(LOCAL_COMPLETE_CSV):
205
  try:
206
  df = pd.read_csv(LOCAL_COMPLETE_CSV)
207
  print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)")
208
+ # Validate basic structure
209
+ required_cols = ['sample_id', 'source_text', 'target_text', 'source_language', 'target_language']
210
+ if all(col in df.columns for col in required_cols):
211
+ return df
212
+ else:
213
+ print("⚠️ Local CSV has invalid structure, regenerating...")
214
  except Exception as e:
215
+ print(f"⚠️ Failed to read local complete CSV: {e}")
216
 
217
  # 3) Regenerate & save
218
+ print("🔄 Generating new complete test set...")
219
  _, complete_df = _generate_and_save_test_set()
220
  return complete_df
221
 
222
+ def create_test_set_download() -> tuple[str, dict]:
 
223
  """
224
  Create a CSV download of the public test set and return its path + stats.
225
  """
226
  public_df = get_public_test_set()
227
+
228
+ if public_df.empty:
229
+ # Create minimal stats for empty dataset
230
+ stats = {
231
+ 'total_samples': 0,
232
+ 'language_pairs': 0,
233
+ 'google_comparable_samples': 0,
234
+ 'languages': [],
235
+ 'domains': []
236
+ }
237
+ return LOCAL_PUBLIC_CSV, stats
238
+
239
  download_path = LOCAL_PUBLIC_CSV
240
  # Ensure the CSV is up-to-date
241
+ try:
242
+ public_df.to_csv(download_path, index=False)
243
+ except Exception as e:
244
+ print(f"⚠️ Error updating CSV: {e}")
245
 
246
+ # Calculate statistics
247
+ try:
248
+ stats = {
249
+ 'total_samples': len(public_df),
250
+ 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
251
+ 'google_comparable_samples': int(public_df['google_comparable'].sum()) if 'google_comparable' in public_df.columns else 0,
252
+ 'languages': sorted(list(set(public_df['source_language']).union(public_df['target_language']))),
253
+ 'domains': public_df['domain'].unique().tolist() if 'domain' in public_df.columns else ['general']
254
+ }
255
+ except Exception as e:
256
+ print(f"⚠️ Error calculating stats: {e}")
257
+ stats = {
258
+ 'total_samples': len(public_df),
259
+ 'language_pairs': 0,
260
+ 'google_comparable_samples': 0,
261
+ 'languages': [],
262
+ 'domains': []
263
+ }
264
+
265
  return download_path, stats
266
 
 
267
  def validate_test_set_integrity() -> dict:
268
  """
269
  Validate test set coverage and integrity.
270
  """
271
+ try:
272
+ public_df = get_public_test_set()
273
+ complete_df = get_complete_test_set()
274
+
275
+ if public_df.empty or complete_df.empty:
276
+ return {
277
+ 'alignment_check': False,
278
+ 'total_samples': 0,
279
+ 'coverage_by_pair': {},
280
+ 'missing_pairs': [],
281
+ 'error': 'Test sets are empty or could not be loaded'
282
+ }
283
 
284
+ public_ids = set(public_df['sample_id'])
285
+ private_ids = set(complete_df['sample_id'])
286
 
287
+ coverage_by_pair = {}
288
+ for src in ALL_UG40_LANGUAGES:
289
+ for tgt in ALL_UG40_LANGUAGES:
290
+ if src == tgt:
291
+ continue
292
+ subset = public_df[
293
+ (public_df['source_language'] == src) &
294
+ (public_df['target_language'] == tgt)
295
+ ]
296
+ count = len(subset)
297
+ coverage_by_pair[f"{src}_{tgt}"] = {
298
+ 'count': count,
299
+ 'has_samples': count >= MIN_SAMPLES_PER_PAIR
300
+ }
301
 
302
+ return {
303
+ 'alignment_check': public_ids <= private_ids,
304
+ 'total_samples': len(public_df),
305
+ 'coverage_by_pair': coverage_by_pair,
306
+ 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']],
307
+ 'public_samples': len(public_df),
308
+ 'private_samples': len(complete_df),
309
+ 'id_alignment_rate': len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0
310
+ }
311
+
312
+ except Exception as e:
313
+ return {
314
+ 'alignment_check': False,
315
+ 'total_samples': 0,
316
+ 'coverage_by_pair': {},
317
+ 'missing_pairs': [],
318
+ 'error': f'Validation failed: {str(e)}'
319
+ }