import os import pandas as pd import yaml from datasets import load_dataset from config import ( TEST_SET_DATASET, SALT_DATASET, MAX_TEST_SAMPLES, HF_TOKEN, MIN_SAMPLES_PER_PAIR, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES ) import salt.dataset from src.utils import get_all_language_pairs # Local CSV filenames for persistence LOCAL_PUBLIC_CSV = "salt_test_set.csv" LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv" def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame: """ Generate standardized test set from the SALT dataset. """ print("🔄 Generating SALT test set from source dataset...") # Build SALT dataset config dataset_config = f''' huggingface_load: path: {SALT_DATASET} name: text-all split: test source: type: text language: {ALL_UG40_LANGUAGES} target: type: text language: {ALL_UG40_LANGUAGES} allow_same_src_and_tgt_language: False ''' config = yaml.safe_load(dataset_config) full_data = pd.DataFrame(salt.dataset.create(config)) test_samples = [] sample_id_counter = 1 for src_lang in ALL_UG40_LANGUAGES: for tgt_lang in ALL_UG40_LANGUAGES: if src_lang == tgt_lang: continue pair_data = full_data[ (full_data['source.language'] == src_lang) & (full_data['target.language'] == tgt_lang) ] if pair_data.empty: continue # Sample up to max_samples_per_pair n_samples = min(len(pair_data), max_samples_per_pair) sampled = pair_data.sample(n=n_samples, random_state=42) for _, row in sampled.iterrows(): test_samples.append({ 'sample_id': f"salt_{sample_id_counter:06d}", 'source_text': row['source'], 'target_text': row['target'], 'source_language': src_lang, 'target_language': tgt_lang, 'domain': row.get('domain', 'general'), 'google_comparable': ( src_lang in GOOGLE_SUPPORTED_LANGUAGES and tgt_lang in GOOGLE_SUPPORTED_LANGUAGES ) }) sample_id_counter += 1 test_df = pd.DataFrame(test_samples) print(f"✅ Generated test set: {len(test_df):,} samples across {len(get_all_language_pairs()):,} pairs") return test_df def _generate_and_save_test_set() -> (pd.DataFrame, pd.DataFrame): """ Generate the full test set and persist both public and complete CSV files. """ full_df = generate_test_set() # Public version (no target_text) public_df = full_df[[ 'sample_id', 'source_text', 'source_language', 'target_language', 'domain', 'google_comparable' ]] public_df.to_csv(LOCAL_PUBLIC_CSV, index=False) # Complete version (with target_text) full_df.to_csv(LOCAL_COMPLETE_CSV, index=False) print(f"✅ Saved local CSVs: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}") return public_df, full_df def get_public_test_set() -> pd.DataFrame: """ Load the public test set (without targets). Tries HF Hub → local CSV → regenerate. """ # 1) Try HF Hub try: ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN) df = ds.to_pandas() print(f"✅ Loaded public test set from HF Hub ({len(df):,} samples)") return df except Exception as e: print("⚠️ HF Hub load failed, falling back to local CSV:", e) # 2) Try local CSV if os.path.exists(LOCAL_PUBLIC_CSV): try: df = pd.read_csv(LOCAL_PUBLIC_CSV) print(f"✅ Loaded public test set from local CSV ({len(df):,} samples)") return df except Exception as e: print("⚠️ Failed to read local CSV, regenerating:", e) # 3) Regenerate & save print("🔄 Generating new public test set and saving to CSV...") public_df, _ = _generate_and_save_test_set() return public_df def get_complete_test_set() -> pd.DataFrame: """ Load the complete test set (with targets). Tries HF Hub-private → local CSV → regenerate. """ # 1) Try HF Hub private try: ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN) df = ds.to_pandas() print(f"✅ Loaded complete test set from HF Hub-private ({len(df):,} samples)") return df except Exception as e: print("⚠️ HF Hub-private load failed, falling back to local CSV:", e) # 2) Try local CSV if os.path.exists(LOCAL_COMPLETE_CSV): try: df = pd.read_csv(LOCAL_COMPLETE_CSV) print(f"✅ Loaded complete test set from local CSV ({len(df):,} samples)") return df except Exception as e: print("⚠️ Failed to read local complete CSV, regenerating:", e) # 3) Regenerate & save print("🔄 Generating new complete test set and saving to CSV...") _, complete_df = _generate_and_save_test_set() return complete_df def create_test_set_download() -> (str, dict): """ Create a CSV download of the public test set and return its path + stats. """ public_df = get_public_test_set() download_path = LOCAL_PUBLIC_CSV # Ensure the CSV is up-to-date public_df.to_csv(download_path, index=False) stats = { 'total_samples': len(public_df), 'language_pairs': len(public_df.groupby(['source_language', 'target_language'])), 'google_comparable_samples': int(public_df['google_comparable'].sum()), 'languages': list(set(public_df['source_language']).union(public_df['target_language'])), 'domains': public_df['domain'].unique().tolist() } return download_path, stats def validate_test_set_integrity() -> dict: """ Validate test set coverage and integrity. """ public_df = get_public_test_set() complete_df = get_complete_test_set() public_ids = set(public_df['sample_id']) private_ids = set(complete_df['sample_id']) coverage_by_pair = {} for src in ALL_UG40_LANGUAGES: for tgt in ALL_UG40_LANGUAGES: if src == tgt: continue subset = public_df[ (public_df['source_language'] == src) & (public_df['target_language'] == tgt) ] count = len(subset) coverage_by_pair[f"{src}_{tgt}"] = { 'count': count, 'has_samples': count >= MIN_SAMPLES_PER_PAIR } return { 'alignment_check': public_ids <= private_ids, 'total_samples': len(public_df), 'coverage_by_pair': coverage_by_pair, 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']] }