File size: 13,222 Bytes
aecc3e1
57c7739
c1926c2
 
37b3c92
57c7739
 
 
 
 
 
 
37b3c92
 
d82b528
57c7739
cc0d353
d82b528
fe8454f
 
 
57c7739
d82b528
 
57c7739
37b3c92
d82b528
 
37b3c92
d82b528
aecc3e1
 
37b3c92
 
aecc3e1
 
 
 
 
 
 
 
 
 
 
37b3c92
aecc3e1
 
 
 
 
 
 
 
 
 
d82b528
aecc3e1
 
 
 
37b3c92
aecc3e1
 
37b3c92
 
aecc3e1
 
 
 
 
 
d82b528
 
 
aecc3e1
37b3c92
aecc3e1
 
 
37b3c92
 
 
 
 
 
 
aecc3e1
 
37b3c92
aecc3e1
 
 
 
 
 
 
 
d82b528
aecc3e1
 
 
 
d82b528
aecc3e1
37b3c92
d82b528
aecc3e1
8003b5b
37b3c92
d82b528
 
37b3c92
d82b528
37b3c92
d82b528
aecc3e1
 
d82b528
aecc3e1
37b3c92
d82b528
aecc3e1
 
37b3c92
d82b528
aecc3e1
 
 
57c7739
 
37b3c92
d82b528
aecc3e1
 
d82b528
aecc3e1
 
 
d82b528
aecc3e1
d82b528
aecc3e1
57c7739
 
37b3c92
d82b528
 
37b3c92
57c7739
c1926c2
d82b528
 
57c7739
37b3c92
d82b528
 
37b3c92
d82b528
37b3c92
 
d82b528
37b3c92
c1926c2
aecc3e1
57c7739
 
 
8003b5b
57c7739
37b3c92
aecc3e1
d82b528
aecc3e1
 
 
57c7739
d82b528
57c7739
 
d82b528
 
57c7739
 
37b3c92
d82b528
 
37b3c92
57c7739
c1926c2
d82b528
 
57c7739
37b3c92
d82b528
37b3c92
d82b528
37b3c92
 
d82b528
37b3c92
c1926c2
37b3c92
57c7739
 
 
 
 
37b3c92
aecc3e1
d82b528
aecc3e1
 
37b3c92
57c7739
d82b528
57c7739
 
d82b528
 
57c7739
 
37b3c92
d82b528
 
37b3c92
d82b528
aecc3e1
 
 
37b3c92
 
d82b528
 
 
aecc3e1
 
 
57c7739
37b3c92
57c7739
aecc3e1
 
 
d82b528
c1926c2
37b3c92
aecc3e1
37b3c92
aecc3e1
37b3c92
 
d82b528
37b3c92
 
 
 
 
 
 
 
 
 
 
 
 
 
d82b528
37b3c92
 
 
 
 
 
 
 
 
 
aecc3e1
d82b528
aecc3e1
37b3c92
 
d82b528
 
 
aecc3e1
 
c1926c2
 
37b3c92
d82b528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b3c92
aecc3e1
d82b528
 
aecc3e1
 
 
37b3c92
 
 
 
aecc3e1
57c7739
37b3c92
 
57c7739
37b3c92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d82b528
 
 
 
37b3c92
 
aecc3e1
37b3c92
 
 
 
 
 
aecc3e1
 
 
 
37b3c92
 
 
 
d82b528
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# src/test_set.py
import os
import pandas as pd
import yaml
import numpy as np
from datasets import load_dataset
from config import (
    TEST_SET_DATASET,
    SALT_DATASET,
    MAX_TEST_SAMPLES,
    HF_TOKEN,
    ALL_UG40_LANGUAGES,
    GOOGLE_SUPPORTED_LANGUAGES,
    EVALUATION_TRACKS,
    LANGUAGE_NAMES,
)
import salt.dataset
from src.utils import get_all_language_pairs
from typing import Dict, List, Optional, Tuple


# Local CSV filenames for persistence
LOCAL_PUBLIC_CSV = "salt_test_set.csv"
LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"


def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
    """Generate test set from SALT dataset."""
    
    print("πŸ”¬ Generating SALT test set...")
    
    try:
        # Build SALT dataset config
        dataset_config = f"""
        huggingface_load:
          path: {SALT_DATASET}
          name: text-all
          split: test
        source:
          type: text
          language: {ALL_UG40_LANGUAGES}
        target:
          type: text
          language: {ALL_UG40_LANGUAGES}
        allow_same_src_and_tgt_language: False
        """
        
        config = yaml.safe_load(dataset_config)
        print("πŸ“₯ Loading SALT dataset...")
        full_data = pd.DataFrame(salt.dataset.create(config))
        
        print(f"πŸ“Š Loaded {len(full_data):,} samples from SALT dataset")
        
        test_samples = []
        sample_id_counter = 1
        
        # Generate samples for each language pair
        for src_lang in ALL_UG40_LANGUAGES:
            for tgt_lang in ALL_UG40_LANGUAGES:
                if src_lang == tgt_lang:
                    continue
                
                # Filter for this language pair
                pair_data = full_data[
                    (full_data["source.language"] == src_lang) &
                    (full_data["target.language"] == tgt_lang)
                ]
                
                if pair_data.empty:
                    print(f"⚠️  No data found for {src_lang} β†’ {tgt_lang}")
                    continue
                
                # Sample data for this pair
                n_samples = min(len(pair_data), max_samples_per_pair)
                sampled = pair_data.sample(n=n_samples, random_state=42)
                
                print(f"βœ… {src_lang} β†’ {tgt_lang}: {len(sampled)} samples")
                
                for _, row in sampled.iterrows():
                    test_samples.append({
                        "sample_id": f"salt_{sample_id_counter:06d}",
                        "source_text": row["source"],
                        "target_text": row["target"],
                        "source_language": src_lang,
                        "target_language": tgt_lang,
                        "domain": row.get("domain", "general"),
                        "google_comparable": (
                            src_lang in GOOGLE_SUPPORTED_LANGUAGES and
                            tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
                        ),
                    })
                    sample_id_counter += 1
        
        test_df = pd.DataFrame(test_samples)
        
        if test_df.empty:
            raise ValueError("No test samples generated - check SALT dataset availability")
        
        print(f"βœ… Generated test set: {len(test_df):,} samples")
        
        return test_df
        
    except Exception as e:
        print(f"❌ Error generating test set: {e}")
        return pd.DataFrame(columns=[
            "sample_id", "source_text", "target_text", "source_language",
            "target_language", "domain", "google_comparable"
        ])


def _generate_and_save_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Generate and save both public and complete versions of the test set."""
    
    print("πŸ”¬ Generating and saving test sets...")
    
    full_df = generate_test_set()
    
    if full_df.empty:
        print("❌ Failed to generate test set")
        empty_public = pd.DataFrame(columns=[
            "sample_id", "source_text", "source_language",
            "target_language", "domain", "google_comparable"
        ])
        empty_complete = pd.DataFrame(columns=[
            "sample_id", "source_text", "target_text", "source_language",
            "target_language", "domain", "google_comparable"
        ])
        return empty_public, empty_complete
    
    # Public version (no target_text)
    public_df = full_df[[
        "sample_id", "source_text", "source_language",
        "target_language", "domain", "google_comparable"
    ]].copy()
    
    # Save versions
    try:
        public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
        full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
        print(f"βœ… Saved test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
    except Exception as e:
        print(f"⚠️  Error saving CSVs: {e}")
    
    return public_df, full_df


def get_public_test_set() -> pd.DataFrame:
    """Load the public test set with enhanced fallback logic."""
    
    # 1) Try HF Hub
    try:
        print("πŸ“₯ Attempting to load test set from HF Hub...")
        ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        
        # Validate structure
        required_cols = ["sample_id", "source_text", "source_language", "target_language"]
        if all(col in df.columns for col in required_cols):
            print(f"βœ… Loaded test set from HF Hub ({len(df):,} samples)")
            return df
        else:
            print("⚠️  HF Hub test set missing columns, regenerating...")
            
    except Exception as e:
        print(f"⚠️  HF Hub load failed: {e}")

    # 2) Try local CSV
    if os.path.exists(LOCAL_PUBLIC_CSV):
        try:
            df = pd.read_csv(LOCAL_PUBLIC_CSV)
            required_cols = ["sample_id", "source_text", "source_language", "target_language"]
            if all(col in df.columns for col in required_cols):
                print(f"βœ… Loaded test set from local CSV ({len(df):,} samples)")
                return df
            else:
                print("⚠️  Local CSV has invalid structure, regenerating...")
        except Exception as e:
            print(f"⚠️  Failed to read local CSV: {e}")

    # 3) Regenerate & save
    print("πŸ”„ Generating new test set...")
    public_df, _ = _generate_and_save_test_set()
    return public_df


def get_complete_test_set() -> pd.DataFrame:
    """Load the complete test set with targets."""
    
    # 1) Try HF Hub private
    try:
        print("πŸ“₯ Attempting to load complete test set from HF Hub...")
        ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        
        required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
        if all(col in df.columns for col in required_cols):
            print(f"βœ… Loaded complete test set from HF Hub ({len(df):,} samples)")
            return df
        else:
            print("⚠️  HF Hub complete test set missing columns, regenerating...")
            
    except Exception as e:
        print(f"⚠️  HF Hub private load failed: {e}")

    # 2) Try local CSV
    if os.path.exists(LOCAL_COMPLETE_CSV):
        try:
            df = pd.read_csv(LOCAL_COMPLETE_CSV)
            required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
            if all(col in df.columns for col in required_cols):
                print(f"βœ… Loaded complete test set from local CSV ({len(df):,} samples)")
                return df
            else:
                print("⚠️  Local complete CSV has invalid structure, regenerating...")
        except Exception as e:
            print(f"⚠️  Failed to read local complete CSV: {e}")

    # 3) Regenerate & save
    print("πŸ”„ Generating new complete test set...")
    _, complete_df = _generate_and_save_test_set()
    return complete_df


def create_test_set_download() -> Tuple[str, Dict]:
    """Create test set download with comprehensive metadata."""
    
    public_df = get_public_test_set()
    
    if public_df.empty:
        stats = {
            "total_samples": 0,
            "track_breakdown": {},
            "languages": [],
            "language_pairs": 0,
            "google_comparable_samples": 0,
        }
        return LOCAL_PUBLIC_CSV, stats
    
    download_path = LOCAL_PUBLIC_CSV
    
    # Ensure the CSV is up-to-date
    try:
        public_df.to_csv(download_path, index=False)
    except Exception as e:
        print(f"⚠️  Error updating CSV: {e}")

    # Calculate comprehensive statistics
    try:
        # Basic statistics
        stats = {
            "total_samples": len(public_df),
            "languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))),
            "language_pairs": len(public_df.groupby(["source_language", "target_language"])),
        }
        
        # Track-specific breakdown
        track_breakdown = {}
        for track_name, track_config in EVALUATION_TRACKS.items():
            track_languages = track_config["languages"]
            track_data = public_df[
                (public_df["source_language"].isin(track_languages)) &
                (public_df["target_language"].isin(track_languages))
            ]
            
            track_breakdown[track_name] = {
                "total_samples": len(track_data),
                "language_pairs": len(track_data.groupby(["source_language", "target_language"])),
                "languages": track_languages,
            }
        
        stats["track_breakdown"] = track_breakdown
        
        # Google-comparable statistics
        if "google_comparable" in public_df.columns:
            stats["google_comparable_samples"] = int(public_df["google_comparable"].sum())
        else:
            stats["google_comparable_samples"] = 0
        
    except Exception as e:
        print(f"⚠️  Error calculating stats: {e}")
        stats = {
            "total_samples": len(public_df),
            "track_breakdown": {},
            "languages": [],
            "language_pairs": 0,
            "google_comparable_samples": 0,
        }
    
    return download_path, stats


def get_track_test_set(track: str) -> pd.DataFrame:
    """Get test set filtered for a specific track."""
    
    if track not in EVALUATION_TRACKS:
        print(f"❌ Unknown track: {track}")
        return pd.DataFrame()
    
    # Get main test set and filter
    public_df = get_public_test_set()
    
    if public_df.empty:
        return pd.DataFrame()
    
    track_languages = EVALUATION_TRACKS[track]["languages"]
    track_df = public_df[
        (public_df["source_language"].isin(track_languages)) &
        (public_df["target_language"].isin(track_languages))
    ]
    
    print(f"βœ… Filtered {track} test set: {len(track_df):,} samples")
    return track_df


def validate_test_set_integrity() -> Dict:
    """Validate test set integrity."""
    
    try:
        public_df = get_public_test_set()
        complete_df = get_complete_test_set()
        
        if public_df.empty or complete_df.empty:
            return {
                "alignment_check": False,
                "total_samples": 0,
                "track_analysis": {},
                "error": "Test sets are empty or could not be loaded",
            }

        public_ids = set(public_df["sample_id"])
        private_ids = set(complete_df["sample_id"])

        # Track-specific analysis
        track_analysis = {}
        for track_name, track_config in EVALUATION_TRACKS.items():
            track_languages = track_config["languages"]
            
            # Analyze public set for this track
            track_public = public_df[
                (public_df["source_language"].isin(track_languages)) &
                (public_df["target_language"].isin(track_languages))
            ]
            
            # Analyze complete set for this track
            track_complete = complete_df[
                (complete_df["source_language"].isin(track_languages)) &
                (complete_df["target_language"].isin(track_languages))
            ]
            
            track_analysis[track_name] = {
                "public_samples": len(track_public),
                "complete_samples": len(track_complete),
                "alignment": len(track_public) == len(track_complete),
                "languages": track_languages,
            }

        return {
            "alignment_check": public_ids <= private_ids,
            "total_samples": len(public_df),
            "track_analysis": track_analysis,
            "public_samples": len(public_df),
            "private_samples": len(complete_df),
            "id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0,
        }
        
    except Exception as e:
        return {
            "alignment_check": False,
            "total_samples": 0,
            "track_analysis": {},
            "error": f"Validation failed: {str(e)}",
        }