Spaces:
Running
Running
# src/validation.py | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Tuple, Optional | |
import json | |
import io | |
from config import PREDICTION_FORMAT | |
def validate_file_format(file_content: bytes, filename: str) -> Dict: | |
"""Validate uploaded file format and structure.""" | |
try: | |
# Determine file type | |
if filename.endswith('.csv'): | |
df = pd.read_csv(io.BytesIO(file_content)) | |
elif filename.endswith('.tsv'): | |
df = pd.read_csv(io.BytesIO(file_content), sep='\t') | |
elif filename.endswith('.json'): | |
data = json.loads(file_content.decode('utf-8')) | |
df = pd.DataFrame(data) | |
else: | |
return { | |
'valid': False, | |
'error': f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}" | |
} | |
# Check required columns | |
missing_cols = set(PREDICTION_FORMAT['required_columns']) - set(df.columns) | |
if missing_cols: | |
return { | |
'valid': False, | |
'error': f"Missing required columns: {', '.join(missing_cols)}" | |
} | |
# Basic data validation | |
if len(df) == 0: | |
return { | |
'valid': False, | |
'error': "File is empty" | |
} | |
# Check for required data | |
if df['sample_id'].isna().any(): | |
return { | |
'valid': False, | |
'error': "Missing sample_id values found" | |
} | |
if df['prediction'].isna().any(): | |
na_count = df['prediction'].isna().sum() | |
return { | |
'valid': False, | |
'error': f"Missing prediction values found ({na_count} empty predictions)" | |
} | |
# Check for duplicates | |
duplicates = df['sample_id'].duplicated() | |
if duplicates.any(): | |
dup_count = duplicates.sum() | |
return { | |
'valid': False, | |
'error': f"Duplicate sample_id values found ({dup_count} duplicates)" | |
} | |
return { | |
'valid': True, | |
'dataframe': df, | |
'row_count': len(df), | |
'columns': list(df.columns) | |
} | |
except Exception as e: | |
return { | |
'valid': False, | |
'error': f"Error parsing file: {str(e)}" | |
} | |
def validate_predictions_content(predictions: pd.DataFrame) -> Dict: | |
"""Validate prediction content quality.""" | |
issues = [] | |
warnings = [] | |
# Check prediction text quality | |
empty_predictions = predictions['prediction'].str.strip().eq('').sum() | |
if empty_predictions > 0: | |
issues.append(f"{empty_predictions} empty predictions found") | |
# Check for suspiciously short predictions | |
short_predictions = (predictions['prediction'].str.len() < 3).sum() | |
if short_predictions > len(predictions) * 0.1: # More than 10% | |
warnings.append(f"{short_predictions} very short predictions (< 3 characters)") | |
# Check for suspiciously long predictions | |
long_predictions = (predictions['prediction'].str.len() > 500).sum() | |
if long_predictions > 0: | |
warnings.append(f"{long_predictions} very long predictions (> 500 characters)") | |
# Check for repeated predictions | |
duplicate_predictions = predictions['prediction'].duplicated().sum() | |
if duplicate_predictions > len(predictions) * 0.5: # More than 50% | |
warnings.append(f"{duplicate_predictions} duplicate prediction texts") | |
# Check for non-text content | |
non_text_pattern = r'^[A-Za-z\s\'".,!?;:()\-]+$' | |
non_text_predictions = ~predictions['prediction'].str.match(non_text_pattern, na=False) | |
if non_text_predictions.sum() > 0: | |
warnings.append(f"{non_text_predictions.sum()} predictions contain unusual characters") | |
return { | |
'has_issues': len(issues) > 0, | |
'issues': issues, | |
'warnings': warnings, | |
'quality_score': max(0, 1.0 - len(issues) * 0.2 - len(warnings) * 0.1) | |
} | |
def validate_against_test_set(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: | |
"""Validate predictions against the official test set.""" | |
# Convert IDs to string for comparison | |
pred_ids = set(predictions['sample_id'].astype(str)) | |
test_ids = set(test_set['sample_id'].astype(str)) | |
# Check coverage | |
missing_ids = test_ids - pred_ids | |
extra_ids = pred_ids - test_ids | |
matching_ids = pred_ids & test_ids | |
coverage = len(matching_ids) / len(test_ids) | |
# Detailed coverage by language pair | |
pair_coverage = {} | |
for _, row in test_set.iterrows(): | |
pair_key = f"{row['source_language']}_{row['target_language']}" | |
if pair_key not in pair_coverage: | |
pair_coverage[pair_key] = {'total': 0, 'covered': 0} | |
pair_coverage[pair_key]['total'] += 1 | |
if str(row['sample_id']) in pred_ids: | |
pair_coverage[pair_key]['covered'] += 1 | |
# Calculate pair-wise coverage rates | |
for pair_key in pair_coverage: | |
pair_info = pair_coverage[pair_key] | |
pair_info['coverage_rate'] = pair_info['covered'] / pair_info['total'] | |
return { | |
'overall_coverage': coverage, | |
'missing_count': len(missing_ids), | |
'extra_count': len(extra_ids), | |
'matching_count': len(matching_ids), | |
'is_complete': coverage == 1.0, | |
'pair_coverage': pair_coverage, | |
'missing_ids_sample': list(missing_ids)[:10], # First 10 for display | |
'extra_ids_sample': list(extra_ids)[:10] | |
} | |
def generate_validation_report( | |
format_result: Dict, | |
content_result: Dict, | |
test_set_result: Dict, | |
model_name: str = "" | |
) -> str: | |
"""Generate human-readable validation report.""" | |
report = [] | |
# Header | |
report.append(f"# Validation Report: {model_name or 'Submission'}") | |
report.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
report.append("") | |
# File format validation | |
if format_result['valid']: | |
report.append("β **File Format**: Valid") | |
report.append(f" - Rows: {format_result['row_count']:,}") | |
report.append(f" - Columns: {', '.join(format_result['columns'])}") | |
else: | |
report.append("β **File Format**: Invalid") | |
report.append(f" - Error: {format_result['error']}") | |
return "\n".join(report) | |
# Content validation | |
if content_result['has_issues']: | |
report.append("β οΈ **Content Quality**: Issues Found") | |
for issue in content_result['issues']: | |
report.append(f" - β {issue}") | |
else: | |
report.append("β **Content Quality**: Good") | |
if content_result['warnings']: | |
for warning in content_result['warnings']: | |
report.append(f" - β οΈ {warning}") | |
# Test set validation | |
coverage = test_set_result['overall_coverage'] | |
if coverage == 1.0: | |
report.append("β **Test Set Coverage**: Complete") | |
elif coverage >= 0.95: | |
report.append("β οΈ **Test Set Coverage**: Nearly Complete") | |
else: | |
report.append("β **Test Set Coverage**: Incomplete") | |
report.append(f" - Coverage: {coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})") | |
if test_set_result['missing_count'] > 0: | |
report.append(f" - Missing: {test_set_result['missing_count']:,} samples") | |
if test_set_result['extra_count'] > 0: | |
report.append(f" - Extra: {test_set_result['extra_count']:,} samples") | |
# Language pair coverage | |
pair_cov = test_set_result['pair_coverage'] | |
incomplete_pairs = [k for k, v in pair_cov.items() if v['coverage_rate'] < 1.0] | |
if incomplete_pairs: | |
report.append("") | |
report.append("**Incomplete Language Pairs:**") | |
for pair in incomplete_pairs[:5]: # Show first 5 | |
info = pair_cov[pair] | |
src, tgt = pair.split('_') | |
report.append(f" - {src}β{tgt}: {info['covered']}/{info['total']} ({info['coverage_rate']:.1%})") | |
if len(incomplete_pairs) > 5: | |
report.append(f" - ... and {len(incomplete_pairs) - 5} more pairs") | |
# Final verdict | |
report.append("") | |
if format_result['valid'] and coverage >= 0.95 and not content_result['has_issues']: | |
report.append("π **Overall**: Ready for evaluation!") | |
elif format_result['valid'] and coverage >= 0.8: | |
report.append("β οΈ **Overall**: Can be evaluated with warnings") | |
else: | |
report.append("β **Overall**: Please fix issues before submission") | |
return "\n".join(report) | |
def validate_submission_complete(file_content: bytes, filename: str, test_set: pd.DataFrame, model_name: str = "") -> Dict: | |
"""Complete validation pipeline for a submission.""" | |
# Step 1: File format validation | |
format_result = validate_file_format(file_content, filename) | |
if not format_result['valid']: | |
return { | |
'valid': False, | |
'report': generate_validation_report(format_result, {}, {}, model_name), | |
'predictions': None | |
} | |
predictions = format_result['dataframe'] | |
# Step 2: Content validation | |
content_result = validate_predictions_content(predictions) | |
# Step 3: Test set validation | |
test_set_result = validate_against_test_set(predictions, test_set) | |
# Step 4: Generate report | |
report = generate_validation_report(format_result, content_result, test_set_result, model_name) | |
# Overall validity | |
is_valid = ( | |
format_result['valid'] and | |
not content_result['has_issues'] and | |
test_set_result['overall_coverage'] >= 0.95 | |
) | |
return { | |
'valid': is_valid, | |
'coverage': test_set_result['overall_coverage'], | |
'report': report, | |
'predictions': predictions, | |
'pair_coverage': test_set_result['pair_coverage'] | |
} |