Spaces:
Sleeping
Sleeping
File size: 7,026 Bytes
c1926c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# src/test_set.py
import pandas as pd
import yaml
from datasets import Dataset, load_dataset
from typing import Dict, Tuple
import salt.dataset
from config import *
def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
"""Generate standardized test set from SALT dataset."""
print("Generating SALT test set...")
# Load full SALT dataset
dataset_config = f'''
huggingface_load:
path: {SALT_DATASET}
name: text-all
split: test
source:
type: text
language: {ALL_UG40_LANGUAGES}
target:
type: text
language: {ALL_UG40_LANGUAGES}
allow_same_src_and_tgt_language: False
'''
config = yaml.safe_load(dataset_config)
full_data = pd.DataFrame(salt.dataset.create(config))
# Sample data for each language pair
test_samples = []
sample_id_counter = 1
for src_lang in ALL_UG40_LANGUAGES:
for tgt_lang in ALL_UG40_LANGUAGES:
if src_lang != tgt_lang:
# Filter for this language pair
pair_data = full_data[
(full_data['source.language'] == src_lang) &
(full_data['target.language'] == tgt_lang)
].copy()
if len(pair_data) > 0:
# Sample up to max_samples_per_pair
n_samples = min(len(pair_data), max_samples_per_pair)
sampled = pair_data.sample(n=n_samples, random_state=42)
# Add to test set with unique IDs
for _, row in sampled.iterrows():
test_samples.append({
'sample_id': f"salt_{sample_id_counter:06d}",
'source_text': row['source'],
'target_text': row['target'], # Hidden from public test set
'source_language': src_lang,
'target_language': tgt_lang,
'domain': row.get('domain', 'general'),
'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES)
})
sample_id_counter += 1
test_df = pd.DataFrame(test_samples)
print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs")
return test_df
def get_public_test_set() -> pd.DataFrame:
"""Get public test set (sources only, no targets)."""
try:
# Try to load existing test set
dataset = load_dataset(TEST_SET_DATASET, split='train')
test_df = dataset.to_pandas()
print(f"Loaded existing test set with {len(test_df)} samples")
except Exception as e:
print(f"Could not load existing test set: {e}")
print("Generating new test set...")
# Generate new test set
test_df = generate_test_set()
# Save complete test set (with targets) privately
save_complete_test_set(test_df)
# Return public version (without targets)
public_columns = [
'sample_id', 'source_text', 'source_language',
'target_language', 'domain', 'google_comparable'
]
return test_df[public_columns].copy()
def get_complete_test_set() -> pd.DataFrame:
"""Get complete test set with targets (for evaluation)."""
try:
# Load from private storage or regenerate
dataset = load_dataset(TEST_SET_DATASET + "-private", split='train')
return dataset.to_pandas()
except Exception as e:
print(f"Regenerating complete test set: {e}")
return generate_test_set()
def save_complete_test_set(test_df: pd.DataFrame) -> bool:
"""Save complete test set to HuggingFace dataset."""
try:
# Save public version (no targets)
public_df = test_df[[
'sample_id', 'source_text', 'source_language',
'target_language', 'domain', 'google_comparable'
]].copy()
public_dataset = Dataset.from_pandas(public_df)
public_dataset.push_to_hub(
TEST_SET_DATASET,
token=HF_TOKEN,
commit_message="Update public test set"
)
# Save private version (with targets)
private_dataset = Dataset.from_pandas(test_df)
private_dataset.push_to_hub(
TEST_SET_DATASET + "-private",
token=HF_TOKEN,
private=True,
commit_message="Update private test set with targets"
)
print("Test sets saved successfully!")
return True
except Exception as e:
print(f"Error saving test sets: {e}")
return False
def create_test_set_download() -> Tuple[str, Dict]:
"""Create downloadable test set file and statistics."""
public_test = get_public_test_set()
# Create download file
download_path = "salt_test_set.csv"
public_test.to_csv(download_path, index=False)
# Generate statistics
stats = {
'total_samples': len(public_test),
'language_pairs': len(public_test.groupby(['source_language', 'target_language'])),
'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]),
'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())),
'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general']
}
return download_path, stats
def validate_test_set_integrity() -> Dict:
"""Validate test set integrity and coverage."""
try:
public_test = get_public_test_set()
complete_test = get_complete_test_set()
# Check alignment
public_ids = set(public_test['sample_id'])
private_ids = set(complete_test['sample_id'])
coverage_by_pair = {}
for src in ALL_UG40_LANGUAGES:
for tgt in ALL_UG40_LANGUAGES:
if src != tgt:
pair_samples = public_test[
(public_test['source_language'] == src) &
(public_test['target_language'] == tgt)
]
coverage_by_pair[f"{src}_{tgt}"] = {
'count': len(pair_samples),
'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR
}
return {
'alignment_check': len(public_ids - private_ids) == 0,
'total_samples': len(public_test),
'coverage_by_pair': coverage_by_pair,
'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
}
except Exception as e:
return {'error': str(e)} |