# src/test_set.py import pandas as pd import yaml from datasets import Dataset, load_dataset from typing import Dict, Tuple import salt.dataset from config import * def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame: """Generate standardized test set from SALT dataset.""" print("Generating SALT test set...") # Load full SALT dataset dataset_config = f''' huggingface_load: path: {SALT_DATASET} name: text-all split: test source: type: text language: {ALL_UG40_LANGUAGES} target: type: text language: {ALL_UG40_LANGUAGES} allow_same_src_and_tgt_language: False ''' config = yaml.safe_load(dataset_config) full_data = pd.DataFrame(salt.dataset.create(config)) # Sample data for each language pair test_samples = [] sample_id_counter = 1 for src_lang in ALL_UG40_LANGUAGES: for tgt_lang in ALL_UG40_LANGUAGES: if src_lang != tgt_lang: # Filter for this language pair pair_data = full_data[ (full_data['source.language'] == src_lang) & (full_data['target.language'] == tgt_lang) ].copy() if len(pair_data) > 0: # Sample up to max_samples_per_pair n_samples = min(len(pair_data), max_samples_per_pair) sampled = pair_data.sample(n=n_samples, random_state=42) # Add to test set with unique IDs for _, row in sampled.iterrows(): test_samples.append({ 'sample_id': f"salt_{sample_id_counter:06d}", 'source_text': row['source'], 'target_text': row['target'], # Hidden from public test set 'source_language': src_lang, 'target_language': tgt_lang, 'domain': row.get('domain', 'general'), 'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and tgt_lang in GOOGLE_SUPPORTED_LANGUAGES) }) sample_id_counter += 1 test_df = pd.DataFrame(test_samples) print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs") return test_df def get_public_test_set() -> pd.DataFrame: """Get public test set (sources only, no targets).""" try: # Try to load existing test set dataset = load_dataset(TEST_SET_DATASET, split='train') test_df = dataset.to_pandas() print(f"Loaded existing test set with {len(test_df)} samples") except Exception as e: print(f"Could not load existing test set: {e}") print("Generating new test set...") # Generate new test set test_df = generate_test_set() # Save complete test set (with targets) privately save_complete_test_set(test_df) # Return public version (without targets) public_columns = [ 'sample_id', 'source_text', 'source_language', 'target_language', 'domain', 'google_comparable' ] return test_df[public_columns].copy() def get_complete_test_set() -> pd.DataFrame: """Get complete test set with targets (for evaluation).""" try: # Load from private storage or regenerate dataset = load_dataset(TEST_SET_DATASET + "-private", split='train') return dataset.to_pandas() except Exception as e: print(f"Regenerating complete test set: {e}") return generate_test_set() def save_complete_test_set(test_df: pd.DataFrame) -> bool: """Save complete test set to HuggingFace dataset.""" try: # Save public version (no targets) public_df = test_df[[ 'sample_id', 'source_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]].copy() public_dataset = Dataset.from_pandas(public_df) public_dataset.push_to_hub( TEST_SET_DATASET, token=HF_TOKEN, commit_message="Update public test set" ) # Save private version (with targets) private_dataset = Dataset.from_pandas(test_df) private_dataset.push_to_hub( TEST_SET_DATASET + "-private", token=HF_TOKEN, private=True, commit_message="Update private test set with targets" ) print("Test sets saved successfully!") return True except Exception as e: print(f"Error saving test sets: {e}") return False def create_test_set_download() -> Tuple[str, Dict]: """Create downloadable test set file and statistics.""" public_test = get_public_test_set() # Create download file download_path = "salt_test_set.csv" public_test.to_csv(download_path, index=False) # Generate statistics stats = { 'total_samples': len(public_test), 'language_pairs': len(public_test.groupby(['source_language', 'target_language'])), 'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]), 'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())), 'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general'] } return download_path, stats def validate_test_set_integrity() -> Dict: """Validate test set integrity and coverage.""" try: public_test = get_public_test_set() complete_test = get_complete_test_set() # Check alignment public_ids = set(public_test['sample_id']) private_ids = set(complete_test['sample_id']) coverage_by_pair = {} for src in ALL_UG40_LANGUAGES: for tgt in ALL_UG40_LANGUAGES: if src != tgt: pair_samples = public_test[ (public_test['source_language'] == src) & (public_test['target_language'] == tgt) ] coverage_by_pair[f"{src}_{tgt}"] = { 'count': len(pair_samples), 'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR } return { 'alignment_check': len(public_ids - private_ids) == 0, 'total_samples': len(public_test), 'coverage_by_pair': coverage_by_pair, 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']] } except Exception as e: return {'error': str(e)}