Spaces:
Sleeping
Sleeping
# src/validation.py | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Tuple, Optional | |
import json | |
import io | |
import re | |
from config import ( | |
PREDICTION_FORMAT, | |
VALIDATION_CONFIG, | |
MODEL_CATEGORIES, | |
EVALUATION_TRACKS, | |
ALL_UG40_LANGUAGES, | |
SAMPLE_SIZE_RECOMMENDATIONS, | |
) | |
def detect_model_category(model_name: str, author: str, description: str) -> str: | |
"""Automatically detect model category based on name and metadata.""" | |
# Combine all text for analysis | |
text_to_analyze = f"{model_name} {author} {description}".lower() | |
# Category detection patterns | |
detection_patterns = PREDICTION_FORMAT["category_detection"] | |
# Check for specific patterns | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("google", [])): | |
return "commercial" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("nllb", [])): | |
return "research" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("m2m", [])): | |
return "research" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("baseline", [])): | |
return "baseline" | |
# Check for research indicators | |
research_indicators = [ | |
"university", "research", "paper", "arxiv", "acl", "emnlp", "naacl", | |
"transformer", "bert", "gpt", "t5", "mbart", "academic" | |
] | |
if any(indicator in text_to_analyze for indicator in research_indicators): | |
return "research" | |
# Check for commercial indicators | |
commercial_indicators = [ | |
"google", "microsoft", "azure", "aws", "openai", "anthropic", | |
"commercial", "api", "cloud", "translate" | |
] | |
if any(indicator in text_to_analyze for indicator in commercial_indicators): | |
return "commercial" | |
# Default to community | |
return "community" | |
def validate_file_format_enhanced(file_content: bytes, filename: str) -> Dict: | |
"""Enhanced file format validation with stricter requirements.""" | |
try: | |
# Determine file type | |
if filename.endswith(".csv"): | |
df = pd.read_csv(io.BytesIO(file_content)) | |
elif filename.endswith(".tsv"): | |
df = pd.read_csv(io.BytesIO(file_content), sep="\t") | |
elif filename.endswith(".json"): | |
data = json.loads(file_content.decode("utf-8")) | |
df = pd.DataFrame(data) | |
else: | |
return { | |
"valid": False, | |
"error": f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}", | |
} | |
# Check required columns | |
missing_cols = set(PREDICTION_FORMAT["required_columns"]) - set(df.columns) | |
if missing_cols: | |
return { | |
"valid": False, | |
"error": f"Missing required columns: {', '.join(missing_cols)}", | |
} | |
# Basic data validation | |
if len(df) == 0: | |
return {"valid": False, "error": "File is empty"} | |
# Enhanced validation checks | |
validation_issues = [] | |
# Check for required data | |
if df["sample_id"].isna().any(): | |
validation_issues.append("Missing sample_id values found") | |
if df["prediction"].isna().any(): | |
na_count = df["prediction"].isna().sum() | |
validation_issues.append(f"Missing prediction values found ({na_count} empty predictions)") | |
# Check for duplicates | |
duplicates = df["sample_id"].duplicated() | |
if duplicates.any(): | |
dup_count = duplicates.sum() | |
validation_issues.append(f"Duplicate sample_id values found ({dup_count} duplicates)") | |
# Data type validation | |
if not df["sample_id"].dtype == "object" and not df["sample_id"].dtype.name.startswith("str"): | |
df["sample_id"] = df["sample_id"].astype(str) | |
# Check sample_id format | |
invalid_ids = ~df["sample_id"].str.match(r"salt_\d{6}", na=False) | |
if invalid_ids.any(): | |
invalid_count = invalid_ids.sum() | |
validation_issues.append(f"Invalid sample_id format found ({invalid_count} invalid IDs)") | |
# Return results | |
if validation_issues: | |
return { | |
"valid": False, | |
"error": "; ".join(validation_issues), | |
"dataframe": df, | |
"row_count": len(df), | |
"columns": list(df.columns), | |
} | |
return { | |
"valid": True, | |
"dataframe": df, | |
"row_count": len(df), | |
"columns": list(df.columns), | |
} | |
except Exception as e: | |
return {"valid": False, "error": f"Error parsing file: {str(e)}"} | |
def validate_predictions_content_enhanced(predictions: pd.DataFrame) -> Dict: | |
"""Enhanced prediction content validation with stricter quality checks.""" | |
issues = [] | |
warnings = [] | |
quality_metrics = {} | |
# Basic content checks | |
empty_predictions = predictions["prediction"].str.strip().eq("").sum() | |
if empty_predictions > 0: | |
issues.append(f"{empty_predictions} empty predictions found") | |
# Length analysis | |
pred_lengths = predictions["prediction"].str.len() | |
quality_metrics["avg_length"] = float(pred_lengths.mean()) | |
quality_metrics["std_length"] = float(pred_lengths.std()) | |
# Check for suspiciously short predictions | |
short_predictions = (pred_lengths < 3).sum() | |
if short_predictions > len(predictions) * 0.05: # More than 5% | |
issues.append(f"{short_predictions} very short predictions (< 3 characters)") | |
# Check for suspiciously long predictions | |
long_predictions = (pred_lengths > 500).sum() | |
if long_predictions > len(predictions) * 0.01: # More than 1% | |
warnings.append(f"{long_predictions} very long predictions (> 500 characters)") | |
# Check for repeated predictions (more stringent) | |
duplicate_predictions = predictions["prediction"].duplicated().sum() | |
duplicate_rate = duplicate_predictions / len(predictions) | |
quality_metrics["duplicate_rate"] = float(duplicate_rate) | |
if duplicate_rate > VALIDATION_CONFIG["quality_thresholds"]["max_duplicate_rate"]: | |
issues.append(f"{duplicate_predictions} duplicate prediction texts ({duplicate_rate:.1%})") | |
# Check for placeholder text | |
placeholder_patterns = [ | |
r"^(test|placeholder|todo|xxx|aaa|bbb)$", | |
r"^[a-z]{1,3}$", # Very short gibberish | |
r"^\d+$", # Just numbers | |
r"^[^\w\s]*$", # Only punctuation | |
] | |
placeholder_count = 0 | |
for pattern in placeholder_patterns: | |
placeholder_matches = predictions["prediction"].str.match(pattern, flags=re.IGNORECASE, na=False).sum() | |
placeholder_count += placeholder_matches | |
if placeholder_count > len(predictions) * 0.02: # More than 2% | |
issues.append(f"{placeholder_count} placeholder-like predictions detected") | |
# Language detection (basic) | |
non_ascii_rate = predictions["prediction"].str.contains(r"[^\x00-\x7f]", na=False).mean() | |
quality_metrics["non_ascii_rate"] = float(non_ascii_rate) | |
# Check for appropriate character distribution for African languages | |
if non_ascii_rate < 0.1: # Less than 10% non-ASCII might indicate English-only | |
warnings.append("Low non-ASCII character rate - check if translations include local language scripts") | |
# Calculate overall quality score | |
quality_score = 1.0 | |
quality_score -= len(issues) * 0.3 # Major penalty for issues | |
quality_score -= len(warnings) * 0.1 # Minor penalty for warnings | |
quality_score -= max(0, duplicate_rate - 0.05) * 2 # Penalty for excessive duplicates | |
# Length appropriateness | |
if quality_metrics["avg_length"] < VALIDATION_CONFIG["quality_thresholds"]["min_avg_length"]: | |
quality_score -= 0.2 | |
elif quality_metrics["avg_length"] > VALIDATION_CONFIG["quality_thresholds"]["max_avg_length"]: | |
quality_score -= 0.1 | |
quality_score = max(0.0, min(1.0, quality_score)) | |
return { | |
"has_issues": len(issues) > 0, | |
"issues": issues, | |
"warnings": warnings, | |
"quality_score": quality_score, | |
"quality_metrics": quality_metrics, | |
} | |
def validate_against_test_set_enhanced( | |
predictions: pd.DataFrame, test_set: pd.DataFrame | |
) -> Dict: | |
"""Enhanced validation against test set with track-specific analysis.""" | |
# Convert IDs to string for comparison | |
pred_ids = set(predictions["sample_id"].astype(str)) | |
test_ids = set(test_set["sample_id"].astype(str)) | |
# Check overall coverage | |
missing_ids = test_ids - pred_ids | |
extra_ids = pred_ids - test_ids | |
matching_ids = pred_ids & test_ids | |
overall_coverage = len(matching_ids) / len(test_ids) | |
# Track-specific coverage analysis | |
track_coverage = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
track_languages = track_config["languages"] | |
# Filter test set to track languages | |
track_test_set = test_set[ | |
(test_set["source_language"].isin(track_languages)) & | |
(test_set["target_language"].isin(track_languages)) | |
] | |
if len(track_test_set) == 0: | |
continue | |
track_test_ids = set(track_test_set["sample_id"].astype(str)) | |
track_matching_ids = pred_ids & track_test_ids | |
track_coverage[track_name] = { | |
"total_samples": len(track_test_set), | |
"covered_samples": len(track_matching_ids), | |
"coverage_rate": len(track_matching_ids) / len(track_test_set), | |
"meets_minimum": len(track_matching_ids) >= VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
"min_required": VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
} | |
# Language pair coverage analysis | |
pair_coverage = {} | |
for _, row in test_set.iterrows(): | |
pair_key = f"{row['source_language']}_{row['target_language']}" | |
if pair_key not in pair_coverage: | |
pair_coverage[pair_key] = {"total": 0, "covered": 0} | |
pair_coverage[pair_key]["total"] += 1 | |
if str(row["sample_id"]) in pred_ids: | |
pair_coverage[pair_key]["covered"] += 1 | |
# Calculate pair-wise coverage rates | |
for pair_key in pair_coverage: | |
pair_info = pair_coverage[pair_key] | |
pair_info["coverage_rate"] = pair_info["covered"] / pair_info["total"] | |
# Missing rate validation | |
missing_rate = len(missing_ids) / len(test_ids) | |
meets_missing_threshold = missing_rate <= VALIDATION_CONFIG["max_missing_rate"] | |
return { | |
"overall_coverage": overall_coverage, | |
"missing_count": len(missing_ids), | |
"extra_count": len(extra_ids), | |
"matching_count": len(matching_ids), | |
"missing_rate": missing_rate, | |
"meets_missing_threshold": meets_missing_threshold, | |
"is_complete": overall_coverage == 1.0, | |
"track_coverage": track_coverage, | |
"pair_coverage": pair_coverage, | |
"missing_ids_sample": list(missing_ids)[:10], | |
"extra_ids_sample": list(extra_ids)[:10], | |
} | |
def assess_statistical_adequacy( | |
validation_result: Dict, model_category: str | |
) -> Dict: | |
"""Assess statistical adequacy for scientific evaluation.""" | |
adequacy_assessment = { | |
"overall_adequate": True, | |
"track_adequacy": {}, | |
"recommendations": [], | |
"statistical_power_estimate": {}, | |
} | |
track_coverage = validation_result.get("track_coverage", {}) | |
for track_name, coverage_info in track_coverage.items(): | |
track_config = EVALUATION_TRACKS[track_name] | |
# Sample size adequacy | |
covered_samples = coverage_info["covered_samples"] | |
min_required = coverage_info["min_required"] | |
sample_adequate = covered_samples >= min_required | |
# Coverage rate adequacy | |
coverage_rate = coverage_info["coverage_rate"] | |
coverage_adequate = coverage_rate >= 0.8 # 80% coverage minimum | |
# Statistical power estimation (simplified) | |
estimated_power = min(1.0, covered_samples / (min_required * 1.5)) | |
track_adequate = sample_adequate and coverage_adequate | |
adequacy_assessment["track_adequacy"][track_name] = { | |
"sample_adequate": sample_adequate, | |
"coverage_adequate": coverage_adequate, | |
"overall_adequate": track_adequate, | |
"covered_samples": covered_samples, | |
"min_required": min_required, | |
"coverage_rate": coverage_rate, | |
"estimated_power": estimated_power, | |
} | |
if not track_adequate: | |
adequacy_assessment["overall_adequate"] = False | |
adequacy_assessment["statistical_power_estimate"][track_name] = estimated_power | |
# Generate recommendations | |
if not adequacy_assessment["overall_adequate"]: | |
inadequate_tracks = [ | |
track for track, info in adequacy_assessment["track_adequacy"].items() | |
if not info["overall_adequate"] | |
] | |
adequacy_assessment["recommendations"].append( | |
f"Insufficient samples for tracks: {', '.join(inadequate_tracks)}" | |
) | |
# Category-specific recommendations | |
if model_category == "commercial" and not adequacy_assessment["track_adequacy"].get("google_comparable", {}).get("overall_adequate", False): | |
adequacy_assessment["recommendations"].append( | |
"Commercial models should ensure adequate coverage of Google-comparable track" | |
) | |
return adequacy_assessment | |
def generate_scientific_validation_report( | |
format_result: Dict, | |
content_result: Dict, | |
test_set_result: Dict, | |
adequacy_result: Dict, | |
model_name: str = "", | |
detected_category: str = "community", | |
) -> str: | |
"""Generate comprehensive scientific validation report.""" | |
report = [] | |
# Header | |
report.append(f"# π¬ Scientific Validation Report: {model_name or 'Submission'}") | |
report.append("") | |
# Model categorization | |
category_info = MODEL_CATEGORIES.get(detected_category, MODEL_CATEGORIES["community"]) | |
report.append(f"**Detected Model Category**: {category_info['name']}") | |
report.append(f"**Category Description**: {category_info['description']}") | |
report.append("") | |
# File format validation | |
if format_result["valid"]: | |
report.append("β **File Format**: Valid") | |
report.append(f" - Rows: {format_result['row_count']:,}") | |
report.append(f" - Columns: {', '.join(format_result['columns'])}") | |
else: | |
report.append("β **File Format**: Invalid") | |
report.append(f" - Error: {format_result['error']}") | |
return "\n".join(report) | |
# Content quality validation | |
quality_score = content_result.get("quality_score", 0.0) | |
if content_result["has_issues"]: | |
report.append("β **Content Quality**: Issues Found") | |
for issue in content_result["issues"]: | |
report.append(f" - β {issue}") | |
else: | |
report.append("β **Content Quality**: Good") | |
if content_result["warnings"]: | |
for warning in content_result["warnings"]: | |
report.append(f" - β οΈ {warning}") | |
report.append(f" - **Quality Score**: {quality_score:.2f}/1.00") | |
report.append("") | |
# Test set coverage validation | |
overall_coverage = test_set_result["overall_coverage"] | |
meets_threshold = test_set_result["meets_missing_threshold"] | |
if overall_coverage == 1.0: | |
report.append("β **Test Set Coverage**: Complete") | |
elif overall_coverage >= 0.95 and meets_threshold: | |
report.append("β **Test Set Coverage**: Adequate") | |
else: | |
report.append("β **Test Set Coverage**: Insufficient") | |
report.append(f" - Coverage: {overall_coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})") | |
report.append(f" - Missing Rate: {test_set_result['missing_rate']:.1%}") | |
report.append("") | |
# Track-specific coverage analysis | |
report.append("## π Track-Specific Analysis") | |
track_coverage = test_set_result.get("track_coverage", {}) | |
for track_name, coverage_info in track_coverage.items(): | |
track_config = EVALUATION_TRACKS[track_name] | |
status = "β " if coverage_info["meets_minimum"] else "β" | |
report.append(f"### {status} {track_config['name']}") | |
report.append(f" - **Samples**: {coverage_info['covered_samples']:,} / {coverage_info['total_samples']:,}") | |
report.append(f" - **Coverage**: {coverage_info['coverage_rate']:.1%}") | |
report.append(f" - **Minimum Required**: {coverage_info['min_required']:,}") | |
report.append(f" - **Status**: {'Adequate' if coverage_info['meets_minimum'] else 'Insufficient'}") | |
report.append("") | |
# Statistical adequacy assessment | |
report.append("## π¬ Statistical Adequacy Assessment") | |
if adequacy_result["overall_adequate"]: | |
report.append("β **Overall Assessment**: Statistically adequate for scientific evaluation") | |
else: | |
report.append("β **Overall Assessment**: Insufficient for rigorous scientific evaluation") | |
# Track adequacy details | |
for track_name, track_adequacy in adequacy_result["track_adequacy"].items(): | |
track_config = EVALUATION_TRACKS[track_name] | |
power = track_adequacy["estimated_power"] | |
status = "β " if track_adequacy["overall_adequate"] else "β" | |
report.append(f" - {status} **{track_config['name']}**: Statistical power β {power:.1%}") | |
# Recommendations | |
if adequacy_result["recommendations"]: | |
report.append("") | |
report.append("## π‘ Recommendations") | |
for rec in adequacy_result["recommendations"]: | |
report.append(f" - {rec}") | |
# Final verdict | |
report.append("") | |
all_checks_pass = ( | |
format_result["valid"] and | |
not content_result["has_issues"] and | |
overall_coverage >= 0.95 and | |
meets_threshold and | |
adequacy_result["overall_adequate"] | |
) | |
can_evaluate_with_limits = ( | |
format_result["valid"] and | |
overall_coverage >= 0.8 and | |
not any("β" in issue for issue in content_result.get("issues", [])) | |
) | |
if all_checks_pass: | |
report.append("π **Final Verdict**: Ready for scientific evaluation!") | |
elif can_evaluate_with_limits: | |
report.append("β οΈ **Final Verdict**: Can be evaluated with limitations") | |
report.append(" - Results will include notes about statistical limitations") | |
report.append(" - Consider improving coverage/quality for publication-grade results") | |
else: | |
report.append("β **Final Verdict**: Please address critical issues before submission") | |
return "\n".join(report) | |
def validate_submission_scientific( | |
file_content: bytes, | |
filename: str, | |
test_set: pd.DataFrame, | |
model_name: str = "", | |
author: str = "", | |
description: str = "" | |
) -> Dict: | |
"""Complete scientific validation pipeline for submissions.""" | |
# Step 1: Detect model category | |
detected_category = detect_model_category(model_name, author, description) | |
# Step 2: Enhanced file format validation | |
format_result = validate_file_format_enhanced(file_content, filename) | |
if not format_result["valid"]: | |
return { | |
"valid": False, | |
"can_evaluate": False, # New field for evaluation eligibility | |
"category": detected_category, | |
"report": generate_scientific_validation_report( | |
format_result, {}, {}, {}, model_name, detected_category | |
), | |
"predictions": None, | |
"adequacy": {}, | |
} | |
predictions = format_result["dataframe"] | |
# Step 3: Enhanced content validation | |
content_result = validate_predictions_content_enhanced(predictions) | |
# Step 4: Enhanced test set validation | |
test_set_result = validate_against_test_set_enhanced(predictions, test_set) | |
# Step 5: Statistical adequacy assessment | |
adequacy_result = assess_statistical_adequacy(test_set_result, detected_category) | |
# Step 6: Generate comprehensive report | |
report = generate_scientific_validation_report( | |
format_result, content_result, test_set_result, adequacy_result, | |
model_name, detected_category | |
) | |
# Overall validity determination (strict scientific standards) | |
is_scientifically_valid = ( | |
format_result["valid"] and | |
not content_result["has_issues"] and | |
test_set_result["overall_coverage"] >= 0.95 and | |
test_set_result["meets_missing_threshold"] and | |
adequacy_result["overall_adequate"] | |
) | |
# Evaluation eligibility (more permissive - can evaluate with limitations) | |
can_evaluate = ( | |
format_result["valid"] and | |
test_set_result["overall_coverage"] >= 0.8 and # 80% coverage minimum | |
not any("β" in issue for issue in content_result.get("issues", [])) # No critical content issues | |
) | |
return { | |
"valid": is_scientifically_valid, | |
"can_evaluate": can_evaluate, # New field | |
"category": detected_category, | |
"coverage": test_set_result["overall_coverage"], | |
"report": report, | |
"predictions": predictions, | |
"adequacy": adequacy_result, | |
"quality_score": content_result.get("quality_score", 0.8), | |
"track_coverage": test_set_result.get("track_coverage", {}), | |
"scientific_metadata": { | |
"validation_timestamp": pd.Timestamp.now().isoformat(), | |
"validation_version": "2.0-scientific", | |
"detected_category": detected_category, | |
"statistical_adequacy": adequacy_result["overall_adequate"], | |
"evaluation_recommended": can_evaluate, | |
}, | |
} |