File size: 7,026 Bytes
c1926c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# src/test_set.py
import pandas as pd
import yaml
from datasets import Dataset, load_dataset
from typing import Dict, Tuple
import salt.dataset
from config import *

def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
    """Generate standardized test set from SALT dataset."""
    
    print("Generating SALT test set...")
    
    # Load full SALT dataset
    dataset_config = f'''
    huggingface_load:
      path: {SALT_DATASET}
      name: text-all
      split: test
    source:
      type: text
      language: {ALL_UG40_LANGUAGES}
    target:
      type: text
      language: {ALL_UG40_LANGUAGES}
    allow_same_src_and_tgt_language: False
    '''
    
    config = yaml.safe_load(dataset_config)
    full_data = pd.DataFrame(salt.dataset.create(config))
    
    # Sample data for each language pair
    test_samples = []
    sample_id_counter = 1
    
    for src_lang in ALL_UG40_LANGUAGES:
        for tgt_lang in ALL_UG40_LANGUAGES:
            if src_lang != tgt_lang:
                # Filter for this language pair
                pair_data = full_data[
                    (full_data['source.language'] == src_lang) & 
                    (full_data['target.language'] == tgt_lang)
                ].copy()
                
                if len(pair_data) > 0:
                    # Sample up to max_samples_per_pair
                    n_samples = min(len(pair_data), max_samples_per_pair)
                    sampled = pair_data.sample(n=n_samples, random_state=42)
                    
                    # Add to test set with unique IDs
                    for _, row in sampled.iterrows():
                        test_samples.append({
                            'sample_id': f"salt_{sample_id_counter:06d}",
                            'source_text': row['source'],
                            'target_text': row['target'],  # Hidden from public test set
                            'source_language': src_lang,
                            'target_language': tgt_lang,
                            'domain': row.get('domain', 'general'),
                            'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and 
                                                tgt_lang in GOOGLE_SUPPORTED_LANGUAGES)
                        })
                        sample_id_counter += 1
    
    test_df = pd.DataFrame(test_samples)
    
    print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs")
    
    return test_df

def get_public_test_set() -> pd.DataFrame:
    """Get public test set (sources only, no targets)."""
    
    try:
        # Try to load existing test set
        dataset = load_dataset(TEST_SET_DATASET, split='train')
        test_df = dataset.to_pandas()
        print(f"Loaded existing test set with {len(test_df)} samples")
        
    except Exception as e:
        print(f"Could not load existing test set: {e}")
        print("Generating new test set...")
        
        # Generate new test set
        test_df = generate_test_set()
        
        # Save complete test set (with targets) privately
        save_complete_test_set(test_df)
    
    # Return public version (without targets)
    public_columns = [
        'sample_id', 'source_text', 'source_language', 
        'target_language', 'domain', 'google_comparable'
    ]
    
    return test_df[public_columns].copy()

def get_complete_test_set() -> pd.DataFrame:
    """Get complete test set with targets (for evaluation)."""
    
    try:
        # Load from private storage or regenerate
        dataset = load_dataset(TEST_SET_DATASET + "-private", split='train')
        return dataset.to_pandas()
        
    except Exception as e:
        print(f"Regenerating complete test set: {e}")
        return generate_test_set()

def save_complete_test_set(test_df: pd.DataFrame) -> bool:
    """Save complete test set to HuggingFace dataset."""
    
    try:
        # Save public version (no targets)
        public_df = test_df[[
            'sample_id', 'source_text', 'source_language',
            'target_language', 'domain', 'google_comparable'
        ]].copy()
        
        public_dataset = Dataset.from_pandas(public_df)
        public_dataset.push_to_hub(
            TEST_SET_DATASET,
            token=HF_TOKEN,
            commit_message="Update public test set"
        )
        
        # Save private version (with targets) 
        private_dataset = Dataset.from_pandas(test_df)
        private_dataset.push_to_hub(
            TEST_SET_DATASET + "-private",
            token=HF_TOKEN,
            private=True,
            commit_message="Update private test set with targets"
        )
        
        print("Test sets saved successfully!")
        return True
        
    except Exception as e:
        print(f"Error saving test sets: {e}")
        return False

def create_test_set_download() -> Tuple[str, Dict]:
    """Create downloadable test set file and statistics."""
    
    public_test = get_public_test_set()
    
    # Create download file
    download_path = "salt_test_set.csv"
    public_test.to_csv(download_path, index=False)
    
    # Generate statistics
    stats = {
        'total_samples': len(public_test),
        'language_pairs': len(public_test.groupby(['source_language', 'target_language'])),
        'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]),
        'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())),
        'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general']
    }
    
    return download_path, stats

def validate_test_set_integrity() -> Dict:
    """Validate test set integrity and coverage."""
    
    try:
        public_test = get_public_test_set()
        complete_test = get_complete_test_set()
        
        # Check alignment
        public_ids = set(public_test['sample_id'])
        private_ids = set(complete_test['sample_id'])
        
        coverage_by_pair = {}
        for src in ALL_UG40_LANGUAGES:
            for tgt in ALL_UG40_LANGUAGES:
                if src != tgt:
                    pair_samples = public_test[
                        (public_test['source_language'] == src) & 
                        (public_test['target_language'] == tgt)
                    ]
                    
                    coverage_by_pair[f"{src}_{tgt}"] = {
                        'count': len(pair_samples),
                        'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR
                    }
        
        return {
            'alignment_check': len(public_ids - private_ids) == 0,
            'total_samples': len(public_test),
            'coverage_by_pair': coverage_by_pair,
            'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
        }
        
    except Exception as e:
        return {'error': str(e)}