# src/test_set.py import os import pandas as pd import yaml from datasets import load_dataset from config import ( TEST_SET_DATASET, SALT_DATASET, MAX_TEST_SAMPLES, HF_TOKEN, MIN_SAMPLES_PER_PAIR, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES ) import salt.dataset from src.utils import get_all_language_pairs # Local CSV filenames for persistence LOCAL_PUBLIC_CSV = "salt_test_set.csv" LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv" def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame: """ Generate standardized test set from the SALT dataset. """ print("🔄 Generating SALT test set from source dataset...") try: # Build SALT dataset config - using 'test' split for consistency dataset_config = f''' huggingface_load: path: {SALT_DATASET} name: text-all split: test source: type: text language: {ALL_UG40_LANGUAGES} target: type: text language: {ALL_UG40_LANGUAGES} allow_same_src_and_tgt_language: False ''' config = yaml.safe_load(dataset_config) print("📥 Loading SALT dataset...") full_data = pd.DataFrame(salt.dataset.create(config)) print(f"📊 Loaded {len(full_data):,} samples from SALT dataset") test_samples = [] sample_id_counter = 1 # Generate samples for each language pair for src_lang in ALL_UG40_LANGUAGES: for tgt_lang in ALL_UG40_LANGUAGES: if src_lang == tgt_lang: continue # Filter for this language pair pair_data = full_data[ (full_data['source.language'] == src_lang) & (full_data['target.language'] == tgt_lang) ] if pair_data.empty: print(f"⚠️ No data found for {src_lang} → {tgt_lang}") continue # Sample up to max_samples_per_pair n_samples = min(len(pair_data), max_samples_per_pair) sampled = pair_data.sample(n=n_samples, random_state=42) print(f"✅ {src_lang} → {tgt_lang}: {n_samples} samples") for _, row in sampled.iterrows(): test_samples.append({ 'sample_id': f"salt_{sample_id_counter:06d}", 'source_text': row['source'], 'target_text': row['target'], 'source_language': src_lang, 'target_language': tgt_lang, 'domain': row.get('domain', 'general'), 'google_comparable': ( src_lang in GOOGLE_SUPPORTED_LANGUAGES and tgt_lang in GOOGLE_SUPPORTED_LANGUAGES ) }) sample_id_counter += 1 test_df = pd.DataFrame(test_samples) if test_df.empty: raise ValueError("No test samples generated - check SALT dataset availability") print(f"✅ Generated test set: {len(test_df):,} samples across {len(test_df.groupby(['source_language', 'target_language'])):,} pairs") # Add some statistics google_samples = test_df['google_comparable'].sum() unique_pairs = len(test_df.groupby(['source_language', 'target_language'])) print(f"📈 Test set statistics:") print(f" - Total samples: {len(test_df):,}") print(f" - Language pairs: {unique_pairs}") print(f" - Google comparable: {google_samples:,} samples") print(f" - UG40 only: {len(test_df) - google_samples:,} samples") return test_df except Exception as e: print(f"❌ Error generating test set: {e}") # Return empty DataFrame with correct structure return pd.DataFrame(columns=[ 'sample_id', 'source_text', 'target_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]) def _generate_and_save_test_set() -> tuple[pd.DataFrame, pd.DataFrame]: """ Generate the full test set and persist both public and complete CSV files. """ print("🔄 Generating and saving test sets...") full_df = generate_test_set() if full_df.empty: print("❌ Failed to generate test set") # Return empty DataFrames with correct structure empty_public = pd.DataFrame(columns=[ 'sample_id', 'source_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]) empty_complete = pd.DataFrame(columns=[ 'sample_id', 'source_text', 'target_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]) return empty_public, empty_complete # Public version (no target_text) public_df = full_df[[ 'sample_id', 'source_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]].copy() # Save both versions try: public_df.to_csv(LOCAL_PUBLIC_CSV, index=False) full_df.to_csv(LOCAL_COMPLETE_CSV, index=False) print(f"✅ Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}") except Exception as e: print(f"⚠️ Error saving CSVs: {e}") return public_df, full_df def get_public_test_set() -> pd.DataFrame: """ Load the public test set (without targets). Tries HF Hub → local CSV → regenerate. """ # 1) Try HF Hub try: print("📥 Attempting to load public test set from HF Hub...") ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN) df = ds.to_pandas() print(f"✅ Loaded public test set from HF Hub ({len(df):,} samples)") return df except Exception as e: print(f"⚠️ HF Hub load failed: {e}") # 2) Try local CSV if os.path.exists(LOCAL_PUBLIC_CSV): try: df = pd.read_csv(LOCAL_PUBLIC_CSV) print(f"✅ Loaded public test set from local CSV ({len(df):,} samples)") # Validate basic structure required_cols = ['sample_id', 'source_text', 'source_language', 'target_language'] if all(col in df.columns for col in required_cols): return df else: print("⚠️ Local CSV has invalid structure, regenerating...") except Exception as e: print(f"⚠️ Failed to read local CSV: {e}") # 3) Regenerate & save print("🔄 Generating new public test set...") public_df, _ = _generate_and_save_test_set() return public_df def get_complete_test_set() -> pd.DataFrame: """ Load the complete test set (with targets). Tries HF Hub-private → local CSV → regenerate. """ # 1) Try HF Hub private try: print("📥 Attempting to load complete test set from HF Hub-private...") ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN) df = ds.to_pandas() print(f"✅ Loaded complete test set from HF Hub-private ({len(df):,} samples)") return df except Exception as e: print(f"⚠️ HF Hub-private load failed: {e}") # 2) Try local CSV if os.path.exists(LOCAL_COMPLETE_CSV): try: df = pd.read_csv(LOCAL_COMPLETE_CSV) print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)") # Validate basic structure required_cols = ['sample_id', 'source_text', 'target_text', 'source_language', 'target_language'] if all(col in df.columns for col in required_cols): return df else: print("⚠️ Local CSV has invalid structure, regenerating...") except Exception as e: print(f"⚠️ Failed to read local complete CSV: {e}") # 3) Regenerate & save print("🔄 Generating new complete test set...") _, complete_df = _generate_and_save_test_set() return complete_df def create_test_set_download() -> tuple[str, dict]: """ Create a CSV download of the public test set and return its path + stats. """ public_df = get_public_test_set() if public_df.empty: # Create minimal stats for empty dataset stats = { 'total_samples': 0, 'language_pairs': 0, 'google_comparable_samples': 0, 'languages': [], 'domains': [] } return LOCAL_PUBLIC_CSV, stats download_path = LOCAL_PUBLIC_CSV # Ensure the CSV is up-to-date try: public_df.to_csv(download_path, index=False) except Exception as e: print(f"⚠️ Error updating CSV: {e}") # Calculate statistics try: stats = { 'total_samples': len(public_df), 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])), 'google_comparable_samples': int(public_df['google_comparable'].sum()) if 'google_comparable' in public_df.columns else 0, 'languages': sorted(list(set(public_df['source_language']).union(public_df['target_language']))), 'domains': public_df['domain'].unique().tolist() if 'domain' in public_df.columns else ['general'] } except Exception as e: print(f"⚠️ Error calculating stats: {e}") stats = { 'total_samples': len(public_df), 'language_pairs': 0, 'google_comparable_samples': 0, 'languages': [], 'domains': [] } return download_path, stats def validate_test_set_integrity() -> dict: """ Validate test set coverage and integrity. """ try: public_df = get_public_test_set() complete_df = get_complete_test_set() if public_df.empty or complete_df.empty: return { 'alignment_check': False, 'total_samples': 0, 'coverage_by_pair': {}, 'missing_pairs': [], 'error': 'Test sets are empty or could not be loaded' } public_ids = set(public_df['sample_id']) private_ids = set(complete_df['sample_id']) coverage_by_pair = {} for src in ALL_UG40_LANGUAGES: for tgt in ALL_UG40_LANGUAGES: if src == tgt: continue subset = public_df[ (public_df['source_language'] == src) & (public_df['target_language'] == tgt) ] count = len(subset) coverage_by_pair[f"{src}_{tgt}"] = { 'count': count, 'has_samples': count >= MIN_SAMPLES_PER_PAIR } return { 'alignment_check': public_ids <= private_ids, 'total_samples': len(public_df), 'coverage_by_pair': coverage_by_pair, 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']], 'public_samples': len(public_df), 'private_samples': len(complete_df), 'id_alignment_rate': len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0 } except Exception as e: return { 'alignment_check': False, 'total_samples': 0, 'coverage_by_pair': {}, 'missing_pairs': [], 'error': f'Validation failed: {str(e)}' }