Spaces:
Sleeping
Sleeping
# src/test_set.py | |
import os | |
import pandas as pd | |
import yaml | |
import numpy as np | |
from datasets import load_dataset | |
from config import ( | |
TEST_SET_DATASET, | |
SALT_DATASET, | |
MAX_TEST_SAMPLES, | |
HF_TOKEN, | |
ALL_UG40_LANGUAGES, | |
GOOGLE_SUPPORTED_LANGUAGES, | |
EVALUATION_TRACKS, | |
SAMPLE_SIZE_RECOMMENDATIONS, | |
STATISTICAL_CONFIG, | |
) | |
import salt.dataset | |
from src.utils import get_all_language_pairs, get_track_language_pairs | |
from typing import Dict, List, Optional, Tuple | |
# Local CSV filenames for persistence | |
LOCAL_PUBLIC_CSV = "salt_test_set_scientific.csv" | |
LOCAL_COMPLETE_CSV = "salt_complete_test_set_scientific.csv" | |
LOCAL_TRACK_CSVS = { | |
track: f"salt_test_set_{track}.csv" for track in EVALUATION_TRACKS.keys() | |
} | |
def generate_scientific_test_set( | |
max_samples_per_pair: int = MAX_TEST_SAMPLES, | |
stratified_sampling: bool = True, | |
balance_tracks: bool = True, | |
) -> pd.DataFrame: | |
"""Generate scientifically rigorous test set with stratified sampling.""" | |
print("π¬ Generating scientific SALT test set...") | |
try: | |
# Build SALT dataset config | |
dataset_config = f""" | |
huggingface_load: | |
path: {SALT_DATASET} | |
name: text-all | |
split: test | |
source: | |
type: text | |
language: {ALL_UG40_LANGUAGES} | |
target: | |
type: text | |
language: {ALL_UG40_LANGUAGES} | |
allow_same_src_and_tgt_language: False | |
""" | |
config = yaml.safe_load(dataset_config) | |
print("π₯ Loading SALT dataset...") | |
full_data = pd.DataFrame(salt.dataset.create(config)) | |
print(f"π Loaded {len(full_data):,} samples from SALT dataset") | |
test_samples = [] | |
sample_id_counter = 1 | |
# Calculate target samples per track for balanced evaluation | |
track_targets = calculate_track_sampling_targets(balance_tracks) | |
# Generate samples for each language pair with stratified sampling | |
for src_lang in ALL_UG40_LANGUAGES: | |
for tgt_lang in ALL_UG40_LANGUAGES: | |
if src_lang == tgt_lang: | |
continue | |
# Determine target sample size for this pair | |
pair_targets = calculate_pair_sampling_targets( | |
src_lang, tgt_lang, track_targets, max_samples_per_pair | |
) | |
target_samples = max(pair_targets.values()) if pair_targets else max_samples_per_pair | |
# Filter for this language pair | |
pair_data = full_data[ | |
(full_data["source.language"] == src_lang) & | |
(full_data["target.language"] == tgt_lang) | |
] | |
if pair_data.empty: | |
print(f"β οΈ No data found for {src_lang} β {tgt_lang}") | |
continue | |
# Stratified sampling if enabled | |
if stratified_sampling and len(pair_data) > target_samples: | |
sampled = stratified_sample_pair_data(pair_data, target_samples) | |
else: | |
# Simple random sampling | |
n_samples = min(len(pair_data), target_samples) | |
sampled = pair_data.sample(n=n_samples, random_state=42) | |
print(f"β {src_lang} β {tgt_lang}: {len(sampled)} samples") | |
for _, row in sampled.iterrows(): | |
# Determine which tracks include this pair | |
tracks_included = [] | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
if (src_lang in track_config["languages"] and | |
tgt_lang in track_config["languages"]): | |
tracks_included.append(track_name) | |
test_samples.append({ | |
"sample_id": f"salt_{sample_id_counter:06d}", | |
"source_text": row["source"], | |
"target_text": row["target"], | |
"source_language": src_lang, | |
"target_language": tgt_lang, | |
"domain": row.get("domain", "general"), | |
"google_comparable": ( | |
src_lang in GOOGLE_SUPPORTED_LANGUAGES and | |
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES | |
), | |
"tracks_included": ",".join(tracks_included), | |
"statistical_weight": calculate_statistical_weight( | |
src_lang, tgt_lang, tracks_included | |
), | |
}) | |
sample_id_counter += 1 | |
test_df = pd.DataFrame(test_samples) | |
if test_df.empty: | |
raise ValueError("No test samples generated - check SALT dataset availability") | |
# Validate scientific adequacy | |
adequacy_report = validate_test_set_scientific_adequacy(test_df) | |
print(f"β Generated scientific test set: {len(test_df):,} samples") | |
print(f"π Test set adequacy: {adequacy_report['overall_adequacy']}") | |
return test_df | |
except Exception as e: | |
print(f"β Error generating scientific test set: {e}") | |
return pd.DataFrame(columns=[ | |
"sample_id", "source_text", "target_text", "source_language", | |
"target_language", "domain", "google_comparable", "tracks_included", | |
"statistical_weight" | |
]) | |
def calculate_track_sampling_targets(balance_tracks: bool) -> Dict[str, int]: | |
"""Calculate target sample sizes for each track to ensure statistical adequacy.""" | |
track_targets = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
# Base requirement from config | |
min_per_pair = track_config["min_samples_per_pair"] | |
# Number of language pairs in this track | |
n_pairs = len(track_config["languages"]) * (len(track_config["languages"]) - 1) | |
# Calculate total samples needed for statistical adequacy | |
if balance_tracks: | |
# Use publication-quality recommendation | |
target_per_pair = max( | |
min_per_pair, | |
SAMPLE_SIZE_RECOMMENDATIONS["publication_quality"] // n_pairs | |
) | |
else: | |
target_per_pair = min_per_pair | |
track_targets[track_name] = target_per_pair * n_pairs | |
print(f"π {track_name}: targeting {target_per_pair} samples/pair Γ {n_pairs} pairs = {track_targets[track_name]} total") | |
return track_targets | |
def calculate_pair_sampling_targets( | |
src_lang: str, tgt_lang: str, track_targets: Dict[str, int], max_samples: int | |
) -> Dict[str, int]: | |
"""Calculate sampling targets for a specific language pair across tracks.""" | |
pair_targets = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
if (src_lang in track_config["languages"] and | |
tgt_lang in track_config["languages"]): | |
n_pairs_in_track = len(track_config["languages"]) * (len(track_config["languages"]) - 1) | |
target_per_pair = track_targets[track_name] // n_pairs_in_track | |
pair_targets[track_name] = min(target_per_pair, max_samples) | |
return pair_targets | |
def stratified_sample_pair_data(pair_data: pd.DataFrame, target_samples: int) -> pd.DataFrame: | |
"""Perform stratified sampling on pair data to ensure representativeness.""" | |
# Try to stratify by domain if available | |
if "domain" in pair_data.columns and pair_data["domain"].nunique() > 1: | |
# Sample proportionally from each domain | |
domain_counts = pair_data["domain"].value_counts() | |
sampled_parts = [] | |
for domain, count in domain_counts.items(): | |
domain_data = pair_data[pair_data["domain"] == domain] | |
# Calculate proportional sample size | |
proportion = count / len(pair_data) | |
domain_target = max(1, int(target_samples * proportion)) | |
domain_target = min(domain_target, len(domain_data)) | |
if len(domain_data) >= domain_target: | |
domain_sample = domain_data.sample(n=domain_target, random_state=42) | |
sampled_parts.append(domain_sample) | |
if sampled_parts: | |
stratified_sample = pd.concat(sampled_parts, ignore_index=True) | |
# If we didn't get enough samples, fill with random sampling | |
if len(stratified_sample) < target_samples: | |
remaining_data = pair_data[~pair_data.index.isin(stratified_sample.index)] | |
additional_needed = target_samples - len(stratified_sample) | |
if len(remaining_data) >= additional_needed: | |
additional_sample = remaining_data.sample(n=additional_needed, random_state=42) | |
stratified_sample = pd.concat([stratified_sample, additional_sample], ignore_index=True) | |
return stratified_sample.head(target_samples) | |
# Fallback to simple random sampling | |
return pair_data.sample(n=min(target_samples, len(pair_data)), random_state=42) | |
def calculate_statistical_weight( | |
src_lang: str, tgt_lang: str, tracks_included: List[str] | |
) -> float: | |
"""Calculate statistical weight for a sample based on track inclusion.""" | |
# Base weight | |
weight = 1.0 | |
# Higher weight for samples in multiple tracks (more valuable) | |
weight *= len(tracks_included) | |
# Higher weight for Google-comparable pairs (enable baseline comparison) | |
if (src_lang in GOOGLE_SUPPORTED_LANGUAGES and | |
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES): | |
weight *= 1.5 | |
# Normalize to reasonable range | |
return min(weight, 5.0) | |
def validate_test_set_scientific_adequacy(test_df: pd.DataFrame) -> Dict: | |
"""Validate that the test set meets scientific adequacy requirements.""" | |
adequacy_report = { | |
"overall_adequacy": "insufficient", | |
"track_adequacy": {}, | |
"issues": [], | |
"recommendations": [], | |
"statistics": {}, | |
} | |
if test_df.empty: | |
adequacy_report["issues"].append("Test set is empty") | |
return adequacy_report | |
# Check each track | |
track_adequacies = [] | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
track_languages = track_config["languages"] | |
min_per_pair = track_config["min_samples_per_pair"] | |
# Filter to track data | |
track_data = test_df[ | |
(test_df["source_language"].isin(track_languages)) & | |
(test_df["target_language"].isin(track_languages)) | |
] | |
# Analyze pair coverage | |
pair_counts = {} | |
for src in track_languages: | |
for tgt in track_languages: | |
if src == tgt: | |
continue | |
pair_samples = track_data[ | |
(track_data["source_language"] == src) & | |
(track_data["target_language"] == tgt) | |
] | |
pair_counts[f"{src}_{tgt}"] = len(pair_samples) | |
# Calculate adequacy metrics | |
total_pairs = len(pair_counts) | |
adequate_pairs = sum(1 for count in pair_counts.values() if count >= min_per_pair) | |
adequacy_rate = adequate_pairs / max(total_pairs, 1) | |
# Determine track adequacy level | |
if adequacy_rate >= 0.9: | |
track_adequacy = "excellent" | |
elif adequacy_rate >= 0.8: | |
track_adequacy = "good" | |
elif adequacy_rate >= 0.6: | |
track_adequacy = "fair" | |
else: | |
track_adequacy = "insufficient" | |
adequacy_report["track_adequacy"][track_name] = { | |
"adequacy": track_adequacy, | |
"adequacy_rate": adequacy_rate, | |
"total_samples": len(track_data), | |
"total_pairs": total_pairs, | |
"adequate_pairs": adequate_pairs, | |
"min_samples_per_pair": min_per_pair, | |
"pair_counts": pair_counts, | |
} | |
track_adequacies.append(track_adequacy) | |
# Add specific issues | |
if track_adequacy == "insufficient": | |
inadequate_pairs = [k for k, v in pair_counts.items() if v < min_per_pair] | |
adequacy_report["issues"].append( | |
f"{track_name}: {len(inadequate_pairs)} pairs below minimum" | |
) | |
# Overall adequacy assessment | |
if all(adequacy in ["excellent", "good"] for adequacy in track_adequacies): | |
adequacy_report["overall_adequacy"] = "excellent" | |
elif all(adequacy in ["excellent", "good", "fair"] for adequacy in track_adequacies): | |
adequacy_report["overall_adequacy"] = "good" | |
elif any(adequacy in ["good", "fair"] for adequacy in track_adequacies): | |
adequacy_report["overall_adequacy"] = "fair" | |
else: | |
adequacy_report["overall_adequacy"] = "insufficient" | |
# Overall statistics | |
adequacy_report["statistics"] = { | |
"total_samples": len(test_df), | |
"total_language_pairs": len(test_df.groupby(["source_language", "target_language"])), | |
"google_comparable_samples": int(test_df["google_comparable"].sum()), | |
"domain_distribution": test_df["domain"].value_counts().to_dict(), | |
"track_sample_distribution": { | |
track: adequacy_report["track_adequacy"][track]["total_samples"] | |
for track in EVALUATION_TRACKS.keys() | |
}, | |
} | |
# Generate recommendations | |
if adequacy_report["overall_adequacy"] in ["insufficient", "fair"]: | |
adequacy_report["recommendations"].append( | |
"Consider increasing sample size for better statistical power" | |
) | |
if adequacy_report["statistics"]["google_comparable_samples"] < 1000: | |
adequacy_report["recommendations"].append( | |
"More Google-comparable samples recommended for baseline comparison" | |
) | |
return adequacy_report | |
def _generate_and_save_scientific_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Generate and save both public and complete versions of the scientific test set.""" | |
print("π¬ Generating and saving scientific test sets...") | |
full_df = generate_scientific_test_set() | |
if full_df.empty: | |
print("β Failed to generate scientific test set") | |
empty_public = pd.DataFrame(columns=[ | |
"sample_id", "source_text", "source_language", | |
"target_language", "domain", "google_comparable", | |
"tracks_included", "statistical_weight" | |
]) | |
empty_complete = pd.DataFrame(columns=[ | |
"sample_id", "source_text", "target_text", "source_language", | |
"target_language", "domain", "google_comparable", | |
"tracks_included", "statistical_weight" | |
]) | |
return empty_public, empty_complete | |
# Public version (no target_text) | |
public_df = full_df[[ | |
"sample_id", "source_text", "source_language", | |
"target_language", "domain", "google_comparable", | |
"tracks_included", "statistical_weight" | |
]].copy() | |
# Save main versions | |
try: | |
public_df.to_csv(LOCAL_PUBLIC_CSV, index=False) | |
full_df.to_csv(LOCAL_COMPLETE_CSV, index=False) | |
print(f"β Saved main test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}") | |
except Exception as e: | |
print(f"β οΈ Error saving main CSVs: {e}") | |
# Save track-specific versions for easier analysis | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
try: | |
track_languages = track_config["languages"] | |
track_public = public_df[ | |
(public_df["source_language"].isin(track_languages)) & | |
(public_df["target_language"].isin(track_languages)) | |
] | |
track_filename = LOCAL_TRACK_CSVS[track_name] | |
track_public.to_csv(track_filename, index=False) | |
print(f"β Saved {track_name} track: {track_filename} ({len(track_public):,} samples)") | |
except Exception as e: | |
print(f"β οΈ Error saving {track_name} track CSV: {e}") | |
return public_df, full_df | |
def get_public_test_set_scientific() -> pd.DataFrame: | |
"""Load the scientific public test set with enhanced fallback logic.""" | |
# 1) Try HF Hub | |
try: | |
print("π₯ Attempting to load scientific test set from HF Hub...") | |
ds = load_dataset(TEST_SET_DATASET + "-scientific", split="train", token=HF_TOKEN) | |
df = ds.to_pandas() | |
# Validate scientific structure | |
required_cols = ["sample_id", "source_text", "source_language", "target_language", | |
"tracks_included", "statistical_weight"] | |
if all(col in df.columns for col in required_cols): | |
print(f"β Loaded scientific test set from HF Hub ({len(df):,} samples)") | |
return df | |
else: | |
print("β οΈ HF Hub test set missing scientific columns, regenerating...") | |
except Exception as e: | |
print(f"β οΈ HF Hub load failed: {e}") | |
# 2) Try local CSV | |
if os.path.exists(LOCAL_PUBLIC_CSV): | |
try: | |
df = pd.read_csv(LOCAL_PUBLIC_CSV) | |
required_cols = ["sample_id", "source_text", "source_language", "target_language"] | |
if all(col in df.columns for col in required_cols): | |
print(f"β Loaded scientific test set from local CSV ({len(df):,} samples)") | |
return df | |
else: | |
print("β οΈ Local CSV has invalid structure, regenerating...") | |
except Exception as e: | |
print(f"β οΈ Failed to read local scientific CSV: {e}") | |
# 3) Regenerate & save | |
print("π Generating new scientific test set...") | |
public_df, _ = _generate_and_save_scientific_test_set() | |
return public_df | |
def get_complete_test_set_scientific() -> pd.DataFrame: | |
"""Load the complete scientific test set with targets.""" | |
# 1) Try HF Hub private | |
try: | |
print("π₯ Attempting to load complete scientific test set from HF Hub...") | |
ds = load_dataset(TEST_SET_DATASET + "-scientific-private", split="train", token=HF_TOKEN) | |
df = ds.to_pandas() | |
required_cols = ["sample_id", "source_text", "target_text", "source_language", | |
"target_language", "tracks_included", "statistical_weight"] | |
if all(col in df.columns for col in required_cols): | |
print(f"β Loaded complete scientific test set from HF Hub ({len(df):,} samples)") | |
return df | |
else: | |
print("β οΈ HF Hub complete test set missing scientific columns, regenerating...") | |
except Exception as e: | |
print(f"β οΈ HF Hub private load failed: {e}") | |
# 2) Try local CSV | |
if os.path.exists(LOCAL_COMPLETE_CSV): | |
try: | |
df = pd.read_csv(LOCAL_COMPLETE_CSV) | |
required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"] | |
if all(col in df.columns for col in required_cols): | |
print(f"β Loaded complete scientific test set from local CSV ({len(df):,} samples)") | |
return df | |
else: | |
print("β οΈ Local complete CSV has invalid structure, regenerating...") | |
except Exception as e: | |
print(f"β οΈ Failed to read local complete scientific CSV: {e}") | |
# 3) Regenerate & save | |
print("π Generating new complete scientific test set...") | |
_, complete_df = _generate_and_save_scientific_test_set() | |
return complete_df | |
def get_track_test_set(track: str) -> pd.DataFrame: | |
"""Get test set filtered for a specific track.""" | |
if track not in EVALUATION_TRACKS: | |
print(f"β Unknown track: {track}") | |
return pd.DataFrame() | |
# Try track-specific CSV first | |
track_csv = LOCAL_TRACK_CSVS.get(track) | |
if track_csv and os.path.exists(track_csv): | |
try: | |
df = pd.read_csv(track_csv) | |
print(f"β Loaded {track} test set from track-specific CSV ({len(df):,} samples)") | |
return df | |
except Exception as e: | |
print(f"β οΈ Failed to read {track} CSV: {e}") | |
# Fallback to filtering main test set | |
public_df = get_public_test_set_scientific() | |
if public_df.empty: | |
return pd.DataFrame() | |
track_languages = EVALUATION_TRACKS[track]["languages"] | |
track_df = public_df[ | |
(public_df["source_language"].isin(track_languages)) & | |
(public_df["target_language"].isin(track_languages)) | |
] | |
print(f"β Filtered {track} test set from main set ({len(track_df):,} samples)") | |
return track_df | |
def create_test_set_download_scientific() -> Tuple[str, Dict]: | |
"""Create scientific test set download with comprehensive metadata.""" | |
public_df = get_public_test_set_scientific() | |
if public_df.empty: | |
stats = { | |
"total_samples": 0, | |
"track_breakdown": {}, | |
"adequacy_assessment": "insufficient", | |
"scientific_metadata": {}, | |
} | |
return LOCAL_PUBLIC_CSV, stats | |
download_path = LOCAL_PUBLIC_CSV | |
# Ensure the CSV is up-to-date | |
try: | |
public_df.to_csv(download_path, index=False) | |
except Exception as e: | |
print(f"β οΈ Error updating scientific CSV: {e}") | |
# Calculate comprehensive statistics | |
try: | |
# Basic statistics | |
stats = { | |
"total_samples": len(public_df), | |
"languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))), | |
"domains": public_df["domain"].unique().tolist() if "domain" in public_df.columns else ["general"], | |
} | |
# Track-specific breakdown | |
track_breakdown = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
track_languages = track_config["languages"] | |
track_data = public_df[ | |
(public_df["source_language"].isin(track_languages)) & | |
(public_df["target_language"].isin(track_languages)) | |
] | |
track_breakdown[track_name] = { | |
"name": track_config["name"], | |
"total_samples": len(track_data), | |
"language_pairs": len(track_data.groupby(["source_language", "target_language"])), | |
"min_samples_per_pair": track_config["min_samples_per_pair"], | |
"statistical_adequacy": len(track_data) >= track_config["min_samples_per_pair"] * len(track_languages) * (len(track_languages) - 1), | |
} | |
stats["track_breakdown"] = track_breakdown | |
# Google-comparable statistics | |
if "google_comparable" in public_df.columns: | |
stats["google_comparable_samples"] = int(public_df["google_comparable"].sum()) | |
stats["google_comparable_rate"] = float(public_df["google_comparable"].mean()) | |
else: | |
stats["google_comparable_samples"] = 0 | |
stats["google_comparable_rate"] = 0.0 | |
# Scientific adequacy assessment | |
adequacy_report = validate_test_set_scientific_adequacy(public_df) | |
stats["adequacy_assessment"] = adequacy_report["overall_adequacy"] | |
stats["adequacy_details"] = adequacy_report | |
# Scientific metadata | |
stats["scientific_metadata"] = { | |
"stratified_sampling": True, | |
"statistical_weighting": "statistical_weight" in public_df.columns, | |
"track_balanced": True, | |
"confidence_level": STATISTICAL_CONFIG["confidence_level"], | |
"recommended_for": [ | |
track for track, info in track_breakdown.items() | |
if info["statistical_adequacy"] | |
], | |
} | |
except Exception as e: | |
print(f"β οΈ Error calculating scientific stats: {e}") | |
stats = { | |
"total_samples": len(public_df), | |
"track_breakdown": {}, | |
"adequacy_assessment": "unknown", | |
"scientific_metadata": {}, | |
} | |
return download_path, stats | |
def validate_test_set_integrity_scientific() -> Dict: | |
"""Comprehensive validation of scientific test set integrity.""" | |
try: | |
public_df = get_public_test_set_scientific() | |
complete_df = get_complete_test_set_scientific() | |
if public_df.empty or complete_df.empty: | |
return { | |
"alignment_check": False, | |
"total_samples": 0, | |
"scientific_adequacy": {}, | |
"track_analysis": {}, | |
"error": "Test sets are empty or could not be loaded", | |
} | |
public_ids = set(public_df["sample_id"]) | |
private_ids = set(complete_df["sample_id"]) | |
# Track-specific analysis | |
track_analysis = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
track_languages = track_config["languages"] | |
min_per_pair = track_config["min_samples_per_pair"] | |
# Analyze public set for this track | |
track_public = public_df[ | |
(public_df["source_language"].isin(track_languages)) & | |
(public_df["target_language"].isin(track_languages)) | |
] | |
# Analyze complete set for this track | |
track_complete = complete_df[ | |
(complete_df["source_language"].isin(track_languages)) & | |
(complete_df["target_language"].isin(track_languages)) | |
] | |
# Calculate coverage | |
pair_coverage = {} | |
for src in track_languages: | |
for tgt in track_languages: | |
if src == tgt: | |
continue | |
public_subset = track_public[ | |
(track_public["source_language"] == src) & | |
(track_public["target_language"] == tgt) | |
] | |
complete_subset = track_complete[ | |
(track_complete["source_language"] == src) & | |
(track_complete["target_language"] == tgt) | |
] | |
pair_coverage[f"{src}_{tgt}"] = { | |
"public_count": len(public_subset), | |
"complete_count": len(complete_subset), | |
"alignment": len(public_subset) == len(complete_subset), | |
"meets_minimum": len(public_subset) >= min_per_pair, | |
} | |
# Track summary | |
total_pairs = len(pair_coverage) | |
adequate_pairs = sum(1 for info in pair_coverage.values() if info["meets_minimum"]) | |
aligned_pairs = sum(1 for info in pair_coverage.values() if info["alignment"]) | |
track_analysis[track_name] = { | |
"total_pairs": total_pairs, | |
"adequate_pairs": adequate_pairs, | |
"aligned_pairs": aligned_pairs, | |
"adequacy_rate": adequate_pairs / max(total_pairs, 1), | |
"alignment_rate": aligned_pairs / max(total_pairs, 1), | |
"pair_coverage": pair_coverage, | |
"statistical_power": calculate_track_statistical_power(track_public, track_config), | |
} | |
# Overall scientific adequacy | |
adequacy_report = validate_test_set_scientific_adequacy(public_df) | |
return { | |
"alignment_check": public_ids <= private_ids, | |
"total_samples": len(public_df), | |
"track_analysis": track_analysis, | |
"scientific_adequacy": adequacy_report, | |
"public_samples": len(public_df), | |
"private_samples": len(complete_df), | |
"id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0, | |
"integrity_score": calculate_integrity_score(track_analysis, adequacy_report), | |
} | |
except Exception as e: | |
return { | |
"alignment_check": False, | |
"total_samples": 0, | |
"scientific_adequacy": {}, | |
"track_analysis": {}, | |
"error": f"Validation failed: {str(e)}", | |
} | |
def calculate_track_statistical_power(track_data: pd.DataFrame, track_config: Dict) -> float: | |
"""Calculate statistical power estimate for a track.""" | |
if track_data.empty: | |
return 0.0 | |
# Simple power estimation based on sample size | |
min_required = track_config["min_samples_per_pair"] | |
languages = track_config["languages"] | |
total_pairs = len(languages) * (len(languages) - 1) | |
# Calculate average samples per pair | |
pair_counts = [] | |
for src in languages: | |
for tgt in languages: | |
if src == tgt: | |
continue | |
pair_samples = track_data[ | |
(track_data["source_language"] == src) & | |
(track_data["target_language"] == tgt) | |
] | |
pair_counts.append(len(pair_samples)) | |
if not pair_counts: | |
return 0.0 | |
avg_samples_per_pair = np.mean(pair_counts) | |
# Rough power estimation (0.8 power at 2x minimum, 0.95 at 4x minimum) | |
if avg_samples_per_pair >= min_required * 4: | |
return 0.95 | |
elif avg_samples_per_pair >= min_required * 2: | |
return 0.8 | |
elif avg_samples_per_pair >= min_required: | |
return 0.6 | |
else: | |
return max(0.0, avg_samples_per_pair / min_required * 0.6) | |
def calculate_integrity_score(track_analysis: Dict, adequacy_report: Dict) -> float: | |
"""Calculate overall integrity score for the test set.""" | |
if not track_analysis or not adequacy_report: | |
return 0.0 | |
# Track adequacy scores | |
track_scores = [] | |
for track_info in track_analysis.values(): | |
adequacy_rate = track_info.get("adequacy_rate", 0.0) | |
alignment_rate = track_info.get("alignment_rate", 0.0) | |
track_score = (adequacy_rate + alignment_rate) / 2 | |
track_scores.append(track_score) | |
# Overall adequacy mapping | |
adequacy_mapping = { | |
"excellent": 1.0, | |
"good": 0.8, | |
"fair": 0.6, | |
"insufficient": 0.2, | |
} | |
overall_adequacy_score = adequacy_mapping.get( | |
adequacy_report.get("overall_adequacy", "insufficient"), 0.2 | |
) | |
# Combined score | |
if track_scores: | |
track_avg = np.mean(track_scores) | |
integrity_score = (track_avg + overall_adequacy_score) / 2 | |
else: | |
integrity_score = overall_adequacy_score | |
return float(integrity_score) |