leaderboard / src /test_set.py
akera's picture
Update src/test_set.py
d82b528 verified
# src/test_set.py
import os
import pandas as pd
import yaml
import numpy as np
from datasets import load_dataset
from config import (
TEST_SET_DATASET,
SALT_DATASET,
MAX_TEST_SAMPLES,
HF_TOKEN,
ALL_UG40_LANGUAGES,
GOOGLE_SUPPORTED_LANGUAGES,
EVALUATION_TRACKS,
LANGUAGE_NAMES,
)
import salt.dataset
from src.utils import get_all_language_pairs
from typing import Dict, List, Optional, Tuple
# Local CSV filenames for persistence
LOCAL_PUBLIC_CSV = "salt_test_set.csv"
LOCAL_COMPLETE_CSV = "salt_complete_test_set.csv"
def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
"""Generate test set from SALT dataset."""
print("πŸ”¬ Generating SALT test set...")
try:
# Build SALT dataset config
dataset_config = f"""
huggingface_load:
path: {SALT_DATASET}
name: text-all
split: test
source:
type: text
language: {ALL_UG40_LANGUAGES}
target:
type: text
language: {ALL_UG40_LANGUAGES}
allow_same_src_and_tgt_language: False
"""
config = yaml.safe_load(dataset_config)
print("πŸ“₯ Loading SALT dataset...")
full_data = pd.DataFrame(salt.dataset.create(config))
print(f"πŸ“Š Loaded {len(full_data):,} samples from SALT dataset")
test_samples = []
sample_id_counter = 1
# Generate samples for each language pair
for src_lang in ALL_UG40_LANGUAGES:
for tgt_lang in ALL_UG40_LANGUAGES:
if src_lang == tgt_lang:
continue
# Filter for this language pair
pair_data = full_data[
(full_data["source.language"] == src_lang) &
(full_data["target.language"] == tgt_lang)
]
if pair_data.empty:
print(f"⚠️ No data found for {src_lang} β†’ {tgt_lang}")
continue
# Sample data for this pair
n_samples = min(len(pair_data), max_samples_per_pair)
sampled = pair_data.sample(n=n_samples, random_state=42)
print(f"βœ… {src_lang} β†’ {tgt_lang}: {len(sampled)} samples")
for _, row in sampled.iterrows():
test_samples.append({
"sample_id": f"salt_{sample_id_counter:06d}",
"source_text": row["source"],
"target_text": row["target"],
"source_language": src_lang,
"target_language": tgt_lang,
"domain": row.get("domain", "general"),
"google_comparable": (
src_lang in GOOGLE_SUPPORTED_LANGUAGES and
tgt_lang in GOOGLE_SUPPORTED_LANGUAGES
),
})
sample_id_counter += 1
test_df = pd.DataFrame(test_samples)
if test_df.empty:
raise ValueError("No test samples generated - check SALT dataset availability")
print(f"βœ… Generated test set: {len(test_df):,} samples")
return test_df
except Exception as e:
print(f"❌ Error generating test set: {e}")
return pd.DataFrame(columns=[
"sample_id", "source_text", "target_text", "source_language",
"target_language", "domain", "google_comparable"
])
def _generate_and_save_test_set() -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Generate and save both public and complete versions of the test set."""
print("πŸ”¬ Generating and saving test sets...")
full_df = generate_test_set()
if full_df.empty:
print("❌ Failed to generate test set")
empty_public = pd.DataFrame(columns=[
"sample_id", "source_text", "source_language",
"target_language", "domain", "google_comparable"
])
empty_complete = pd.DataFrame(columns=[
"sample_id", "source_text", "target_text", "source_language",
"target_language", "domain", "google_comparable"
])
return empty_public, empty_complete
# Public version (no target_text)
public_df = full_df[[
"sample_id", "source_text", "source_language",
"target_language", "domain", "google_comparable"
]].copy()
# Save versions
try:
public_df.to_csv(LOCAL_PUBLIC_CSV, index=False)
full_df.to_csv(LOCAL_COMPLETE_CSV, index=False)
print(f"βœ… Saved test sets: {LOCAL_PUBLIC_CSV}, {LOCAL_COMPLETE_CSV}")
except Exception as e:
print(f"⚠️ Error saving CSVs: {e}")
return public_df, full_df
def get_public_test_set() -> pd.DataFrame:
"""Load the public test set with enhanced fallback logic."""
# 1) Try HF Hub
try:
print("πŸ“₯ Attempting to load test set from HF Hub...")
ds = load_dataset(TEST_SET_DATASET, split="train", token=HF_TOKEN)
df = ds.to_pandas()
# Validate structure
required_cols = ["sample_id", "source_text", "source_language", "target_language"]
if all(col in df.columns for col in required_cols):
print(f"βœ… Loaded test set from HF Hub ({len(df):,} samples)")
return df
else:
print("⚠️ HF Hub test set missing columns, regenerating...")
except Exception as e:
print(f"⚠️ HF Hub load failed: {e}")
# 2) Try local CSV
if os.path.exists(LOCAL_PUBLIC_CSV):
try:
df = pd.read_csv(LOCAL_PUBLIC_CSV)
required_cols = ["sample_id", "source_text", "source_language", "target_language"]
if all(col in df.columns for col in required_cols):
print(f"βœ… Loaded test set from local CSV ({len(df):,} samples)")
return df
else:
print("⚠️ Local CSV has invalid structure, regenerating...")
except Exception as e:
print(f"⚠️ Failed to read local CSV: {e}")
# 3) Regenerate & save
print("πŸ”„ Generating new test set...")
public_df, _ = _generate_and_save_test_set()
return public_df
def get_complete_test_set() -> pd.DataFrame:
"""Load the complete test set with targets."""
# 1) Try HF Hub private
try:
print("πŸ“₯ Attempting to load complete test set from HF Hub...")
ds = load_dataset(TEST_SET_DATASET + "-private", split="train", token=HF_TOKEN)
df = ds.to_pandas()
required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
if all(col in df.columns for col in required_cols):
print(f"βœ… Loaded complete test set from HF Hub ({len(df):,} samples)")
return df
else:
print("⚠️ HF Hub complete test set missing columns, regenerating...")
except Exception as e:
print(f"⚠️ HF Hub private load failed: {e}")
# 2) Try local CSV
if os.path.exists(LOCAL_COMPLETE_CSV):
try:
df = pd.read_csv(LOCAL_COMPLETE_CSV)
required_cols = ["sample_id", "source_text", "target_text", "source_language", "target_language"]
if all(col in df.columns for col in required_cols):
print(f"βœ… Loaded complete test set from local CSV ({len(df):,} samples)")
return df
else:
print("⚠️ Local complete CSV has invalid structure, regenerating...")
except Exception as e:
print(f"⚠️ Failed to read local complete CSV: {e}")
# 3) Regenerate & save
print("πŸ”„ Generating new complete test set...")
_, complete_df = _generate_and_save_test_set()
return complete_df
def create_test_set_download() -> Tuple[str, Dict]:
"""Create test set download with comprehensive metadata."""
public_df = get_public_test_set()
if public_df.empty:
stats = {
"total_samples": 0,
"track_breakdown": {},
"languages": [],
"language_pairs": 0,
"google_comparable_samples": 0,
}
return LOCAL_PUBLIC_CSV, stats
download_path = LOCAL_PUBLIC_CSV
# Ensure the CSV is up-to-date
try:
public_df.to_csv(download_path, index=False)
except Exception as e:
print(f"⚠️ Error updating CSV: {e}")
# Calculate comprehensive statistics
try:
# Basic statistics
stats = {
"total_samples": len(public_df),
"languages": sorted(list(set(public_df["source_language"]).union(public_df["target_language"]))),
"language_pairs": len(public_df.groupby(["source_language", "target_language"])),
}
# Track-specific breakdown
track_breakdown = {}
for track_name, track_config in EVALUATION_TRACKS.items():
track_languages = track_config["languages"]
track_data = public_df[
(public_df["source_language"].isin(track_languages)) &
(public_df["target_language"].isin(track_languages))
]
track_breakdown[track_name] = {
"total_samples": len(track_data),
"language_pairs": len(track_data.groupby(["source_language", "target_language"])),
"languages": track_languages,
}
stats["track_breakdown"] = track_breakdown
# Google-comparable statistics
if "google_comparable" in public_df.columns:
stats["google_comparable_samples"] = int(public_df["google_comparable"].sum())
else:
stats["google_comparable_samples"] = 0
except Exception as e:
print(f"⚠️ Error calculating stats: {e}")
stats = {
"total_samples": len(public_df),
"track_breakdown": {},
"languages": [],
"language_pairs": 0,
"google_comparable_samples": 0,
}
return download_path, stats
def get_track_test_set(track: str) -> pd.DataFrame:
"""Get test set filtered for a specific track."""
if track not in EVALUATION_TRACKS:
print(f"❌ Unknown track: {track}")
return pd.DataFrame()
# Get main test set and filter
public_df = get_public_test_set()
if public_df.empty:
return pd.DataFrame()
track_languages = EVALUATION_TRACKS[track]["languages"]
track_df = public_df[
(public_df["source_language"].isin(track_languages)) &
(public_df["target_language"].isin(track_languages))
]
print(f"βœ… Filtered {track} test set: {len(track_df):,} samples")
return track_df
def validate_test_set_integrity() -> Dict:
"""Validate test set integrity."""
try:
public_df = get_public_test_set()
complete_df = get_complete_test_set()
if public_df.empty or complete_df.empty:
return {
"alignment_check": False,
"total_samples": 0,
"track_analysis": {},
"error": "Test sets are empty or could not be loaded",
}
public_ids = set(public_df["sample_id"])
private_ids = set(complete_df["sample_id"])
# Track-specific analysis
track_analysis = {}
for track_name, track_config in EVALUATION_TRACKS.items():
track_languages = track_config["languages"]
# Analyze public set for this track
track_public = public_df[
(public_df["source_language"].isin(track_languages)) &
(public_df["target_language"].isin(track_languages))
]
# Analyze complete set for this track
track_complete = complete_df[
(complete_df["source_language"].isin(track_languages)) &
(complete_df["target_language"].isin(track_languages))
]
track_analysis[track_name] = {
"public_samples": len(track_public),
"complete_samples": len(track_complete),
"alignment": len(track_public) == len(track_complete),
"languages": track_languages,
}
return {
"alignment_check": public_ids <= private_ids,
"total_samples": len(public_df),
"track_analysis": track_analysis,
"public_samples": len(public_df),
"private_samples": len(complete_df),
"id_alignment_rate": len(public_ids & private_ids) / len(public_ids) if public_ids else 0.0,
}
except Exception as e:
return {
"alignment_check": False,
"total_samples": 0,
"track_analysis": {},
"error": f"Validation failed: {str(e)}",
}