leaderboard / src /validation.py
akera's picture
Update src/validation.py
7827065 verified
# src/validation.py
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple, Optional
import json
import io
import re
from config import (
PREDICTION_FORMAT,
VALIDATION_CONFIG,
MODEL_CATEGORIES,
EVALUATION_TRACKS,
ALL_UG40_LANGUAGES,
)
def detect_model_category(model_name: str, author: str, description: str) -> str:
"""Automatically detect model category based on name and metadata."""
# Combine all text for analysis
text_to_analyze = f"{model_name} {author} {description}".lower()
# Category detection patterns
detection_patterns = PREDICTION_FORMAT["category_detection"]
# Check for specific patterns
if any(pattern in text_to_analyze for pattern in detection_patterns.get("google", [])):
return "commercial"
if any(pattern in text_to_analyze for pattern in detection_patterns.get("nllb", [])):
return "research"
if any(pattern in text_to_analyze for pattern in detection_patterns.get("m2m", [])):
return "research"
if any(pattern in text_to_analyze for pattern in detection_patterns.get("baseline", [])):
return "baseline"
# Check for research indicators
research_indicators = [
"university", "research", "paper", "arxiv", "acl", "emnlp", "naacl",
"transformer", "bert", "gpt", "t5", "mbart", "academic"
]
if any(indicator in text_to_analyze for indicator in research_indicators):
return "research"
# Check for commercial indicators
commercial_indicators = [
"google", "microsoft", "azure", "aws", "openai", "anthropic",
"commercial", "api", "cloud", "translate"
]
if any(indicator in text_to_analyze for indicator in commercial_indicators):
return "commercial"
# Default to community
return "community"
def validate_file_format(file_content: bytes, filename: str) -> Dict:
"""Validate file format and structure."""
try:
# Determine file type
if filename.endswith(".csv"):
df = pd.read_csv(io.BytesIO(file_content))
elif filename.endswith(".tsv"):
df = pd.read_csv(io.BytesIO(file_content), sep="\t")
elif filename.endswith(".json"):
data = json.loads(file_content.decode("utf-8"))
df = pd.DataFrame(data)
else:
return {
"valid": False,
"error": f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}",
}
# Check required columns
missing_cols = set(PREDICTION_FORMAT["required_columns"]) - set(df.columns)
if missing_cols:
return {
"valid": False,
"error": f"Missing required columns: {', '.join(missing_cols)}",
}
# Basic data validation
if len(df) == 0:
return {"valid": False, "error": "File is empty"}
# Validation checks
validation_issues = []
# Check for required data
if df["sample_id"].isna().any():
validation_issues.append("Missing sample_id values found")
if df["prediction"].isna().any():
na_count = df["prediction"].isna().sum()
validation_issues.append(f"Missing prediction values found ({na_count} empty predictions)")
# Check for duplicates
duplicates = df["sample_id"].duplicated()
if duplicates.any():
dup_count = duplicates.sum()
validation_issues.append(f"Duplicate sample_id values found ({dup_count} duplicates)")
# Data type validation
if not df["sample_id"].dtype == "object":
df["sample_id"] = df["sample_id"].astype(str)
# Check sample_id format
invalid_ids = ~df["sample_id"].str.match(r"salt_\d{6}", na=False)
if invalid_ids.any():
invalid_count = invalid_ids.sum()
validation_issues.append(f"Invalid sample_id format found ({invalid_count} invalid IDs)")
# Return results
if validation_issues:
return {
"valid": False,
"error": "; ".join(validation_issues),
"dataframe": df,
"row_count": len(df),
"columns": list(df.columns),
}
return {
"valid": True,
"dataframe": df,
"row_count": len(df),
"columns": list(df.columns),
}
except Exception as e:
return {"valid": False, "error": f"Error parsing file: {str(e)}"}
def validate_predictions_content(predictions: pd.DataFrame) -> Dict:
"""Validate prediction content quality."""
issues = []
warnings = []
quality_metrics = {}
# Basic content checks
empty_predictions = predictions["prediction"].str.strip().eq("").sum()
if empty_predictions > 0:
issues.append(f"{empty_predictions} empty predictions found")
# Length analysis
pred_lengths = predictions["prediction"].str.len()
quality_metrics["avg_length"] = float(pred_lengths.mean())
quality_metrics["std_length"] = float(pred_lengths.std())
# Check for suspiciously short predictions
short_predictions = (pred_lengths < 3).sum()
if short_predictions > len(predictions) * 0.05: # More than 5%
issues.append(f"{short_predictions} very short predictions (< 3 characters)")
# Check for suspiciously long predictions
long_predictions = (pred_lengths > 500).sum()
if long_predictions > len(predictions) * 0.01: # More than 1%
warnings.append(f"{long_predictions} very long predictions (> 500 characters)")
# Check for repeated predictions
duplicate_predictions = predictions["prediction"].duplicated().sum()
duplicate_rate = duplicate_predictions / len(predictions)
quality_metrics["duplicate_rate"] = float(duplicate_rate)
if duplicate_rate > VALIDATION_CONFIG["quality_thresholds"]["max_duplicate_rate"]:
issues.append(f"{duplicate_predictions} duplicate prediction texts ({duplicate_rate:.1%})")
# Check for placeholder text
placeholder_patterns = [
r"^(test|placeholder|todo|xxx|aaa|bbb)$",
r"^[a-z]{1,3}$", # Very short gibberish
r"^\d+$", # Just numbers
r"^[^\w\s]*$", # Only punctuation
]
placeholder_count = 0
for pattern in placeholder_patterns:
placeholder_matches = predictions["prediction"].str.match(pattern, flags=re.IGNORECASE, na=False).sum()
placeholder_count += placeholder_matches
if placeholder_count > len(predictions) * 0.02: # More than 2%
issues.append(f"{placeholder_count} placeholder-like predictions detected")
# Calculate overall quality score
quality_score = 1.0
quality_score -= len(issues) * 0.3 # Major penalty for issues
quality_score -= len(warnings) * 0.1 # Minor penalty for warnings
quality_score -= max(0, duplicate_rate - 0.05) * 2 # Penalty for excessive duplicates
# Length appropriateness
if quality_metrics["avg_length"] < VALIDATION_CONFIG["quality_thresholds"]["min_avg_length"]:
quality_score -= 0.2
elif quality_metrics["avg_length"] > VALIDATION_CONFIG["quality_thresholds"]["max_avg_length"]:
quality_score -= 0.1
quality_score = max(0.0, min(1.0, quality_score))
return {
"has_issues": len(issues) > 0,
"issues": issues,
"warnings": warnings,
"quality_score": quality_score,
"quality_metrics": quality_metrics,
}
def validate_against_test_set(
predictions: pd.DataFrame, test_set: pd.DataFrame
) -> Dict:
"""Validate predictions against test set."""
# Convert IDs to string for comparison
pred_ids = set(predictions["sample_id"].astype(str))
test_ids = set(test_set["sample_id"].astype(str))
# Check overall coverage
missing_ids = test_ids - pred_ids
extra_ids = pred_ids - test_ids
matching_ids = pred_ids & test_ids
overall_coverage = len(matching_ids) / len(test_ids)
# Track-specific coverage analysis
track_coverage = {}
for track_name, track_config in EVALUATION_TRACKS.items():
track_languages = track_config["languages"]
# Filter test set to track languages
track_test_set = test_set[
(test_set["source_language"].isin(track_languages)) &
(test_set["target_language"].isin(track_languages))
]
if len(track_test_set) == 0:
continue
track_test_ids = set(track_test_set["sample_id"].astype(str))
track_matching_ids = pred_ids & track_test_ids
track_coverage[track_name] = {
"total_samples": len(track_test_set),
"covered_samples": len(track_matching_ids),
"coverage_rate": len(track_matching_ids) / len(track_test_set),
"meets_minimum": len(track_matching_ids) >= VALIDATION_CONFIG["min_samples_per_track"][track_name],
"min_required": VALIDATION_CONFIG["min_samples_per_track"][track_name],
}
# Missing rate validation
missing_rate = len(missing_ids) / len(test_ids)
meets_missing_threshold = missing_rate <= VALIDATION_CONFIG["max_missing_rate"]
return {
"overall_coverage": overall_coverage,
"missing_count": len(missing_ids),
"extra_count": len(extra_ids),
"matching_count": len(matching_ids),
"missing_rate": missing_rate,
"meets_missing_threshold": meets_missing_threshold,
"is_complete": overall_coverage == 1.0,
"track_coverage": track_coverage,
"missing_ids_sample": list(missing_ids)[:10],
"extra_ids_sample": list(extra_ids)[:10],
}
def generate_validation_report(
format_result: Dict,
content_result: Dict,
test_set_result: Dict,
model_name: str = "",
detected_category: str = "community",
) -> str:
"""Generate comprehensive validation report."""
report = []
# Header
report.append(f"### πŸ”¬ Validation Report: {model_name or 'Submission'}")
report.append("")
# Model categorization
category_info = MODEL_CATEGORIES.get(detected_category, MODEL_CATEGORIES["community"])
report.append(f"**Detected Model Category**: {category_info['name']}")
report.append("")
# File format validation
if format_result["valid"]:
report.append("βœ… **File Format**: Valid")
report.append(f" - Rows: {format_result['row_count']:,}")
report.append(f" - Columns: {', '.join(format_result['columns'])}")
else:
report.append("❌ **File Format**: Invalid")
report.append(f" - Error: {format_result['error']}")
return "\n".join(report)
# Content quality validation
quality_score = content_result.get("quality_score", 0.0)
if content_result["has_issues"]:
report.append("❌ **Content Quality**: Issues Found")
for issue in content_result["issues"]:
report.append(f" - ❌ {issue}")
else:
report.append("βœ… **Content Quality**: Good")
if content_result["warnings"]:
for warning in content_result["warnings"]:
report.append(f" - ⚠️ {warning}")
report.append(f" - **Quality Score**: {quality_score:.2f}/1.00")
report.append("")
# Test set coverage validation
overall_coverage = test_set_result["overall_coverage"]
meets_threshold = test_set_result["meets_missing_threshold"]
if overall_coverage == 1.0:
report.append("βœ… **Test Set Coverage**: Complete")
elif overall_coverage >= 0.95 and meets_threshold:
report.append("βœ… **Test Set Coverage**: Adequate")
else:
report.append("❌ **Test Set Coverage**: Insufficient")
report.append(f" - Coverage: {overall_coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})")
report.append(f" - Missing Rate: {test_set_result['missing_rate']:.1%}")
report.append("")
# Track-specific coverage analysis
report.append("#### πŸ“Š Track-Specific Analysis")
track_coverage = test_set_result.get("track_coverage", {})
for track_name, coverage_info in track_coverage.items():
track_config = EVALUATION_TRACKS[track_name]
status = "βœ…" if coverage_info["meets_minimum"] else "❌"
report.append(f"**{status} {track_config['name']}**:")
report.append(f" - **Samples**: {coverage_info['covered_samples']:,} / {coverage_info['total_samples']:,}")
report.append(f" - **Coverage**: {coverage_info['coverage_rate']:.1%}")
report.append(f" - **Minimum Required**: {coverage_info['min_required']:,}")
report.append(f" - **Status**: {'Adequate' if coverage_info['meets_minimum'] else 'Insufficient'}")
report.append("")
# Final verdict
all_checks_pass = (
format_result["valid"] and
not content_result["has_issues"] and
overall_coverage >= 0.95 and
meets_threshold
)
can_evaluate_with_limits = (
format_result["valid"] and
overall_coverage >= 0.8 and
not any("❌" in issue for issue in content_result.get("issues", []))
)
if all_checks_pass:
report.append("πŸŽ‰ **Final Verdict**: Ready for evaluation!")
elif can_evaluate_with_limits:
report.append("⚠️ **Final Verdict**: Can be evaluated with limitations")
report.append(" - Results will include notes about limitations")
else:
report.append("❌ **Final Verdict**: Please address critical issues before submission")
return "\n".join(report)
def validate_submission(
file_content: bytes,
filename: str,
test_set: pd.DataFrame,
model_name: str = "",
author: str = "",
description: str = ""
) -> Dict:
"""Complete validation pipeline for submissions."""
# Step 1: Detect model category
detected_category = detect_model_category(model_name, author, description)
# Step 2: File format validation
format_result = validate_file_format(file_content, filename)
if not format_result["valid"]:
return {
"valid": False,
"can_evaluate": False,
"category": detected_category,
"report": generate_validation_report(
format_result, {}, {}, model_name, detected_category
),
"predictions": None,
}
predictions = format_result["dataframe"]
# Step 3: Content validation
content_result = validate_predictions_content(predictions)
# Step 4: Test set validation
test_set_result = validate_against_test_set(predictions, test_set)
# Step 5: Generate report
report = generate_validation_report(
format_result, content_result, test_set_result, model_name, detected_category
)
# Overall validity determination
is_valid = (
format_result["valid"] and
not content_result["has_issues"] and
test_set_result["overall_coverage"] >= 0.95 and
test_set_result["meets_missing_threshold"]
)
# Evaluation eligibility (more permissive)
can_evaluate = (
format_result["valid"] and
test_set_result["overall_coverage"] >= 0.8 and
not any("❌" in issue for issue in content_result.get("issues", []))
)
return {
"valid": is_valid,
"can_evaluate": can_evaluate,
"category": detected_category,
"coverage": test_set_result["overall_coverage"],
"report": report,
"predictions": predictions,
"quality_score": content_result.get("quality_score", 0.8),
"track_coverage": test_set_result.get("track_coverage", {}),
}