File size: 6,977 Bytes
57c7739
c1926c2
 
57c7739
 
 
 
 
 
 
 
 
 
cc0d353
c0c3e37
c1926c2
57c7739
 
 
 
 
c1926c2
57c7739
 
 
 
 
c1926c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c7739
c1926c2
 
57c7739
c1926c2
 
57c7739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1926c2
57c7739
c1926c2
 
8003b5b
57c7739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1926c2
57c7739
 
 
 
 
c1926c2
57c7739
 
 
 
c1926c2
57c7739
 
 
 
8003b5b
57c7739
 
 
 
 
 
 
 
 
 
 
c1926c2
 
57c7739
 
 
 
 
c1926c2
57c7739
 
 
 
c1926c2
57c7739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1926c2
 
57c7739
 
 
 
 
c1926c2
 
 
57c7739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import pandas as pd
import yaml
from datasets import load_dataset
from config import (
    TEST_SET_DATASET,
    SALT_DATASET,
    MAX_TEST_SAMPLES,
    HF_TOKEN,
    MIN_SAMPLES_PER_PAIR,
    ALL_UG40_LANGUAGES,
    GOOGLE_SUPPORTED_LANGUAGES
)
import salt.dataset
from src.utils import get_all_language_pairs

# Local CSV filenames for persistence
LOCAL_PUBLIC_CSV = "salt_test_set.csv"
LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"


def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
    """
    Generate standardized test set from the SALT dataset.
    """
    print("πŸ”„ Generating SALT test set from source dataset...")
    # Build SALT dataset config
    dataset_config = f'''
    huggingface_load:
      path: {SALT_DATASET}
      name: text-all
      split: test
    source:
      type: text
      language: {ALL_UG40_LANGUAGES}
    target:
      type: text
      language: {ALL_UG40_LANGUAGES}
    allow_same_src_and_tgt_language: False
    '''
    config = yaml.safe_load(dataset_config)
    full_data = pd.DataFrame(salt.dataset.create(config))

    test_samples = []
    sample_id_counter = 1

    for src_lang in ALL_UG40_LANGUAGES:
        for tgt_lang in ALL_UG40_LANGUAGES:
            if src_lang == tgt_lang:
                continue
            pair_data = full_data[
                (full_data['source.language'] == src_lang) &
                (full_data['target.language'] == tgt_lang)
            ]
            if pair_data.empty:
                continue

            # Sample up to max_samples_per_pair
            n_samples = min(len(pair_data), max_samples_per_pair)
            sampled = pair_data.sample(n=n_samples, random_state=42)

            for _, row in sampled.iterrows():
                test_samples.append({
                    'sample_id': f"salt_{sample_id_counter:06d}",
                    'source_text': row['source'],
                    'target_text': row['target'],
                    'source_language': src_lang,
                    'target_language': tgt_lang,
                    'domain': row.get('domain', 'general'),
                    'google_comparable': (
                        src_lang in GOOGLE_SUPPORTED_LANGUAGES and
                        tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
                    )
                })
                sample_id_counter += 1

    test_df = pd.DataFrame(test_samples)
    print(f"βœ… Generated test set: {len(test_df):,} samples across {len(get_all_language_pairs()):,} pairs")
    return test_df


def _generate_and_save_test_set() -> (pd.DataFrame, pd.DataFrame):
    """
    Generate the full test set and persist both public and complete CSV files.
    """
    full_df = generate_test_set()
    # Public version (no target_text)
    public_df = full_df[[
        'sample_id', 'source_text', 'source_language',
        'target_language', 'domain', 'google_comparable'
    ]]
    public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
    # Complete version (with target_text)
    full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
    print(f"βœ… Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
    return public_df, full_df


def get_public_test_set() -> pd.DataFrame:
    """
    Load the public test set (without targets).
    Tries HF Hub β†’ local CSV β†’ regenerate.
    """
    # 1) Try HF Hub
    try:
        ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        print(f"βœ… Loaded public test set from HF Hub ({len(df):,} samples)")
        return df
    except Exception as e:
        print("⚠️ HF Hub load failed, falling back to local CSV:", e)

    # 2) Try local CSV
    if os.path.exists(LOCAL_PUBLIC_CSV):
        try:
            df = pd.read_csv(LOCAL_PUBLIC_CSV)
            print(f"βœ… Loaded public test set from local CSV ({len(df):,} samples)")
            return df
        except Exception as e:
            print("⚠️ Failed to read local CSV, regenerating:", e)

    # 3) Regenerate & save
    print("πŸ”„ Generating new public test set and saving to CSV...")
    public_df, _ = _generate_and_save_test_set()
    return public_df


def get_complete_test_set() -> pd.DataFrame:
    """
    Load the complete test set (with targets).
    Tries HF Hub-private β†’ local CSV β†’ regenerate.
    """
    # 1) Try HF Hub private
    try:
        ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
        df = ds.to_pandas()
        print(f"βœ… Loaded complete test set from HF Hub-private ({len(df):,} samples)")
        return df
    except Exception as e:
        print("⚠️ HF Hub-private load failed, falling back to local CSV:", e)

    # 2) Try local CSV
    if os.path.exists(LOCAL_COMPLETE_CSV):
        try:
            df = pd.read_csv(LOCAL_COMPLETE_CSV)
            print(f"βœ… Loaded complete test set from local CSV ({len(df):,} samples)")
            return df
        except Exception as e:
            print("⚠️ Failed to read local complete CSV, regenerating:", e)

    # 3) Regenerate & save
    print("πŸ”„ Generating new complete test set and saving to CSV...")
    _, complete_df = _generate_and_save_test_set()
    return complete_df


def create_test_set_download() -> (str, dict):
    """
    Create a CSV download of the public test set and return its path + stats.
    """
    public_df = get_public_test_set()
    download_path = LOCAL_PUBLIC_CSV
    # Ensure the CSV is up-to-date
    public_df.to_csv(download_path, index=False)

    stats = {
        'total_samples': len(public_df),
        'language_pairs': len(public_df.groupby(['source_language', 'target_language'])),
        'google_comparable_samples': int(public_df['google_comparable'].sum()),
        'languages': list(set(public_df['source_language']).union(public_df['target_language'])),
        'domains': public_df['domain'].unique().tolist()
    }
    return download_path, stats


def validate_test_set_integrity() -> dict:
    """
    Validate test set coverage and integrity.
    """
    public_df = get_public_test_set()
    complete_df = get_complete_test_set()

    public_ids = set(public_df['sample_id'])
    private_ids = set(complete_df['sample_id'])

    coverage_by_pair = {}
    for src in ALL_UG40_LANGUAGES:
        for tgt in ALL_UG40_LANGUAGES:
            if src == tgt:
                continue
            subset = public_df[
                (public_df['source_language'] == src) &
                (public_df['target_language'] == tgt)
            ]
            count = len(subset)
            coverage_by_pair[f"{src}_{tgt}"] = {
                'count': count,
                'has_samples': count >= MIN_SAMPLES_PER_PAIR
            }

    return {
        'alignment_check': public_ids <= private_ids,
        'total_samples': len(public_df),
        'coverage_by_pair': coverage_by_pair,
        'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
    }