# src/validation.py import pandas as pd import numpy as np from typing import Dict, List, Tuple, Optional import json import io from config import PREDICTION_FORMAT def validate_file_format(file_content: bytes, filename: str) -> Dict: """Validate uploaded file format and structure.""" try: # Determine file type if filename.endswith('.csv'): df = pd.read_csv(io.BytesIO(file_content)) elif filename.endswith('.tsv'): df = pd.read_csv(io.BytesIO(file_content), sep='\t') elif filename.endswith('.json'): data = json.loads(file_content.decode('utf-8')) df = pd.DataFrame(data) else: return { 'valid': False, 'error': f"Unsupported file type. Use: {', '.join(PREDICTION_FORMAT['file_types'])}" } # Check required columns missing_cols = set(PREDICTION_FORMAT['required_columns']) - set(df.columns) if missing_cols: return { 'valid': False, 'error': f"Missing required columns: {', '.join(missing_cols)}" } # Basic data validation if len(df) == 0: return { 'valid': False, 'error': "File is empty" } # Check for required data if df['sample_id'].isna().any(): return { 'valid': False, 'error': "Missing sample_id values found" } if df['prediction'].isna().any(): na_count = df['prediction'].isna().sum() return { 'valid': False, 'error': f"Missing prediction values found ({na_count} empty predictions)" } # Check for duplicates duplicates = df['sample_id'].duplicated() if duplicates.any(): dup_count = duplicates.sum() return { 'valid': False, 'error': f"Duplicate sample_id values found ({dup_count} duplicates)" } return { 'valid': True, 'dataframe': df, 'row_count': len(df), 'columns': list(df.columns) } except Exception as e: return { 'valid': False, 'error': f"Error parsing file: {str(e)}" } def validate_predictions_content(predictions: pd.DataFrame) -> Dict: """Validate prediction content quality.""" issues = [] warnings = [] # Check prediction text quality empty_predictions = predictions['prediction'].str.strip().eq('').sum() if empty_predictions > 0: issues.append(f"{empty_predictions} empty predictions found") # Check for suspiciously short predictions short_predictions = (predictions['prediction'].str.len() < 3).sum() if short_predictions > len(predictions) * 0.1: # More than 10% warnings.append(f"{short_predictions} very short predictions (< 3 characters)") # Check for suspiciously long predictions long_predictions = (predictions['prediction'].str.len() > 500).sum() if long_predictions > 0: warnings.append(f"{long_predictions} very long predictions (> 500 characters)") # Check for repeated predictions duplicate_predictions = predictions['prediction'].duplicated().sum() if duplicate_predictions > len(predictions) * 0.5: # More than 50% warnings.append(f"{duplicate_predictions} duplicate prediction texts") # Check for non-text content non_text_pattern = r'^[A-Za-z\s\'".,!?;:()\-]+$' non_text_predictions = ~predictions['prediction'].str.match(non_text_pattern, na=False) if non_text_predictions.sum() > 0: warnings.append(f"{non_text_predictions.sum()} predictions contain unusual characters") return { 'has_issues': len(issues) > 0, 'issues': issues, 'warnings': warnings, 'quality_score': max(0, 1.0 - len(issues) * 0.2 - len(warnings) * 0.1) } def validate_against_test_set(predictions: pd.DataFrame, test_set: pd.DataFrame) -> Dict: """Validate predictions against the official test set.""" # Convert IDs to string for comparison pred_ids = set(predictions['sample_id'].astype(str)) test_ids = set(test_set['sample_id'].astype(str)) # Check coverage missing_ids = test_ids - pred_ids extra_ids = pred_ids - test_ids matching_ids = pred_ids & test_ids coverage = len(matching_ids) / len(test_ids) # Detailed coverage by language pair pair_coverage = {} for _, row in test_set.iterrows(): pair_key = f"{row['source_language']}_{row['target_language']}" if pair_key not in pair_coverage: pair_coverage[pair_key] = {'total': 0, 'covered': 0} pair_coverage[pair_key]['total'] += 1 if str(row['sample_id']) in pred_ids: pair_coverage[pair_key]['covered'] += 1 # Calculate pair-wise coverage rates for pair_key in pair_coverage: pair_info = pair_coverage[pair_key] pair_info['coverage_rate'] = pair_info['covered'] / pair_info['total'] return { 'overall_coverage': coverage, 'missing_count': len(missing_ids), 'extra_count': len(extra_ids), 'matching_count': len(matching_ids), 'is_complete': coverage == 1.0, 'pair_coverage': pair_coverage, 'missing_ids_sample': list(missing_ids)[:10], # First 10 for display 'extra_ids_sample': list(extra_ids)[:10] } def generate_validation_report( format_result: Dict, content_result: Dict, test_set_result: Dict, model_name: str = "" ) -> str: """Generate human-readable validation report.""" report = [] # Header report.append(f"# Validation Report: {model_name or 'Submission'}") report.append(f"Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}") report.append("") # File format validation if format_result['valid']: report.append("✅ **File Format**: Valid") report.append(f" - Rows: {format_result['row_count']:,}") report.append(f" - Columns: {', '.join(format_result['columns'])}") else: report.append("❌ **File Format**: Invalid") report.append(f" - Error: {format_result['error']}") return "\n".join(report) # Content validation if content_result['has_issues']: report.append("⚠️ **Content Quality**: Issues Found") for issue in content_result['issues']: report.append(f" - ❌ {issue}") else: report.append("✅ **Content Quality**: Good") if content_result['warnings']: for warning in content_result['warnings']: report.append(f" - ⚠️ {warning}") # Test set validation coverage = test_set_result['overall_coverage'] if coverage == 1.0: report.append("✅ **Test Set Coverage**: Complete") elif coverage >= 0.95: report.append("⚠️ **Test Set Coverage**: Nearly Complete") else: report.append("❌ **Test Set Coverage**: Incomplete") report.append(f" - Coverage: {coverage:.1%} ({test_set_result['matching_count']:,} / {test_set_result['matching_count'] + test_set_result['missing_count']:,})") if test_set_result['missing_count'] > 0: report.append(f" - Missing: {test_set_result['missing_count']:,} samples") if test_set_result['extra_count'] > 0: report.append(f" - Extra: {test_set_result['extra_count']:,} samples") # Language pair coverage pair_cov = test_set_result['pair_coverage'] incomplete_pairs = [k for k, v in pair_cov.items() if v['coverage_rate'] < 1.0] if incomplete_pairs: report.append("") report.append("**Incomplete Language Pairs:**") for pair in incomplete_pairs[:5]: # Show first 5 info = pair_cov[pair] src, tgt = pair.split('_') report.append(f" - {src}→{tgt}: {info['covered']}/{info['total']} ({info['coverage_rate']:.1%})") if len(incomplete_pairs) > 5: report.append(f" - ... and {len(incomplete_pairs) - 5} more pairs") # Final verdict report.append("") if format_result['valid'] and coverage >= 0.95 and not content_result['has_issues']: report.append("🎉 **Overall**: Ready for evaluation!") elif format_result['valid'] and coverage >= 0.8: report.append("⚠️ **Overall**: Can be evaluated with warnings") else: report.append("❌ **Overall**: Please fix issues before submission") return "\n".join(report) def validate_submission_complete(file_content: bytes, filename: str, test_set: pd.DataFrame, model_name: str = "") -> Dict: """Complete validation pipeline for a submission.""" # Step 1: File format validation format_result = validate_file_format(file_content, filename) if not format_result['valid']: return { 'valid': False, 'report': generate_validation_report(format_result, {}, {}, model_name), 'predictions': None } predictions = format_result['dataframe'] # Step 2: Content validation content_result = validate_predictions_content(predictions) # Step 3: Test set validation test_set_result = validate_against_test_set(predictions, test_set) # Step 4: Generate report report = generate_validation_report(format_result, content_result, test_set_result, model_name) # Overall validity is_valid = ( format_result['valid'] and not content_result['has_issues'] and test_set_result['overall_coverage'] >= 0.95 ) return { 'valid': is_valid, 'coverage': test_set_result['overall_coverage'], 'report': report, 'predictions': predictions, 'pair_coverage': test_set_result['pair_coverage'] }