Spaces:
Sleeping
Sleeping
| import time | |
| import traceback | |
| import pandas as pd | |
| import gradio as gr | |
| import spaces | |
| from mmlu_pro_eval_adapted import evaluate_mmlu_pro | |
| from configs.dataset_config import get_subject_mode_param, get_subject_names | |
| def run_mmlu_evaluation(subject_selection_mode, num_subjects, selected_subjects, | |
| all_questions, num_questions, model_configs, progress=gr.Progress()): | |
| """ | |
| Runs the MMLU evaluation with the specified parameters. | |
| Args: | |
| subject_selection_mode (str): Mode of subject selection ("all", "number", or "specific") | |
| num_subjects (int): Number of subjects to evaluate (1-14) | |
| selected_subjects (list): List of specific subjects to evaluate | |
| all_questions (bool): Whether to evaluate all questions per subject | |
| num_questions (int): Number of examples per subject (1-100 or all) | |
| model_configs (dict): Configuration for both models | |
| progress (gr.Progress): Progress indicator | |
| """ | |
| try: | |
| # Convert parameters if needed | |
| if subject_selection_mode == "all": | |
| num_subjects = -1 | |
| selected_subjects = [] | |
| elif subject_selection_mode == "specific": | |
| num_subjects = len(selected_subjects) if selected_subjects else -1 | |
| if all_questions: | |
| num_questions = -1 | |
| # Extract model configurations | |
| model1_config = model_configs["model1"] | |
| model2_config = model_configs["model2"] | |
| # Run evaluation for Model 1 | |
| start_time_model1 = time.time() | |
| model1_results = evaluate_mmlu_pro( | |
| model1_config["name"], | |
| num_subjects=num_subjects, | |
| num_questions=num_questions, | |
| num_shots=model1_config["shots"], | |
| specific_subjects=selected_subjects if subject_selection_mode == "specific" else None, | |
| flash_attention=model1_config["flash_attention"], | |
| regex_pattern=model1_config["regex"] if model1_config["regex"] else None | |
| ) | |
| model1_elapsed_time = time.time() - start_time_model1 | |
| # Run evaluation for Model 2 | |
| start_time_model2 = time.time() | |
| model2_results = evaluate_mmlu_pro( | |
| model2_config["name"], | |
| num_subjects=num_subjects, | |
| num_questions=num_questions, | |
| num_shots=model2_config["shots"], | |
| specific_subjects=selected_subjects if subject_selection_mode == "specific" else None, | |
| flash_attention=model2_config["flash_attention"], | |
| regex_pattern=model2_config["regex"] if model2_config["regex"] else None | |
| ) | |
| model2_elapsed_time = time.time() - start_time_model2 | |
| # Format summary results | |
| model1_overall_acc = model1_results["overall_accuracy"] | |
| model1_min_subject, model1_min_acc = model1_results["min_accuracy_subject"] | |
| model1_max_subject, model1_max_acc = model1_results["max_accuracy_subject"] | |
| model2_overall_acc = model2_results["overall_accuracy"] | |
| model2_min_subject, model2_min_acc = model2_results["min_accuracy_subject"] | |
| model2_max_subject, model2_max_acc = model2_results["max_accuracy_subject"] | |
| # Create merged results DataFrame | |
| results_df1 = pd.DataFrame(model1_results["full_accuracy_table"]) | |
| results_df2 = pd.DataFrame(model2_results["full_accuracy_table"]) | |
| # Ensure both dataframes have the same subjects | |
| subjects = sorted(set(results_df1['Subject'].tolist() + results_df2['Subject'].tolist())) | |
| # Create comparison DataFrame | |
| comparison_data = [] | |
| for subject in subjects: | |
| model1_row = results_df1[results_df1['Subject'] == subject] | |
| model2_row = results_df2[results_df2['Subject'] == subject] | |
| model1_acc = model1_row['Accuracy'].iloc[0] if not model1_row.empty else 0 | |
| model2_acc = model2_row['Accuracy'].iloc[0] if not model2_row.empty else 0 | |
| # Calculate the difference and determine the winner | |
| diff = model1_acc - model2_acc | |
| winner = "Model 1" if diff > 0 else ("Model 2" if diff < 0 else "Tie") | |
| comparison_data.append({ | |
| 'Subject': subject, | |
| 'Model 1 Accuracy': model1_acc, | |
| 'Model 2 Accuracy': model2_acc, | |
| 'Difference': abs(diff), | |
| 'Winner': winner | |
| }) | |
| # Add overall row | |
| model1_total_samples = results_df1['Num_samples'].sum() | |
| model1_total_correct = results_df1['Num_correct'].sum() | |
| model2_total_samples = results_df2['Num_samples'].sum() | |
| model2_total_correct = results_df2['Num_correct'].sum() | |
| overall_diff = model1_overall_acc - model2_overall_acc | |
| overall_winner = "Model 1" if overall_diff > 0 else ("Model 2" if overall_diff < 0 else "Tie") | |
| comparison_data.insert(0, { | |
| 'Subject': '**Overall**', | |
| 'Model 1 Accuracy': model1_overall_acc, | |
| 'Model 2 Accuracy': model2_overall_acc, | |
| 'Difference': abs(overall_diff), | |
| 'Winner': overall_winner | |
| }) | |
| comparison_df = pd.DataFrame(comparison_data) | |
| # Format the report | |
| report = ( | |
| f"### Head-to-Head Comparison Results\n\n" | |
| f"#### Model 1: {model1_config['name']}\n" | |
| f"* Overall Accuracy: {model1_overall_acc:.3f}\n" | |
| f"* Best Performance: {model1_max_subject} ({model1_max_acc:.3f})\n" | |
| f"* Worst Performance: {model1_min_subject} ({model1_min_acc:.3f})\n" | |
| f"* Evaluation completed in {model1_elapsed_time:.2f} seconds\n\n" | |
| f"#### Model 2: {model2_config['name']}\n" | |
| f"* Overall Accuracy: {model2_overall_acc:.3f}\n" | |
| f"* Best Performance: {model2_max_subject} ({model2_max_acc:.3f})\n" | |
| f"* Worst Performance: {model2_min_subject} ({model2_min_acc:.3f})\n" | |
| f"* Evaluation completed in {model2_elapsed_time:.2f} seconds\n\n" | |
| f"#### Overall Winner: {overall_winner}\n" | |
| f"* Margin: {abs(overall_diff):.3f}\n" | |
| ) | |
| # Return values that re-enable UI components after completion | |
| return { | |
| 'report': report, | |
| 'comparison_df': comparison_df, | |
| 'success': True | |
| } | |
| except Exception as e: | |
| # Handle errors gracefully | |
| error_trace = traceback.format_exc() | |
| error_message = f"### Error during evaluation\n```\n{error_trace}\n```" | |
| # Return error information | |
| return { | |
| 'report': error_message, | |
| 'comparison_df': None, | |
| 'success': False | |
| } | |