File size: 11,870 Bytes
aecc3e1
57c7739
c1926c2
 
57c7739
 
 
 
 
 
 
 
 
 
cc0d353
c0c3e37
c1926c2
57c7739
 
 
 
c1926c2
57c7739
 
 
 
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8003b5b
aecc3e1
57c7739
 
 
aecc3e1
 
57c7739
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
57c7739
 
 
 
aecc3e1
 
 
 
 
 
 
 
 
 
57c7739
 
c1926c2
57c7739
 
 
 
 
c1926c2
aecc3e1
57c7739
 
 
 
c1926c2
aecc3e1
57c7739
 
 
8003b5b
57c7739
 
aecc3e1
 
 
 
 
 
57c7739
aecc3e1
57c7739
 
aecc3e1
57c7739
 
 
c1926c2
57c7739
 
 
 
 
c1926c2
aecc3e1
57c7739
 
 
 
c1926c2
aecc3e1
57c7739
 
 
 
 
 
aecc3e1
 
 
 
 
 
57c7739
aecc3e1
57c7739
 
aecc3e1
57c7739
 
 
aecc3e1
57c7739
 
 
 
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
57c7739
 
aecc3e1
 
 
 
c1926c2
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1926c2
 
57c7739
 
 
 
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
57c7739
aecc3e1
 
57c7739
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
57c7739
aecc3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# src/test_set.py
import os
import pandas as pd
import yaml
from datasets import load_dataset
from config import (
    TEST_SET_DATASET,
    SALT_DATASET,
    MAX_TEST_SAMPLES,
    HF_TOKEN,
    MIN_SAMPLES_PER_PAIR,
    ALL_UG40_LANGUAGES,
    GOOGLE_SUPPORTED_LANGUAGES
)
import salt.dataset
from src.utils import get_all_language_pairs

# Local CSV filenames for persistence
LOCAL_PUBLIC_CSV = "salt_test_set.csv"
LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"

def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
    """
    Generate standardized test set from the SALT dataset.
    """
    print("πŸ”„ Generating SALT test set from source dataset...")
    
    try:
        # Build SALT dataset config - using 'test' split for consistency
        dataset_config = f'''
        huggingface_load:
          path: {SALT_DATASET}
          name: text-all
          split: test
        source:
          type: text
          language: {ALL_UG40_LANGUAGES}
        target:
          type: text
          language: {ALL_UG40_LANGUAGES}
        allow_same_src_and_tgt_language: False
        '''
        
        config = yaml.safe_load(dataset_config)
        print("πŸ“₯ Loading SALT dataset...")
        full_data = pd.DataFrame(salt.dataset.create(config))
        
        print(f"πŸ“Š Loaded {len(full_data):,} samples from SALT dataset")
        
        test_samples = []
        sample_id_counter = 1
        
        # Generate samples for each language pair
        for src_lang in ALL_UG40_LANGUAGES:
            for tgt_lang in ALL_UG40_LANGUAGES:
                if src_lang == tgt_lang:
                    continue
                    
                # Filter for this language pair
                pair_data = full_data[
                    (full_data['source.language'] == src_lang) & 
                    (full_data['target.language'] == tgt_lang)
                ]
                
                if pair_data.empty:
                    print(f"⚠️  No data found for {src_lang} β†’ {tgt_lang}")
                    continue
                
                # Sample up to max_samples_per_pair
                n_samples = min(len(pair_data), max_samples_per_pair)
                sampled = pair_data.sample(n=n_samples, random_state=42)
                
                print(f"βœ… {src_lang} β†’ {tgt_lang}: {n_samples} samples")
                
                for _, row in sampled.iterrows():
                    test_samples.append({
                        'sample_id': f"salt_{sample_id_counter:06d}",
                        'source_text': row['source'],
                        'target_text': row['target'],
                        'source_language': src_lang,
                        'target_language': tgt_lang,
                        'domain': row.get('domain', 'general'),
                        'google_comparable': (
                            src_lang in GOOGLE_SUPPORTED_LANGUAGES and
                            tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
                        )
                    })
                    sample_id_counter += 1
        
        test_df = pd.DataFrame(test_samples)
        
        if test_df.empty:
            raise ValueError("No test samples generated - check SALT dataset availability")
            
        print(f"βœ… Generated test set: {len(test_df):,} samples across {len(test_df.groupby(['source_language', 'target_language'])):,} pairs")
        
        # Add some statistics
        google_samples = test_df['google_comparable'].sum()
        unique_pairs = len(test_df.groupby(['source_language', 'target_language']))
        
        print(f"πŸ“ˆ Test set statistics:")
        print(f"   - Total samples: {len(test_df):,}")
        print(f"   - Language pairs: {unique_pairs}")
        print(f"   - Google comparable: {google_samples:,} samples")
        print(f"   - UG40 only: {len(test_df) - google_samples:,} samples")
        
        return test_df
        
    except Exception as e:
        print(f"❌ Error generating test set: {e}")
        # Return empty DataFrame with correct structure
        return pd.DataFrame(columns=[
            'sample_id', 'source_text', 'target_text', 'source_language', 
            'target_language', 'domain', 'google_comparable'
        ])

def _generate_and_save_test_set() -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Generate the full test set and persist both public and complete CSV files.
    """
    print("πŸ”„ Generating and saving test sets...")
    
    full_df = generate_test_set()
    
    if full_df.empty:
        print("❌ Failed to generate test set")
        # Return empty DataFrames with correct structure
        empty_public = pd.DataFrame(columns=[
            'sample_id', 'source_text', 'source_language',
            'target_language', 'domain', 'google_comparable'
        ])
        empty_complete = pd.DataFrame(columns=[
            'sample_id', 'source_text', 'target_text', 'source_language',
            'target_language', 'domain', 'google_comparable'
        ])
        return empty_public, empty_complete
    
    # Public version (no target_text)
    public_df = full_df[[
        'sample_id', 'source_text', 'source_language',
        'target_language', 'domain', 'google_comparable'
    ]].copy()
    
    # Save both versions
    try:
        public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
        full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
        print(f"βœ… Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
    except Exception as e:
        print(f"⚠️  Error saving CSVs: {e}")
    
    return public_df, full_df

def get_public_test_set() -> pd.DataFrame:
    """
    Load the public test set (without targets).
    Tries HF Hub β†’ local CSV β†’ regenerate.
    """
    # 1) Try HF Hub
    try:
        print("πŸ“₯ Attempting to load public test set from HF Hub...")
        ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        print(f"βœ… Loaded public test set from HF Hub ({len(df):,} samples)")
        return df
    except Exception as e:
        print(f"⚠️  HF Hub load failed: {e}")

    # 2) Try local CSV
    if os.path.exists(LOCAL_PUBLIC_CSV):
        try:
            df = pd.read_csv(LOCAL_PUBLIC_CSV)
            print(f"βœ… Loaded public test set from local CSV ({len(df):,} samples)")
            # Validate basic structure
            required_cols = ['sample_id', 'source_text', 'source_language', 'target_language']
            if all(col in df.columns for col in required_cols):
                return df
            else:
                print("⚠️  Local CSV has invalid structure, regenerating...")
        except Exception as e:
            print(f"⚠️  Failed to read local CSV: {e}")

    # 3) Regenerate & save
    print("πŸ”„ Generating new public test set...")
    public_df, _ = _generate_and_save_test_set()
    return public_df

def get_complete_test_set() -> pd.DataFrame:
    """
    Load the complete test set (with targets).
    Tries HF Hub-private β†’ local CSV β†’ regenerate.
    """
    # 1) Try HF Hub private
    try:
        print("πŸ“₯ Attempting to load complete test set from HF Hub-private...")
        ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        print(f"βœ… Loaded complete test set from HF Hub-private ({len(df):,} samples)")
        return df
    except Exception as e:
        print(f"⚠️  HF Hub-private load failed: {e}")

    # 2) Try local CSV
    if os.path.exists(LOCAL_COMPLETE_CSV):
        try:
            df = pd.read_csv(LOCAL_COMPLETE_CSV)
            print(f"βœ… Loaded complete test set from local CSV ({len(df):,} samples)")
            # Validate basic structure
            required_cols = ['sample_id', 'source_text', 'target_text', 'source_language', 'target_language']
            if all(col in df.columns for col in required_cols):
                return df
            else:
                print("⚠️  Local CSV has invalid structure, regenerating...")
        except Exception as e:
            print(f"⚠️  Failed to read local complete CSV: {e}")

    # 3) Regenerate & save
    print("πŸ”„ Generating new complete test set...")
    _, complete_df = _generate_and_save_test_set()
    return complete_df

def create_test_set_download() -> tuple[str, dict]:
    """
    Create a CSV download of the public test set and return its path + stats.
    """
    public_df = get_public_test_set()
    
    if public_df.empty:
        # Create minimal stats for empty dataset
        stats = {
            'total_samples': 0,
            'language_pairs': 0,
            'google_comparable_samples': 0,
            'languages': [],
            'domains': []
        }
        return LOCAL_PUBLIC_CSV, stats
    
    download_path = LOCAL_PUBLIC_CSV
    # Ensure the CSV is up-to-date
    try:
        public_df.to_csv(download_path, index=False)
    except Exception as e:
        print(f"⚠️  Error updating CSV: {e}")

    # Calculate statistics
    try:
        stats = {
            'total_samples': len(public_df),
            'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
            'google_comparable_samples': int(public_df['google_comparable'].sum()) if 'google_comparable' in public_df.columns else 0,
            'languages': sorted(list(set(public_df['source_language']).union(public_df['target_language']))),
            'domains': public_df['domain'].unique().tolist() if 'domain' in public_df.columns else ['general']
        }
    except Exception as e:
        print(f"⚠️  Error calculating stats: {e}")
        stats = {
            'total_samples': len(public_df),
            'language_pairs': 0,
            'google_comparable_samples': 0,
            'languages': [],
            'domains': []
        }
    
    return download_path, stats

def validate_test_set_integrity() -> dict:
    """
    Validate test set coverage and integrity.
    """
    try:
        public_df = get_public_test_set()
        complete_df = get_complete_test_set()
        
        if public_df.empty or complete_df.empty:
            return {
                'alignment_check': False,
                'total_samples': 0,
                'coverage_by_pair': {},
                'missing_pairs': [],
                'error': 'Test sets are empty or could not be loaded'
            }

        public_ids = set(public_df['sample_id'])
        private_ids = set(complete_df['sample_id'])

        coverage_by_pair = {}
        for src in ALL_UG40_LANGUAGES:
            for tgt in ALL_UG40_LANGUAGES:
                if src == tgt:
                    continue
                subset = public_df[
                    (public_df['source_language'] == src) &
                    (public_df['target_language'] == tgt)
                ]
                count = len(subset)
                coverage_by_pair[f"{src}_{tgt}"] = {
                    'count': count,
                    'has_samples': count >= MIN_SAMPLES_PER_PAIR
                }

        return {
            'alignment_check': public_ids <= private_ids,
            'total_samples': len(public_df),
            'coverage_by_pair': coverage_by_pair,
            'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']],
            'public_samples': len(public_df),
            'private_samples': len(complete_df),
            'id_alignment_rate': len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0
        }
        
    except Exception as e:
        return {
            'alignment_check': False,
            'total_samples': 0,
            'coverage_by_pair': {},
            'missing_pairs': [],
            'error': f'Validation failed: {str(e)}'
        }