Spaces:
Running
Running
# src/test_set.py | |
import pandas as pd | |
import yaml | |
from datasets import Dataset, load_dataset | |
from typing import Dict, Tuple | |
import salt.dataset | |
from config import * | |
def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame: | |
"""Generate standardized test set from SALT dataset.""" | |
print("Generating SALT test set...") | |
# Load full SALT dataset | |
dataset_config = f''' | |
huggingface_load: | |
path: {SALT_DATASET} | |
name: text-all | |
split: test | |
source: | |
type: text | |
language: {ALL_UG40_LANGUAGES} | |
target: | |
type: text | |
language: {ALL_UG40_LANGUAGES} | |
allow_same_src_and_tgt_language: False | |
''' | |
config = yaml.safe_load(dataset_config) | |
full_data = pd.DataFrame(salt.dataset.create(config)) | |
# Sample data for each language pair | |
test_samples = [] | |
sample_id_counter = 1 | |
for src_lang in ALL_UG40_LANGUAGES: | |
for tgt_lang in ALL_UG40_LANGUAGES: | |
if src_lang != tgt_lang: | |
# Filter for this language pair | |
pair_data = full_data[ | |
(full_data['source.language'] == src_lang) & | |
(full_data['target.language'] == tgt_lang) | |
].copy() | |
if len(pair_data) > 0: | |
# Sample up to max_samples_per_pair | |
n_samples = min(len(pair_data), max_samples_per_pair) | |
sampled = pair_data.sample(n=n_samples, random_state=42) | |
# Add to test set with unique IDs | |
for _, row in sampled.iterrows(): | |
test_samples.append({ | |
'sample_id': f"salt_{sample_id_counter:06d}", | |
'source_text': row['source'], | |
'target_text': row['target'], # Hidden from public test set | |
'source_language': src_lang, | |
'target_language': tgt_lang, | |
'domain': row.get('domain', 'general'), | |
'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and | |
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES) | |
}) | |
sample_id_counter += 1 | |
test_df = pd.DataFrame(test_samples) | |
print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs") | |
return test_df | |
def get_public_test_set() -> pd.DataFrame: | |
"""Get public test set (sources only, no targets).""" | |
try: | |
# Try to load existing test set | |
dataset = load_dataset(TEST_SET_DATASET, split='train') | |
test_df = dataset.to_pandas() | |
print(f"Loaded existing test set with {len(test_df)} samples") | |
except Exception as e: | |
print(f"Could not load existing test set: {e}") | |
print("Generating new test set...") | |
# Generate new test set | |
test_df = generate_test_set() | |
# Save complete test set (with targets) privately | |
save_complete_test_set(test_df) | |
# Return public version (without targets) | |
public_columns = [ | |
'sample_id', 'source_text', 'source_language', | |
'target_language', 'domain', 'google_comparable' | |
] | |
return test_df[public_columns].copy() | |
def get_complete_test_set() -> pd.DataFrame: | |
"""Get complete test set with targets (for evaluation).""" | |
try: | |
# Load from private storage or regenerate | |
dataset = load_dataset(TEST_SET_DATASET + "-private", split='train') | |
return dataset.to_pandas() | |
except Exception as e: | |
print(f"Regenerating complete test set: {e}") | |
return generate_test_set() | |
def save_complete_test_set(test_df: pd.DataFrame) -> bool: | |
"""Save complete test set to HuggingFace dataset.""" | |
try: | |
# Save public version (no targets) | |
public_df = test_df[[ | |
'sample_id', 'source_text', 'source_language', | |
'target_language', 'domain', 'google_comparable' | |
]].copy() | |
public_dataset = Dataset.from_pandas(public_df) | |
public_dataset.push_to_hub( | |
TEST_SET_DATASET, | |
token=HF_TOKEN, | |
commit_message="Update public test set" | |
) | |
# Save private version (with targets) | |
private_dataset = Dataset.from_pandas(test_df) | |
private_dataset.push_to_hub( | |
TEST_SET_DATASET + "-private", | |
token=HF_TOKEN, | |
private=True, | |
commit_message="Update private test set with targets" | |
) | |
print("Test sets saved successfully!") | |
return True | |
except Exception as e: | |
print(f"Error saving test sets: {e}") | |
return False | |
def create_test_set_download() -> Tuple[str, Dict]: | |
"""Create downloadable test set file and statistics.""" | |
public_test = get_public_test_set() | |
# Create download file | |
download_path = "salt_test_set.csv" | |
public_test.to_csv(download_path, index=False) | |
# Generate statistics | |
stats = { | |
'total_samples': len(public_test), | |
'language_pairs': len(public_test.groupby(['source_language', 'target_language'])), | |
'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]), | |
'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())), | |
'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general'] | |
} | |
return download_path, stats | |
def validate_test_set_integrity() -> Dict: | |
"""Validate test set integrity and coverage.""" | |
try: | |
public_test = get_public_test_set() | |
complete_test = get_complete_test_set() | |
# Check alignment | |
public_ids = set(public_test['sample_id']) | |
private_ids = set(complete_test['sample_id']) | |
coverage_by_pair = {} | |
for src in ALL_UG40_LANGUAGES: | |
for tgt in ALL_UG40_LANGUAGES: | |
if src != tgt: | |
pair_samples = public_test[ | |
(public_test['source_language'] == src) & | |
(public_test['target_language'] == tgt) | |
] | |
coverage_by_pair[f"{src}_{tgt}"] = { | |
'count': len(pair_samples), | |
'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR | |
} | |
return { | |
'alignment_check': len(public_ids - private_ids) == 0, | |
'total_samples': len(public_test), | |
'coverage_by_pair': coverage_by_pair, | |
'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']] | |
} | |
except Exception as e: | |
return {'error': str(e)} |