# Adapted from https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/evaluate_from_local.py import csv import json import argparse import os import torch import spaces import random import transformers import time import re from vllm import LLM, SamplingParams from tqdm import tqdm import logging import sys from datasets import load_dataset import pandas as pd import numpy as np logging.basicConfig(level=logging.INFO) # Can be found at https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/cot_prompt_lib/initial_prompt.txt initial_prompt = "The following are multiple choice questions (with answers) about {$}. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice." choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"] max_model_length = 4096 max_new_tokens = 2048 def preprocess(test_df): res_df = [] for each in test_df: options = [] for opt in each["options"]: if opt == "N/A": continue options.append(opt) each["options"] = options res_df.append(each) return res_df def load_mmlu_pro(): dataset = load_dataset("TIGER-Lab/MMLU-Pro") test_df, val_df = dataset["test"], dataset["validation"] test_df = preprocess(test_df) val_df = preprocess(val_df) # Convert to DataFrames right after loading and preprocessing test_df = pd.DataFrame(test_df) val_df = pd.DataFrame(val_df) return test_df, val_df def load_model(model_name, gpu_utilization=0.8): llm = LLM(model=model_name, gpu_memory_utilization=float(gpu_utilization), tensor_parallel_size=torch.cuda.device_count(), max_model_len=max_model_length, trust_remote_code=True) logging.info(f"Torch Device CUDA Count: {torch.cuda.device_count()}") sampling_params = SamplingParams(temperature=0, max_tokens=max_new_tokens, stop=["Question:"]) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) return (llm, sampling_params), tokenizer def format_cot_example(example, including_answer=True): # Handle both Series and dict inputs if isinstance(example, pd.Series): example = example.to_dict() prompt = "Question:\n" question = example["question"] options = example["options"] prompt += question + "\n" prompt += "Options:\n" for i, opt in enumerate(options): prompt += "{}. {}\n".format(choices[i], opt) if including_answer: cot_content = example["cot_content"].replace("A: Let's think step by step.", "Answer: Let's think step by step.") prompt += cot_content + "\n\n" else: prompt += "Answer: Let's think step by step." return prompt def generate_cot_prompt(val_df, curr, k): """ Generate prompt with examples from val_df matching curr's category. Args: val_df: DataFrame containing validation examples curr: Series or dict representing current example k: Number of examples to include """ prompt = initial_prompt # Handle both Series and dict inputs for curr if isinstance(curr, pd.Series): subject = curr["category"] else: subject = curr["category"] # Filter validation examples by category filtered_val_df = val_df[val_df["category"] == subject].head(k) prompt = prompt.replace("{$}", subject) + "\n" # Add each example to the prompt for _, example in filtered_val_df.iterrows(): prompt += format_cot_example(example, including_answer=True) # Add the current example prompt += format_cot_example(curr, including_answer=False) return prompt def extract_answer(text): pattern = r"answer is \(?([A-J])\)?" match = re.search(pattern, text) if match: return match.group(1) else: print("1st answer extract failed\n" + text) return extract_again(text) def extract_again(text): match = re.search(r'.*[aA]nswer:\s*([A-J])', text) if match: return match.group(1) else: return extract_final(text) def extract_final(text): pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" match = re.search(pattern, text, re.DOTALL) if match: return match.group(0) else: return None def batch_inference(llm, sampling_params, inference_batch, tokenizer): start = time.time() outputs = llm.generate(inference_batch, sampling_params) logging.info("Batch of size: %s. Time taken: %s", len(inference_batch), time.time() - start) response_batch = [] pred_batch = [] for output in outputs: generated_text = output.outputs[0].text response_batch.append(generated_text) pred = extract_answer(generated_text) pred_batch.append(pred) return pred_batch, response_batch def batch_inference_debug_mode(llm, sampling_params, inference_batch, tokenizer): start = time.time() outputs = llm.generate(inference_batch, sampling_params) logging.info("Batch of size: %s. Time taken: %s", len(inference_batch), time.time() - start) response_batch = [] pred_batch = [] input_token_counts = [] output_token_counts = [] for i, output in enumerate(outputs): generated_text = output.outputs[0].text response_batch.append(generated_text) pred = extract_answer(generated_text) pred_batch.append(pred) # Proper token count using tokenizer input_tokens = len(tokenizer.encode(inference_batch[i])) output_tokens = len(tokenizer.encode(generated_text)) input_token_counts.append(input_tokens) output_token_counts.append(output_tokens) logging.info("\n----------- PRED BATCH -----------\n%s", pred_batch) logging.info("\n----------- RESPONSE BATCH -----------\n%s", response_batch) # Convert to DataFrame for logging (handle cases with fewer than 40 requests) num_samples = min(40, len(inference_batch)) summary_df = pd.DataFrame({ 'Input': inference_batch[:num_samples], 'Response': response_batch[:num_samples] }) logging.info("\n----------- Summary of first %d requests and responses -----------\n%s", num_samples, summary_df.to_string()) # Total and average input/output token statistics total_input_tokens = sum(input_token_counts) total_output_tokens = sum(output_token_counts) avg_input_tokens = total_input_tokens / len(input_token_counts) avg_output_tokens = total_output_tokens / len(output_token_counts) max_input_idx = np.argmax(input_token_counts) max_output_idx = np.argmax(output_token_counts) min_input_idx = np.argmin(input_token_counts) min_output_idx = np.argmin(output_token_counts) logging.info("\n----------- Token Statistics -----------") logging.info("Total input tokens: %d", total_input_tokens) logging.info("Total output tokens: %d", total_output_tokens) logging.info("Average input tokens: %.2f", avg_input_tokens) logging.info("Average output tokens: %.2f", avg_output_tokens) logging.info("\n----------- Request with max input tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", max_input_idx, input_token_counts[max_input_idx], inference_batch[max_input_idx], response_batch[max_input_idx]) logging.info("\n----------- Request with max output tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", max_output_idx, output_token_counts[max_output_idx], inference_batch[max_output_idx], response_batch[max_output_idx]) logging.info("\n----------- Request with min input tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", min_input_idx, input_token_counts[min_input_idx], inference_batch[min_input_idx], response_batch[min_input_idx]) logging.info("\n----------- Request with min output tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", min_output_idx, output_token_counts[min_output_idx], inference_batch[min_output_idx], response_batch[min_output_idx]) return pred_batch, response_batch def calculate_accuracy(res): """ Calculate accuracy and return an array of correctness (1 if correct, 0 if wrong) along with the overall accuracy. """ correctness = [] # Process predictions and compute correctness for i, row in res.iterrows(): logging.info(f"Processing row {i}. Prediction: {row.get('pred')}, Answer: {row.get('answer')}") if not row["pred"]: # If prediction is None, use random choice with fixed seed random.seed(12345) options_len = len(row["options"]) if isinstance(row["options"], list) else 4 x = random.randint(0, options_len - 1) is_correct = 1 if x == row["answer_index"] else 0 else: is_correct = 1 if row["pred"] == row["answer"] else 0 correctness.append(is_correct) # Calculate accuracy from correctness array if len(correctness) == 0: return [], 0.0 accuracy = sum(correctness) / len(correctness) return correctness, accuracy @torch.no_grad() def eval_cot(subject, model, tokenizer, val_df, test_df, num_shots=5, debug_mode=False): """ Evaluate model using chain-of-thought prompting. Args: subject: Subject category being evaluated model: Tuple of (llm, sampling_params) tokenizer: Model tokenizer val_df: DataFrame with validation examples test_df: DataFrame with test examples num_shots: Number of examples to include in prompt """ llm, sampling_params = model global choices logging.info("evaluating " + subject) inference_batches = [] # Process each test example for i in range(len(test_df)): curr = test_df.iloc[i] k = num_shots # Reset k for each example # Find prompt that fits within token limit prompt_length_ok = False prompt = None while not prompt_length_ok and k > 0: prompt = generate_cot_prompt(val_df, curr, k) inputs = tokenizer(prompt, return_tensors="pt") inputs = {key: value.cuda() for key, value in inputs.items()} length = len(inputs["input_ids"][0]) if length < max_model_length - max_new_tokens: prompt_length_ok = True else: k -= 1 if not prompt_length_ok: # If we couldn't fit any examples, use just the test question prompt = generate_cot_prompt(val_df.head(0), curr, 0) inference_batches.append(prompt) batch_fn = batch_inference_debug_mode if debug_mode else batch_inference pred_batch, response_batch = batch_fn(llm, sampling_params, inference_batches, tokenizer) # Add predictions to test DataFrame results_df = test_df.copy() results_df["pred"] = pred_batch results_df["model_outputs"] = response_batch # Calculate accuracy correctness, accuracy = calculate_accuracy(results_df) logging.info("This batch accuracy is: {}, correct samples: {}/{}\n".format( str(accuracy), str(sum(correctness)), str(len(correctness)))) return correctness, accuracy def evaluate_mmlu_pro(model_name, num_subjects=-1, num_questions=10, num_shots=5, specific_subjects=None, flash_attention=False, regex_pattern=None): """ Main evaluation function for MMLU-Pro benchmark. Args: model_name: Name/path of model to evaluate num_subjects: Number of subjects to test (-1 for all) num_questions: Number of questions per subject (-1 for all) num_shots: Number of examples to include in prompts specific_subjects: List of specific subjects to evaluate (overrides num_subjects) flash_attention: Whether to use flash attention (currently ignored) regex_pattern: Regex pattern for answer extraction (currently ignored) """ print(f"Is CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # Load model and data model, tokenizer = load_model(model_name, gpu_utilization=0.8) test_df, val_df = load_mmlu_pro() # Sort DataFrames test_df = test_df.sort_values(['category', 'question_id']) val_df = val_df.sort_values(['category', 'question_id']) # Get unique subjects all_subjects = sorted(test_df['category'].unique()) # Select subjects based on parameters if specific_subjects is not None: selected_subjects = [subject for subject in specific_subjects if subject in all_subjects] elif num_subjects == -1 or num_subjects >= len(all_subjects): selected_subjects = all_subjects else: selected_subjects = all_subjects[:num_subjects] logging.info("selected subjects:\n" + "\n".join(selected_subjects)) # Prepare results tracking results = {} all_correctness = [] results_table = [] # Process each subject for subject in tqdm(selected_subjects, desc="Processing Selected Categories"): # Filter data for current subject if num_questions == -1: # Use all questions for this subject test_samples = test_df[test_df['category'] == subject] else: # Use specified number of questions test_samples = test_df[test_df['category'] == subject].head(num_questions) val_samples = val_df[val_df['category'] == subject].head(num_shots) # Run evaluation correctness, acc = eval_cot( subject, model, tokenizer, val_df=val_samples, test_df=test_samples, num_shots=num_shots ) # Store results results[subject] = acc all_correctness.extend(correctness) results_table.append({ 'Subject': subject, 'Num_samples': len(test_samples), 'Num_correct': sum(correctness), 'Accuracy': acc }) # Calculate overall metrics weighted_acc = np.mean(all_correctness) min_acc_subject = min(results.items(), key=lambda x: x[1]) max_acc_subject = max(results.items(), key=lambda x: x[1]) # Return results summary return { "overall_accuracy": weighted_acc, "min_accuracy_subject": min_acc_subject, "max_accuracy_subject": max_acc_subject, "full_accuracy_table": results_table, }