# src/test_set.py import os import pandas as pd import yaml import numpy as np from datasets import load_dataset from typing import Optional, Dict, Tuple, List from config import ( TEST_SET_DATASET, SALT_DATASET, MAX_TEST_SAMPLES, HF_TOKEN, ALL_UG40_LANGUAGES, GOOGLE_SUPPORTED_LANGUAGES, EVALUATION_TRACKS, SAMPLE_SIZE_RECOMMENDATIONS, STATISTICAL_CONFIG, ) import salt.dataset from src.utils import get_all_language_pairs, get_track_language_pairs # Local CSV filenames for persistence LOCAL_PUBLIC_CSV = "salt_test_set_scientific.csv" LOCAL_COMPLETE_CSV = "salt_complete_test_set_scientific.csv" LOCAL_TRACK_CSVS = { track: f"salt_test_set_{track}.csv" for track in EVALUATION_TRACKS.keys() } def generate_scientific_test_set( max_samples_per_pair: int = MAX_TEST_SAMPLES, stratified_sampling: bool = True, balance_tracks: bool = True, ) -> pd.DataFrame: """Generate scientifically rigorous test set with stratified sampling.""" print("🔬 Generating scientific SALT test set...") try: # Build SALT dataset config dataset_config = f""" huggingface_load: path: {SALT_DATASET} name: text-all split: test source: type: text language: {ALL_UG40_LANGUAGES} target: type: text language: {ALL_UG40_LANGUAGES} allow_same_src_and_tgt_language: False """ config = yaml.safe_load(dataset_config) print("📥 Loading SALT dataset...") full_data = pd.DataFrame(salt.dataset.create(config)) print(f"📊 Loaded {len(full_data):,} samples from SALT dataset") test_samples = [] sample_id_counter = 1 # Calculate target samples per track for balanced evaluation track_targets = calculate_track_sampling_targets(balance_tracks) # Generate samples for each language pair with stratified sampling for src_lang in ALL_UG40_LANGUAGES: for tgt_lang in ALL_UG40_LANGUAGES: if src_lang == tgt_lang: continue # Determine target sample size for this pair pair_targets = calculate_pair_sampling_targets( src_lang, tgt_lang, track_targets, max_samples_per_pair ) target_samples = max(pair_targets.values()) if pair_targets else max_samples_per_pair # Filter for this language pair pair_data = full_data[ (full_data["source.language"] == src_lang) & (full_data["target.language"] == tgt_lang) ] if pair_data.empty: print(f"⚠️ No data found for {src_lang} → {tgt_lang}") continue # Stratified sampling if enabled if stratified_sampling and len(pair_data) > target_samples: sampled = stratified_sample_pair_data(pair_data, target_samples) else: # Simple random sampling n_samples = min(len(pair_data), target_samples) sampled = pair_data.sample(n=n_samples, random_state=42) print(f"✅ {src_lang} → {tgt_lang}: {len(sampled)} samples") for _, row in sampled.iterrows(): # Determine which tracks include this pair tracks_included = [] for track_name, track_config in EVALUATION_TRACKS.items(): if (src_lang in track_config["languages"] and tgt_lang in track_config["languages"]): tracks_included.append(track_name) test_samples.append({ "sample_id": f"salt_{sample_id_counter:06d}", "source_text": row["source"], "target_text": row["target"], "source_language": src_lang, "target_language": tgt_lang, "domain": row.get("domain", "general"), "google_comparable": ( src_lang in GOOGLE_SUPPORTED_LANGUAGES and tgt_lang in GOOGLE_SUPPORTED_LANGUAGES ), "tracks_included": ",".join(tracks_included), "statistical_weight": calculate_statistical_weight( src_lang, tgt_lang, tracks_included ), }) sample_id_counter += 1 test_df = pd.DataFrame(test_samples) if test_df.empty: raise ValueError("No test samples generated - check SALT dataset availability") # Validate scientific adequacy adequacy_report = validate_test_set_scientific_adequacy(test_df) print(f"✅ Generated scientific test set: {len(test_df):,} samples") print(f"📈 Test set adequacy: {adequacy_report['overall_adequacy']}") return test_df except Exception as e: print(f"❌ Error generating scientific test set: {e}") return pd.DataFrame(columns=[ "sample_id", "source_text", "target_text", "source_language", "target_language", "domain", "google_comparable", "tracks_included", "statistical_weight" ]) def calculate_track_sampling_targets(balance_tracks: bool) -> Dict[str, int]: """Calculate target sample sizes for each track to ensure statistical adequacy.""" track_targets = {} for track_name, track_config in EVALUATION_TRACKS.items(): # Base requirement from config min_per_pair = track_config["min_samples_per_pair"] # Number of language pairs in this track n_pairs = len(track_config["languages"]) * (len(track_config["languages"]) - 1) # Calculate total samples needed for statistical adequacy if balance_tracks: # Use publication-quality recommendation target_per_pair = max( min_per_pair, SAMPLE_SIZE_RECOMMENDATIONS["publication_quality"] // n_pairs ) else: target_per_pair = min_per_pair track_targets[track_name] = target_per_pair * n_pairs print(f"📊 {track_name}: targeting {target_per_pair} samples/pair × {n_pairs} pairs = {track_targets[track_name]} total") return track_targets def calculate_pair_sampling_targets( src_lang: str, tgt_lang: str, track_targets: Dict[str, int], max_samples: int ) -> Dict[str, int]: """Calculate sampling targets for a specific language pair across tracks.""" pair_targets = {} for track_name, track_config in EVALUATION_TRACKS.items(): if (src_lang in track_config["languages"] and tgt_lang in track_config["languages"]): n_pairs_in_track = len(track_config["languages"]) * (len(track_config["languages"]) - 1) target_per_pair = track_targets[track_name] // n_pairs_in_track pair_targets[track_name] = min(target_per_pair, max_samples) return pair_targets def stratified_sample_pair_data(pair_data: pd.DataFrame, target_samples: int) -> pd.DataFrame: """Perform stratified sampling on pair data to ensure representativeness.""" # Try to stratify by domain if available if "domain" in pair_data.columns and pair_data["domain"].nunique() > 1: # Sample proportionally from each domain domain_counts = pair_data["domain"].value_counts() sampled_parts = [] for domain, count in domain_counts.items(): domain_data = pair_data[pair_data["domain"] == domain] # Calculate proportional sample size proportion = count / len(pair_data) domain_target = max(1, int(target_samples * proportion)) domain_target = min(domain_target, len(domain_data)) if len(domain_data) >= domain_target: domain_sample = domain_data.sample(n=domain_target, random_state=42) sampled_parts.append(domain_sample) if sampled_parts: stratified_sample = pd.concat(sampled_parts, ignore_index=True) # If we didn't get enough samples, fill with random sampling if len(stratified_sample) < target_samples: remaining_data = pair_data[~pair_data.index.isin(stratified_sample.index)] additional_needed = target_samples - len(stratified_sample) if len(remaining_data) >= additional_needed: additional_sample = remaining_data.sample(n=additional_needed, random_state=42) stratified_sample = pd.concat([stratified_sample, additional_sample], ignore_index=True) return stratified_sample.head(target_samples) # Fallback to simple random sampling return pair_data.sample(n=min(target_samples, len(pair_data)), random_state=42) def calculate_statistical_weight( src_lang: str, tgt_lang: str, tracks_included: List[str] ) -> float: """Calculate statistical weight for a sample based on track inclusion.""" # Base weight weight = 1.0 # Higher weight for samples in multiple tracks (more valuable) weight *= len(tracks_included) # Higher weight for Google-comparable pairs (enable baseline comparison) if (src_lang in GOOGLE_SUPPORTED_LANGUAGES and tgt_lang in GOOGLE_SUPPORTED_LANGUAGES): weight *= 1.5 # Normalize to reasonable range return min(weight, 5.0) def validate_test_set_scientific_adequacy(test_df: pd.DataFrame) -> Dict: """Validate that the test set meets scientific adequacy requirements.""" adequacy_report = { "overall_adequacy": "insufficient", "track_adequacy": {}, "issues": [], "recommendations": [], "statistics": {}, } if test_df.empty: adequacy_report["issues"].append("Test set is empty") return adequacy_report # Check each track track_adequacies = [] for track_name, track_config in EVALUATION_TRACKS.items(): track_languages = track_config["languages"] min_per_pair = track_config["min_samples_per_pair"] # Filter to track data track_data = test_df[ (test_df["source_language"].isin(track_languages)) & (test_df["target_language"].isin(track_languages)) ] # Analyze pair coverage pair_counts = {} for src in track_languages: for tgt in track_languages: if src == tgt: continue pair_samples = track_data[ (track_data["source_language"] == src) & (track_data["target_language"] == tgt) ] pair_counts[f"{src}_{tgt}"] = len(pair_samples) # Calculate adequacy metrics total_pairs = len(pair_counts) adequate_pairs = sum(1 for count in pair_counts.values() if count >= min_per_pair) adequacy_rate = adequate_pairs / max(total_pairs, 1) # Determine track adequacy level if adequacy_rate >= 0.9: track_adequacy = "excellent" elif adequacy_rate >= 0.8: track_adequacy = "good" elif adequacy_rate >= 0.6: track_adequacy = "fair" else: track_adequacy = "insufficient" adequacy_report["track_adequacy"][track_name] = { "adequacy": track_adequacy, "adequacy_rate": adequacy_rate, "total_samples": len(track_data), "total_pairs": total_pairs, "adequate_pairs": adequate_pairs, "min_samples_per_pair": min_per_pair, "pair_counts": pair_counts, } track_adequacies.append(track_adequacy) # Add specific issues if track_adequacy == "insufficient": inadequate_pairs = [k for k, v in pair_counts.items() if v < min_per_pair] adequacy_report["issues"].append( f"{track_name}: {len(inadequate_pairs)} pairs below minimum" ) # Overall adequacy assessment if all(adequacy in ["excellent", "good"] for adequacy in track_adequacies): adequacy_report["overall_adequacy"] = "excellent" elif all(adequacy in ["excellent", "good", "fair"] for adequacy in track_adequacies): adequacy_report["overall_adequacy"] = "good" elif any(adequacy in ["good", "fair"] for adequacy in track_adequacies): adequacy_report["overall_adequacy"] = "fair" else: adequacy_report["overall_adequacy"] = "insufficient" # Overall statistics adequacy_report["statistics"] = { "total_samples": len(test_df), "total_language_pairs": len(test_df.groupby(["source_language", "target_language"])), "google_comparable_samples": int(test_df["google_comparable"].sum()), "domain_distribution": test_df["domain"].value_counts().to_dict(), "track_sample_distribution": { track: adequacy_report["track_adequacy"][track]["total_samples"] for track in EVALUATION_TRACKS.keys() }, } # Generate recommendations if adequacy_report["overall_adequacy"] in ["insufficient", "fair"]: adequacy_report["recommendations"].append( "Consider increasing sample size for better statistical power" ) if adequacy_report["statistics"]["google_comparable_samples"] < 1000: adequacy_report["recommendations"].append( "More Google-comparable samples recommended for baseline comparison" ) return adequacy_report def _generate_and_save_scientific_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]: """Generate and save both public and complete versions of the scientific test set.""" print("🔬 Generating and saving scientific test sets...") full_df = generate_scientific_test_set() if full_df.empty: print("❌ Failed to generate scientific test set") empty_public = pd.DataFrame(columns=[ "sample_id", "source_text", "source_language", "target_language", "domain", "google_comparable", "tracks_included", "statistical_weight" ]) empty_complete = pd.DataFrame(columns=[ "sample_id", "source_text", "target_text", "source_language", "target_language", "domain", "google_comparable", "tracks_included", "statistical_weight" ]) return empty_public, empty_complete # Public version (no target_text) public_df = full_df[[ "sample_id", "source_text", "source_language", "target_language", "domain", "google_comparable", "tracks_included", "statistical_weight" ]].copy() # Save main versions try: public_df.to_csv(LOCAL_PUBLIC_CSV, index=False) full_df.to_csv(LOCAL_COMPLETE_CSV, index=False) print(f"✅ Saved main test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}") except Exception as e: print(f"⚠️ Error saving main CSVs: {e}") # Save track-specific versions for easier analysis for track_name, track_config in EVALUATION_TRACKS.items(): try: track_languages = track_config["languages"] track_public = public_df[ (public_df["source_language"].isin(track_languages)) & (public_df["target_language"].isin(track_languages)) ] track_filename = LOCAL_TRACK_CSVS[track_name] track_public.to_csv(track_filename, index=False) print(f"✅ Saved {track_name} track: {track_filename} ({len(track_public):,} samples)") except Exception as e: print(f"⚠️ Error saving {track_name} track CSV: {e}") return public_df, full_df def get_public_test_set_scientific() -> pd.DataFrame: """Load the scientific public test set with enhanced fallback logic.""" # 1) Try HF Hub try: print("📥 Attempting to load scientific test set from HF Hub...") ds = load_dataset(TEST_SET_DATASET + "-scientific", split="train", token=HF_TOKEN) df = ds.to_pandas() # Validate scientific structure required_cols = ["sample_id", "source_text", "source_language", "target_language", "tracks_included", "statistical_weight"] if all(col in df.columns for col in required_cols): print(f"✅ Loaded scientific test set from HF Hub ({len(df):,} samples)") return df else: print("⚠️ HF Hub test set missing scientific columns, regenerating...") except Exception as e: print(f"⚠️ HF Hub load failed: {e}") # 2) Try local CSV if os.path.exists(LOCAL_PUBLIC_CSV): try: df = pd.read_csv(LOCAL_PUBLIC_CSV) required_cols = ["sample_id", "source_text", "source_language", "target_language"] if all(col in df.columns for col in required_cols): print(f"✅ Loaded scientific test set from local CSV ({len(df):,} samples)") return df else: print("⚠️ Local CSV has invalid structure, regenerating...") except Exception as e: print(f"⚠️ Failed to read local scientific CSV: {e}") # 3) Regenerate & save print("🔄 Generating new scientific test set...") public_df, _ = _generate_and_save_scientific_test_set() return public_df def get_complete_test_set_scientific() -> pd.DataFrame: """Load the complete scientific test set with targets.""" # 1) Try HF Hub private try: print("📥 Attempting to load complete scientific test set from HF Hub...") ds = load_dataset(TEST_SET_DATASET + "-scientific-private", split="train", token=HF_TOKEN) df = ds.to_pandas() required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language", "tracks_included", "statistical_weight"] if all(col in df.columns for col in required_cols): print(f"✅ Loaded complete scientific test set from HF Hub ({len(df):,} samples)") return df else: print("⚠️ HF Hub complete test set missing scientific columns, regenerating...") except Exception as e: print(f"⚠️ HF Hub private load failed: {e}") # 2) Try local CSV if os.path.exists(LOCAL_COMPLETE_CSV): try: df = pd.read_csv(LOCAL_COMPLETE_CSV) required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"] if all(col in df.columns for col in required_cols): print(f"✅ Loaded complete scientific test set from local CSV ({len(df):,} samples)") return df else: print("⚠️ Local complete CSV has invalid structure, regenerating...") except Exception as e: print(f"⚠️ Failed to read local complete scientific CSV: {e}") # 3) Regenerate & save print("🔄 Generating new complete scientific test set...") _, complete_df = _generate_and_save_scientific_test_set() return complete_df def get_track_test_set(track: str) -> pd.DataFrame: """Get test set filtered for a specific track.""" if track not in EVALUATION_TRACKS: print(f"❌ Unknown track: {track}") return pd.DataFrame() # Try track-specific CSV first track_csv = LOCAL_TRACK_CSVS.get(track) if track_csv and os.path.exists(track_csv): try: df = pd.read_csv(track_csv) print(f"✅ Loaded {track} test set from track-specific CSV ({len(df):,} samples)") return df except Exception as e: print(f"⚠️ Failed to read {track} CSV: {e}") # Fallback to filtering main test set public_df = get_public_test_set_scientific() if public_df.empty: return pd.DataFrame() track_languages = EVALUATION_TRACKS[track]["languages"] track_df = public_df[ (public_df["source_language"].isin(track_languages)) & (public_df["target_language"].isin(track_languages)) ] print(f"✅ Filtered {track} test set from main set ({len(track_df):,} samples)") return track_df def create_test_set_download_scientific() -> Tuple[str, Dict]: """Create scientific test set download with comprehensive metadata.""" public_df = get_public_test_set_scientific() if public_df.empty: stats = { "total_samples": 0, "track_breakdown": {}, "adequacy_assessment": "insufficient", "scientific_metadata": {}, } return LOCAL_PUBLIC_CSV, stats download_path = LOCAL_PUBLIC_CSV # Ensure the CSV is up-to-date try: public_df.to_csv(download_path, index=False) except Exception as e: print(f"⚠️ Error updating scientific CSV: {e}") # Calculate comprehensive statistics try: # Basic statistics stats = { "total_samples": len(public_df), "languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))), "domains": public_df["domain"].unique().tolist() if "domain" in public_df.columns else ["general"], } # Track-specific breakdown track_breakdown = {} for track_name, track_config in EVALUATION_TRACKS.items(): track_languages = track_config["languages"] track_data = public_df[ (public_df["source_language"].isin(track_languages)) & (public_df["target_language"].isin(track_languages)) ] track_breakdown[track_name] = { "name": track_config["name"], "total_samples": len(track_data), "language_pairs": len(track_data.groupby(["source_language", "target_language"])), "min_samples_per_pair": track_config["min_samples_per_pair"], "statistical_adequacy": len(track_data) >= track_config["min_samples_per_pair"] * len(track_languages) * (len(track_languages) - 1), } stats["track_breakdown"] = track_breakdown # Google-comparable statistics if "google_comparable" in public_df.columns: stats["google_comparable_samples"] = int(public_df["google_comparable"].sum()) stats["google_comparable_rate"] = float(public_df["google_comparable"].mean()) else: stats["google_comparable_samples"] = 0 stats["google_comparable_rate"] = 0.0 # Scientific adequacy assessment adequacy_report = validate_test_set_scientific_adequacy(public_df) stats["adequacy_assessment"] = adequacy_report["overall_adequacy"] stats["adequacy_details"] = adequacy_report # Scientific metadata stats["scientific_metadata"] = { "stratified_sampling": True, "statistical_weighting": "statistical_weight" in public_df.columns, "track_balanced": True, "confidence_level": STATISTICAL_CONFIG["confidence_level"], "recommended_for": [ track for track, info in track_breakdown.items() if info["statistical_adequacy"] ], } except Exception as e: print(f"⚠️ Error calculating scientific stats: {e}") stats = { "total_samples": len(public_df), "track_breakdown": {}, "adequacy_assessment": "unknown", "scientific_metadata": {}, } return download_path, stats def validate_test_set_integrity_scientific() -> Dict: """Comprehensive validation of scientific test set integrity.""" try: public_df = get_public_test_set_scientific() complete_df = get_complete_test_set_scientific() if public_df.empty or complete_df.empty: return { "alignment_check": False, "total_samples": 0, "scientific_adequacy": {}, "track_analysis": {}, "error": "Test sets are empty or could not be loaded", } public_ids = set(public_df["sample_id"]) private_ids = set(complete_df["sample_id"]) # Track-specific analysis track_analysis = {} for track_name, track_config in EVALUATION_TRACKS.items(): track_languages = track_config["languages"] min_per_pair = track_config["min_samples_per_pair"] # Analyze public set for this track track_public = public_df[ (public_df["source_language"].isin(track_languages)) & (public_df["target_language"].isin(track_languages)) ] # Analyze complete set for this track track_complete = complete_df[ (complete_df["source_language"].isin(track_languages)) & (complete_df["target_language"].isin(track_languages)) ] # Calculate coverage pair_coverage = {} for src in track_languages: for tgt in track_languages: if src == tgt: continue public_subset = track_public[ (track_public["source_language"] == src) & (track_public["target_language"] == tgt) ] complete_subset = track_complete[ (track_complete["source_language"] == src) & (track_complete["target_language"] == tgt) ] pair_coverage[f"{src}_{tgt}"] = { "public_count": len(public_subset), "complete_count": len(complete_subset), "alignment": len(public_subset) == len(complete_subset), "meets_minimum": len(public_subset) >= min_per_pair, } # Track summary total_pairs = len(pair_coverage) adequate_pairs = sum(1 for info in pair_coverage.values() if info["meets_minimum"]) aligned_pairs = sum(1 for info in pair_coverage.values() if info["alignment"]) track_analysis[track_name] = { "total_pairs": total_pairs, "adequate_pairs": adequate_pairs, "aligned_pairs": aligned_pairs, "adequacy_rate": adequate_pairs / max(total_pairs, 1), "alignment_rate": aligned_pairs / max(total_pairs, 1), "pair_coverage": pair_coverage, "statistical_power": calculate_track_statistical_power(track_public, track_config), } # Overall scientific adequacy adequacy_report = validate_test_set_scientific_adequacy(public_df) return { "alignment_check": public_ids <= private_ids, "total_samples": len(public_df), "track_analysis": track_analysis, "scientific_adequacy": adequacy_report, "public_samples": len(public_df), "private_samples": len(complete_df), "id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0, "integrity_score": calculate_integrity_score(track_analysis, adequacy_report), } except Exception as e: return { "alignment_check": False, "total_samples": 0, "scientific_adequacy": {}, "track_analysis": {}, "error": f"Validation failed: {str(e)}", } def calculate_track_statistical_power(track_data: pd.DataFrame, track_config: Dict) -> float: """Calculate statistical power estimate for a track.""" if track_data.empty: return 0.0 # Simple power estimation based on sample size min_required = track_config["min_samples_per_pair"] languages = track_config["languages"] total_pairs = len(languages) * (len(languages) - 1) # Calculate average samples per pair pair_counts = [] for src in languages: for tgt in languages: if src == tgt: continue pair_samples = track_data[ (track_data["source_language"] == src) & (track_data["target_language"] == tgt) ] pair_counts.append(len(pair_samples)) if not pair_counts: return 0.0 avg_samples_per_pair = np.mean(pair_counts) # Rough power estimation (0.8 power at 2x minimum, 0.95 at 4x minimum) if avg_samples_per_pair >= min_required * 4: return 0.95 elif avg_samples_per_pair >= min_required * 2: return 0.8 elif avg_samples_per_pair >= min_required: return 0.6 else: return max(0.0, avg_samples_per_pair / min_required * 0.6) def calculate_integrity_score(track_analysis: Dict, adequacy_report: Dict) -> float: """Calculate overall integrity score for the test set.""" if not track_analysis or not adequacy_report: return 0.0 # Track adequacy scores track_scores = [] for track_info in track_analysis.values(): adequacy_rate = track_info.get("adequacy_rate", 0.0) alignment_rate = track_info.get("alignment_rate", 0.0) track_score = (adequacy_rate + alignment_rate) / 2 track_scores.append(track_score) # Overall adequacy mapping adequacy_mapping = { "excellent": 1.0, "good": 0.8, "fair": 0.6, "insufficient": 0.2, } overall_adequacy_score = adequacy_mapping.get( adequacy_report.get("overall_adequacy", "insufficient"), 0.2 ) # Combined score if track_scores: track_avg = np.mean(track_scores) integrity_score = (track_avg + overall_adequacy_score) / 2 else: integrity_score = overall_adequacy_score return float(integrity_score)