akera commited on
Commit
57c7739
Β·
verified Β·
1 Parent(s): 7a9f2cb

Update src/test_set.py

Browse files
Files changed (1) hide show
  1. src/test_set.py +164 -162
src/test_set.py CHANGED
@@ -1,19 +1,30 @@
1
- # src/test_set.py
2
  import pandas as pd
3
  import yaml
4
- from datasets import Dataset, load_dataset
5
- from typing import Dict, Tuple
6
- from config import *
 
 
 
 
 
 
 
7
  import salt.dataset
8
-
9
  from src.utils import get_all_language_pairs
10
 
 
 
 
 
 
11
  def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
12
- """Generate standardized test set from SALT dataset."""
13
-
14
- print("Generating SALT test set...")
15
-
16
- # Load full SALT dataset
17
  dataset_config = f'''
18
  huggingface_load:
19
  path: {SALT_DATASET}
@@ -27,178 +38,169 @@ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFr
27
  language: {ALL_UG40_LANGUAGES}
28
  allow_same_src_and_tgt_language: False
29
  '''
30
-
31
  config = yaml.safe_load(dataset_config)
32
  full_data = pd.DataFrame(salt.dataset.create(config))
33
-
34
- # Sample data for each language pair
35
  test_samples = []
36
  sample_id_counter = 1
37
-
38
  for src_lang in ALL_UG40_LANGUAGES:
39
  for tgt_lang in ALL_UG40_LANGUAGES:
40
- if src_lang != tgt_lang:
41
- # Filter for this language pair
42
- pair_data = full_data[
43
- (full_data['source.language'] == src_lang) &
44
- (full_data['target.language'] == tgt_lang)
45
- ].copy()
46
-
47
- if len(pair_data) > 0:
48
- # Sample up to max_samples_per_pair
49
- n_samples = min(len(pair_data), max_samples_per_pair)
50
- sampled = pair_data.sample(n=n_samples, random_state=42)
51
-
52
- # Add to test set with unique IDs
53
- for _, row in sampled.iterrows():
54
- test_samples.append({
55
- 'sample_id': f"salt_{sample_id_counter:06d}",
56
- 'source_text': row['source'],
57
- 'target_text': row['target'], # Hidden from public test set
58
- 'source_language': src_lang,
59
- 'target_language': tgt_lang,
60
- 'domain': row.get('domain', 'general'),
61
- 'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
62
- tgt_lang in GOOGLE_SUPPORTED_LANGUAGES)
63
- })
64
- sample_id_counter += 1
65
-
 
 
66
  test_df = pd.DataFrame(test_samples)
67
-
68
- print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs")
69
-
70
  return test_df
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def get_public_test_set() -> pd.DataFrame:
74
- """Get public test set (sources only, no targets)."""
75
-
 
 
 
76
  try:
77
- # Try to load existing test set
78
- print(f"Loading test set from: {TEST_SET_DATASET}")
79
- dataset = load_dataset(TEST_SET_DATASET, split='train')
80
- test_df = dataset.to_pandas()
81
- print(f"Loaded existing test set with {len(test_df)} samples")
82
-
83
  except Exception as e:
84
- print(f"Could not load existing test set: {e}")
85
- print("This is expected for first run. Generating new test set...")
86
-
87
- # Generate new test set
88
- test_df = generate_test_set()
89
-
90
- # Save complete test set (with targets) privately
91
- print("Saving test set for future use...")
92
  try:
93
- save_complete_test_set(test_df)
94
- except Exception as save_error:
95
- print(f"Warning: Could not save test set: {save_error}")
96
- print("Continuing with generated test set...")
97
-
98
- # Return public version (without targets)
99
- public_columns = [
100
- 'sample_id', 'source_text', 'source_language',
101
- 'target_language', 'domain', 'google_comparable'
102
- ]
103
-
104
- return test_df[public_columns].copy()
105
 
106
  def get_complete_test_set() -> pd.DataFrame:
107
- """Get complete test set with targets (for evaluation)."""
108
-
 
 
 
109
  try:
110
- # Load from private storage or regenerate
111
- dataset = load_dataset(TEST_SET_DATASET + "-private", split='train')
112
- return dataset.to_pandas()
113
-
114
  except Exception as e:
115
- print(f"Regenerating complete test set: {e}")
116
- return generate_test_set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- def save_complete_test_set(test_df: pd.DataFrame) -> bool:
119
- """Save complete test set to HuggingFace dataset."""
120
-
121
- try:
122
- # Save public version (no targets)
123
- public_df = test_df[[
124
- 'sample_id', 'source_text', 'source_language',
125
- 'target_language', 'domain', 'google_comparable'
126
- ]].copy()
127
-
128
- public_dataset = Dataset.from_pandas(public_df)
129
- public_dataset.push_to_hub(
130
- TEST_SET_DATASET,
131
- token=HF_TOKEN,
132
- commit_message="Update public test set"
133
- )
134
-
135
- # Save private version (with targets)
136
- private_dataset = Dataset.from_pandas(test_df)
137
- private_dataset.push_to_hub(
138
- TEST_SET_DATASET + "-private",
139
- token=HF_TOKEN,
140
- private=True,
141
- commit_message="Update private test set with targets"
142
- )
143
-
144
- print("Test sets saved successfully!")
145
- return True
146
-
147
- except Exception as e:
148
- print(f"Error saving test sets: {e}")
149
- return False
150
-
151
- def create_test_set_download() -> Tuple[str, Dict]:
152
- """Create downloadable test set file and statistics."""
153
-
154
- public_test = get_public_test_set()
155
-
156
- # Create download file
157
- download_path = "salt_test_set.csv"
158
- public_test.to_csv(download_path, index=False)
159
-
160
- # Generate statistics
161
  stats = {
162
- 'total_samples': len(public_test),
163
- 'language_pairs': len(public_test.groupby(['source_language', 'target_language'])),
164
- 'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]),
165
- 'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())),
166
- 'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general']
167
  }
168
-
169
  return download_path, stats
170
 
171
- def validate_test_set_integrity() -> Dict:
172
- """Validate test set integrity and coverage."""
173
-
174
- try:
175
- public_test = get_public_test_set()
176
- complete_test = get_complete_test_set()
177
-
178
- # Check alignment
179
- public_ids = set(public_test['sample_id'])
180
- private_ids = set(complete_test['sample_id'])
181
-
182
- coverage_by_pair = {}
183
- for src in ALL_UG40_LANGUAGES:
184
- for tgt in ALL_UG40_LANGUAGES:
185
- if src != tgt:
186
- pair_samples = public_test[
187
- (public_test['source_language'] == src) &
188
- (public_test['target_language'] == tgt)
189
- ]
190
-
191
- coverage_by_pair[f"{src}_{tgt}"] = {
192
- 'count': len(pair_samples),
193
- 'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR
194
- }
195
-
196
- return {
197
- 'alignment_check': len(public_ids - private_ids) == 0,
198
- 'total_samples': len(public_test),
199
- 'coverage_by_pair': coverage_by_pair,
200
- 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
201
- }
202
-
203
- except Exception as e:
204
- return {'error': str(e)}
 
1
+ import os
2
  import pandas as pd
3
  import yaml
4
+ from datasets import load_dataset
5
+ from config import (
6
+ TEST_SET_DATASET,
7
+ SALT_DATASET,
8
+ MAX_TEST_SAMPLES,
9
+ HF_TOKEN,
10
+ MIN_SAMPLES_PER_PAIR,
11
+ ALL_UG40_LANGUAGES,
12
+ GOOGLE_SUPPORTED_LANGUAGES
13
+ )
14
  import salt.dataset
 
15
  from src.utils import get_all_language_pairs
16
 
17
+ # Local CSV filenames for persistence
18
+ LOCAL_PUBLIC_CSV = "salt_test_set.csv"
19
+ LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
20
+
21
+
22
  def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
23
+ """
24
+ Generate standardized test set from the SALT dataset.
25
+ """
26
+ print("πŸ”„ Generating SALT test set from source dataset...")
27
+ # Build SALT dataset config
28
  dataset_config = f'''
29
  huggingface_load:
30
  path: {SALT_DATASET}
 
38
  language: {ALL_UG40_LANGUAGES}
39
  allow_same_src_and_tgt_language: False
40
  '''
 
41
  config = yaml.safe_load(dataset_config)
42
  full_data = pd.DataFrame(salt.dataset.create(config))
43
+
 
44
  test_samples = []
45
  sample_id_counter = 1
46
+
47
  for src_lang in ALL_UG40_LANGUAGES:
48
  for tgt_lang in ALL_UG40_LANGUAGES:
49
+ if src_lang == tgt_lang:
50
+ continue
51
+ pair_data = full_data[
52
+ (full_data['source.language'] == src_lang) &
53
+ (full_data['target.language'] == tgt_lang)
54
+ ]
55
+ if pair_data.empty:
56
+ continue
57
+
58
+ # Sample up to max_samples_per_pair
59
+ n_samples = min(len(pair_data), max_samples_per_pair)
60
+ sampled = pair_data.sample(n=n_samples, random_state=42)
61
+
62
+ for _, row in sampled.iterrows():
63
+ test_samples.append({
64
+ 'sample_id': f"salt_{sample_id_counter:06d}",
65
+ 'source_text': row['source'],
66
+ 'target_text': row['target'],
67
+ 'source_language': src_lang,
68
+ 'target_language': tgt_lang,
69
+ 'domain': row.get('domain', 'general'),
70
+ 'google_comparable': (
71
+ src_lang in GOOGLE_SUPPORTED_LANGUAGES and
72
+ tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
73
+ )
74
+ })
75
+ sample_id_counter += 1
76
+
77
  test_df = pd.DataFrame(test_samples)
78
+ print(f"βœ… Generated test set: {len(test_df):,} samples across {len(get_all_language_pairs()):,} pairs")
 
 
79
  return test_df
80
 
81
 
82
+ def _generate_and_save_test_set() -> (pd.DataFrame, pd.DataFrame):
83
+ """
84
+ Generate the full test set and persist both public and complete CSV files.
85
+ """
86
+ full_df = generate_test_set()
87
+ # Public version (no target_text)
88
+ public_df = full_df[[
89
+ 'sample_id', 'source_text', 'source_language',
90
+ 'target_language', 'domain', 'google_comparable'
91
+ ]]
92
+ public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
93
+ # Complete version (with target_text)
94
+ full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
95
+ print(f"βœ… Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
96
+ return public_df, full_df
97
+
98
+
99
  def get_public_test_set() -> pd.DataFrame:
100
+ """
101
+ Load the public test set (without targets).
102
+ Tries HF Hub β†’ local CSV β†’ regenerate.
103
+ """
104
+ # 1) Try HF Hub
105
  try:
106
+ ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
107
+ df = ds.to_pandas()
108
+ print(f"βœ… Loaded public test set from HF Hub ({len(df):,} samples)")
109
+ return df
 
 
110
  except Exception as e:
111
+ print("⚠️ HF Hub load failed, falling back to local CSV:", e)
112
+
113
+ # 2) Try local CSV
114
+ if os.path.exists(LOCAL_PUBLIC_CSV):
 
 
 
 
115
  try:
116
+ df = pd.read_csv(LOCAL_PUBLIC_CSV)
117
+ print(f"βœ… Loaded public test set from local CSV ({len(df):,} samples)")
118
+ return df
119
+ except Exception as e:
120
+ print("⚠️ Failed to read local CSV, regenerating:", e)
121
+
122
+ # 3) Regenerate & save
123
+ print("πŸ”„ Generating new public test set and saving to CSV...")
124
+ public_df, _ = _generate_and_save_test_set()
125
+ return public_df
126
+
 
127
 
128
  def get_complete_test_set() -> pd.DataFrame:
129
+ """
130
+ Load the complete test set (with targets).
131
+ Tries HF Hub-private β†’ local CSV β†’ regenerate.
132
+ """
133
+ # 1) Try HF Hub private
134
  try:
135
+ ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
136
+ df = ds.to_pandas()
137
+ print(f"βœ… Loaded complete test set from HF Hub-private ({len(df):,} samples)")
138
+ return df
139
  except Exception as e:
140
+ print("⚠️ HF Hub-private load failed, falling back to local CSV:", e)
141
+
142
+ # 2) Try local CSV
143
+ if os.path.exists(LOCAL_COMPLETE_CSV):
144
+ try:
145
+ df = pd.read_csv(LOCAL_COMPLETE_CSV)
146
+ print(f"βœ… Loaded complete test set from local CSV ({len(df):,} samples)")
147
+ return df
148
+ except Exception as e:
149
+ print("⚠️ Failed to read local complete CSV, regenerating:", e)
150
+
151
+ # 3) Regenerate & save
152
+ print("πŸ”„ Generating new complete test set and saving to CSV...")
153
+ _, complete_df = _generate_and_save_test_set()
154
+ return complete_df
155
+
156
+
157
+ def create_test_set_download() -> (str, dict):
158
+ """
159
+ Create a CSV download of the public test set and return its path + stats.
160
+ """
161
+ public_df = get_public_test_set()
162
+ download_path = LOCAL_PUBLIC_CSV
163
+ # Ensure the CSV is up-to-date
164
+ public_df.to_csv(download_path, index=False)
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  stats = {
167
+ 'total_samples': len(public_df),
168
+ 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
169
+ 'google_comparable_samples': int(public_df['google_comparable'].sum()),
170
+ 'languages': list(set(public_df['source_language']).union(public_df['target_language'])),
171
+ 'domains': public_df['domain'].unique().tolist()
172
  }
 
173
  return download_path, stats
174
 
175
+
176
+ def validate_test_set_integrity() -> dict:
177
+ """
178
+ Validate test set coverage and integrity.
179
+ """
180
+ public_df = get_public_test_set()
181
+ complete_df = get_complete_test_set()
182
+
183
+ public_ids = set(public_df['sample_id'])
184
+ private_ids = set(complete_df['sample_id'])
185
+
186
+ coverage_by_pair = {}
187
+ for src in ALL_UG40_LANGUAGES:
188
+ for tgt in ALL_UG40_LANGUAGES:
189
+ if src == tgt:
190
+ continue
191
+ subset = public_df[
192
+ (public_df['source_language'] == src) &
193
+ (public_df['target_language'] == tgt)
194
+ ]
195
+ count = len(subset)
196
+ coverage_by_pair[f"{src}_{tgt}"] = {
197
+ 'count': count,
198
+ 'has_samples': count >= MIN_SAMPLES_PER_PAIR
199
+ }
200
+
201
+ return {
202
+ 'alignment_check': public_ids <= private_ids,
203
+ 'total_samples': len(public_df),
204
+ 'coverage_by_pair': coverage_by_pair,
205
+ 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
206
+ }