Spaces:
Running
Running
# src/validation.py | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Tuple, Optional | |
import json | |
import io | |
import re | |
from config import ( | |
PREDICTION_FORMAT, | |
VALIDATION_CONFIG, | |
MODEL_CATEGORIES, | |
EVALUATION_TRACKS, | |
ALL_UG40_LANGUAGES, | |
) | |
def detect_model_category(model_name: str, author: str, description: str) -> str: | |
"""Automatically detect model category based on name and metadata.""" | |
# Combine all text for analysis | |
text_to_analyze = f"{model_name} {author} {description}".lower() | |
# Category detection patterns | |
detection_patterns = PREDICTION_FORMAT["category_detection"] | |
# Check for specific patterns | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("google", [])): | |
return "commercial" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("nllb", [])): | |
return "research" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("m2m", [])): | |
return "research" | |
if any(pattern in text_to_analyze for pattern in detection_patterns.get("baseline", [])): | |
return "baseline" | |
# Check for research indicators | |
research_indicators = [ | |
"university", "research", "paper", "arxiv", "acl", "emnlp", "naacl", | |
"transformer", "bert", "gpt", "t5", "mbart", "academic" | |
] | |
if any(indicator in text_to_analyze for indicator in research_indicators): | |
return "research" | |
# Check for commercial indicators | |
commercial_indicators = [ | |
"google", "microsoft", "azure", "aws", "openai", "anthropic", | |
"commercial", "api", "cloud", "translate" | |
] | |
if any(indicator in text_to_analyze for indicator in commercial_indicators): | |
return "commercial" | |
# Default to community | |
return "community" | |
def validate_file_format(file_content: bytes, filename: str) -> Dict: | |
"""Validate file format and structure.""" | |
try: | |
# Determine file type | |
if filename.endswith(".csv"): | |
df = pd.read_csv(io.BytesIO(file_content)) | |
elif filename.endswith(".tsv"): | |
df = pd.read_csv(io.BytesIO(file_content), sep="\t") | |
elif filename.endswith(".json"): | |
data = json.loads(file_content.decode("utf-8")) | |
df = pd.DataFrame(data) | |
else: | |
return { | |
"valid": False, | |
"error": f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}", | |
} | |
# Check required columns | |
missing_cols = set(PREDICTION_FORMAT["required_columns"]) - set(df.columns) | |
if missing_cols: | |
return { | |
"valid": False, | |
"error": f"Missing required columns: {', '.join(missing_cols)}", | |
} | |
# Basic data validation | |
if len(df) == 0: | |
return {"valid": False, "error": "File is empty"} | |
# Validation checks | |
validation_issues = [] | |
# Check for required data | |
if df["sample_id"].isna().any(): | |
validation_issues.append("Missing sample_id values found") | |
if df["prediction"].isna().any(): | |
na_count = df["prediction"].isna().sum() | |
validation_issues.append(f"Missing prediction values found ({na_count} empty predictions)") | |
# Check for duplicates | |
duplicates = df["sample_id"].duplicated() | |
if duplicates.any(): | |
dup_count = duplicates.sum() | |
validation_issues.append(f"Duplicate sample_id values found ({dup_count} duplicates)") | |
# Data type validation | |
if not df["sample_id"].dtype == "object": | |
df["sample_id"] = df["sample_id"].astype(str) | |
# Check sample_id format | |
invalid_ids = ~df["sample_id"].str.match(r"salt_\d{6}", na=False) | |
if invalid_ids.any(): | |
invalid_count = invalid_ids.sum() | |
validation_issues.append(f"Invalid sample_id format found ({invalid_count} invalid IDs)") | |
# Return results | |
if validation_issues: | |
return { | |
"valid": False, | |
"error": "; ".join(validation_issues), | |
"dataframe": df, | |
"row_count": len(df), | |
"columns": list(df.columns), | |
} | |
return { | |
"valid": True, | |
"dataframe": df, | |
"row_count": len(df), | |
"columns": list(df.columns), | |
} | |
except Exception as e: | |
return {"valid": False, "error": f"Error parsing file: {str(e)}"} | |
def validate_predictions_content(predictions: pd.DataFrame) -> Dict: | |
"""Validate prediction content quality.""" | |
issues = [] | |
warnings = [] | |
quality_metrics = {} | |
# Basic content checks | |
empty_predictions = predictions["prediction"].str.strip().eq("").sum() | |
if empty_predictions > 0: | |
issues.append(f"{empty_predictions} empty predictions found") | |
# Length analysis | |
pred_lengths = predictions["prediction"].str.len() | |
quality_metrics["avg_length"] = float(pred_lengths.mean()) | |
quality_metrics["std_length"] = float(pred_lengths.std()) | |
# Check for suspiciously short predictions | |
short_predictions = (pred_lengths < 3).sum() | |
if short_predictions > len(predictions) * 0.05: # More than 5% | |
issues.append(f"{short_predictions} very short predictions (< 3 characters)") | |
# Check for suspiciously long predictions | |
long_predictions = (pred_lengths > 500).sum() | |
if long_predictions > len(predictions) * 0.01: # More than 1% | |
warnings.append(f"{long_predictions} very long predictions (> 500 characters)") | |
# Check for repeated predictions | |
duplicate_predictions = predictions["prediction"].duplicated().sum() | |
duplicate_rate = duplicate_predictions / len(predictions) | |
quality_metrics["duplicate_rate"] = float(duplicate_rate) | |
if duplicate_rate > VALIDATION_CONFIG["quality_thresholds"]["max_duplicate_rate"]: | |
issues.append(f"{duplicate_predictions} duplicate prediction texts ({duplicate_rate:.1%})") | |
# Check for placeholder text | |
placeholder_patterns = [ | |
r"^(test|placeholder|todo|xxx|aaa|bbb)$", | |
r"^[a-z]{1,3}$", # Very short gibberish | |
r"^\d+$", # Just numbers | |
r"^[^\w\s]*$", # Only punctuation | |
] | |
placeholder_count = 0 | |
for pattern in placeholder_patterns: | |
placeholder_matches = predictions["prediction"].str.match(pattern, flags=re.IGNORECASE, na=False).sum() | |
placeholder_count += placeholder_matches | |
if placeholder_count > len(predictions) * 0.02: # More than 2% | |
issues.append(f"{placeholder_count} placeholder-like predictions detected") | |
# Calculate overall quality score | |
quality_score = 1.0 | |
quality_score -= len(issues) * 0.3 # Major penalty for issues | |
quality_score -= len(warnings) * 0.1 # Minor penalty for warnings | |
quality_score -= max(0, duplicate_rate - 0.05) * 2 # Penalty for excessive duplicates | |
# Length appropriateness | |
if quality_metrics["avg_length"] < VALIDATION_CONFIG["quality_thresholds"]["min_avg_length"]: | |
quality_score -= 0.2 | |
elif quality_metrics["avg_length"] > VALIDATION_CONFIG["quality_thresholds"]["max_avg_length"]: | |
quality_score -= 0.1 | |
quality_score = max(0.0, min(1.0, quality_score)) | |
return { | |
"has_issues": len(issues) > 0, | |
"issues": issues, | |
"warnings": warnings, | |
"quality_score": quality_score, | |
"quality_metrics": quality_metrics, | |
} | |
def validate_against_test_set( | |
predictions: pd.DataFrame, test_set: pd.DataFrame | |
) -> Dict: | |
"""Validate predictions against test set.""" | |
# Convert IDs to string for comparison | |
pred_ids = set(predictions["sample_id"].astype(str)) | |
test_ids = set(test_set["sample_id"].astype(str)) | |
# Check overall coverage | |
missing_ids = test_ids - pred_ids | |
extra_ids = pred_ids - test_ids | |
matching_ids = pred_ids & test_ids | |
overall_coverage = len(matching_ids) / len(test_ids) | |
# Track-specific coverage analysis | |
track_coverage = {} | |
for track_name, track_config in EVALUATION_TRACKS.items(): | |
track_languages = track_config["languages"] | |
# Filter test set to track languages | |
track_test_set = test_set[ | |
(test_set["source_language"].isin(track_languages)) & | |
(test_set["target_language"].isin(track_languages)) | |
] | |
if len(track_test_set) == 0: | |
continue | |
track_test_ids = set(track_test_set["sample_id"].astype(str)) | |
track_matching_ids = pred_ids & track_test_ids | |
track_coverage[track_name] = { | |
"total_samples": len(track_test_set), | |
"covered_samples": len(track_matching_ids), | |
"coverage_rate": len(track_matching_ids) / len(track_test_set), | |
"meets_minimum": len(track_matching_ids) >= VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
"min_required": VALIDATION_CONFIG["min_samples_per_track"][track_name], | |
} | |
# Missing rate validation | |
missing_rate = len(missing_ids) / len(test_ids) | |
meets_missing_threshold = missing_rate <= VALIDATION_CONFIG["max_missing_rate"] | |
return { | |
"overall_coverage": overall_coverage, | |
"missing_count": len(missing_ids), | |
"extra_count": len(extra_ids), | |
"matching_count": len(matching_ids), | |
"missing_rate": missing_rate, | |
"meets_missing_threshold": meets_missing_threshold, | |
"is_complete": overall_coverage == 1.0, | |
"track_coverage": track_coverage, | |
"missing_ids_sample": list(missing_ids)[:10], | |
"extra_ids_sample": list(extra_ids)[:10], | |
} | |
def generate_validation_report( | |
format_result: Dict, | |
content_result: Dict, | |
test_set_result: Dict, | |
model_name: str = "", | |
detected_category: str = "community", | |
) -> str: | |
"""Generate comprehensive validation report.""" | |
report = [] | |
# Header | |
report.append(f"### π¬ Validation Report: {model_name or 'Submission'}") | |
report.append("") | |
# Model categorization | |
category_info = MODEL_CATEGORIES.get(detected_category, MODEL_CATEGORIES["community"]) | |
report.append(f"**Detected Model Category**: {category_info['name']}") | |
report.append("") | |
# File format validation | |
if format_result["valid"]: | |
report.append("β **File Format**: Valid") | |
report.append(f" - Rows: {format_result['row_count']:,}") | |
report.append(f" - Columns: {', '.join(format_result['columns'])}") | |
else: | |
report.append("β **File Format**: Invalid") | |
report.append(f" - Error: {format_result['error']}") | |
return "\n".join(report) | |
# Content quality validation | |
quality_score = content_result.get("quality_score", 0.0) | |
if content_result["has_issues"]: | |
report.append("β **Content Quality**: Issues Found") | |
for issue in content_result["issues"]: | |
report.append(f" - β {issue}") | |
else: | |
report.append("β **Content Quality**: Good") | |
if content_result["warnings"]: | |
for warning in content_result["warnings"]: | |
report.append(f" - β οΈ {warning}") | |
report.append(f" - **Quality Score**: {quality_score:.2f}/1.00") | |
report.append("") | |
# Test set coverage validation | |
overall_coverage = test_set_result["overall_coverage"] | |
meets_threshold = test_set_result["meets_missing_threshold"] | |
if overall_coverage == 1.0: | |
report.append("β **Test Set Coverage**: Complete") | |
elif overall_coverage >= 0.95 and meets_threshold: | |
report.append("β **Test Set Coverage**: Adequate") | |
else: | |
report.append("β **Test Set Coverage**: Insufficient") | |
report.append(f" - Coverage: {overall_coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})") | |
report.append(f" - Missing Rate: {test_set_result['missing_rate']:.1%}") | |
report.append("") | |
# Track-specific coverage analysis | |
report.append("#### π Track-Specific Analysis") | |
track_coverage = test_set_result.get("track_coverage", {}) | |
for track_name, coverage_info in track_coverage.items(): | |
track_config = EVALUATION_TRACKS[track_name] | |
status = "β " if coverage_info["meets_minimum"] else "β" | |
report.append(f"**{status} {track_config['name']}**:") | |
report.append(f" - **Samples**: {coverage_info['covered_samples']:,} / {coverage_info['total_samples']:,}") | |
report.append(f" - **Coverage**: {coverage_info['coverage_rate']:.1%}") | |
report.append(f" - **Minimum Required**: {coverage_info['min_required']:,}") | |
report.append(f" - **Status**: {'Adequate' if coverage_info['meets_minimum'] else 'Insufficient'}") | |
report.append("") | |
# Final verdict | |
all_checks_pass = ( | |
format_result["valid"] and | |
not content_result["has_issues"] and | |
overall_coverage >= 0.95 and | |
meets_threshold | |
) | |
can_evaluate_with_limits = ( | |
format_result["valid"] and | |
overall_coverage >= 0.8 and | |
not any("β" in issue for issue in content_result.get("issues", [])) | |
) | |
if all_checks_pass: | |
report.append("π **Final Verdict**: Ready for evaluation!") | |
elif can_evaluate_with_limits: | |
report.append("β οΈ **Final Verdict**: Can be evaluated with limitations") | |
report.append(" - Results will include notes about limitations") | |
else: | |
report.append("β **Final Verdict**: Please address critical issues before submission") | |
return "\n".join(report) | |
def validate_submission( | |
file_content: bytes, | |
filename: str, | |
test_set: pd.DataFrame, | |
model_name: str = "", | |
author: str = "", | |
description: str = "" | |
) -> Dict: | |
"""Complete validation pipeline for submissions.""" | |
# Step 1: Detect model category | |
detected_category = detect_model_category(model_name, author, description) | |
# Step 2: File format validation | |
format_result = validate_file_format(file_content, filename) | |
if not format_result["valid"]: | |
return { | |
"valid": False, | |
"can_evaluate": False, | |
"category": detected_category, | |
"report": generate_validation_report( | |
format_result, {}, {}, model_name, detected_category | |
), | |
"predictions": None, | |
} | |
predictions = format_result["dataframe"] | |
# Step 3: Content validation | |
content_result = validate_predictions_content(predictions) | |
# Step 4: Test set validation | |
test_set_result = validate_against_test_set(predictions, test_set) | |
# Step 5: Generate report | |
report = generate_validation_report( | |
format_result, content_result, test_set_result, model_name, detected_category | |
) | |
# Overall validity determination | |
is_valid = ( | |
format_result["valid"] and | |
not content_result["has_issues"] and | |
test_set_result["overall_coverage"] >= 0.95 and | |
test_set_result["meets_missing_threshold"] | |
) | |
# Evaluation eligibility (more permissive) | |
can_evaluate = ( | |
format_result["valid"] and | |
test_set_result["overall_coverage"] >= 0.8 and | |
not any("β" in issue for issue in content_result.get("issues", [])) | |
) | |
return { | |
"valid": is_valid, | |
"can_evaluate": can_evaluate, | |
"category": detected_category, | |
"coverage": test_set_result["overall_coverage"], | |
"report": report, | |
"predictions": predictions, | |
"quality_score": content_result.get("quality_score", 0.8), | |
"track_coverage": test_set_result.get("track_coverage", {}), | |
} |