File size: 14,843 Bytes
cbd1959
 
 
 
 
77d4add
599b7a0
cbd1959
 
 
 
 
 
b748395
cbd1959
 
532a4a4
a3cb7ba
b748395
 
77d4add
cbd1959
3bcb863
00afad7
cbd1959
 
 
00afad7
532a4a4
cbd1959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3cb7ba
 
 
 
 
cbd1959
 
 
 
 
 
 
 
b5e28ab
cbd1959
 
 
 
 
 
 
a3cb7ba
 
 
 
cbd1959
 
 
 
 
 
 
 
 
 
 
 
 
532a4a4
 
 
cbd1959
a3cb7ba
 
 
 
 
 
 
 
cbd1959
a3cb7ba
 
 
 
 
 
 
 
 
 
cbd1959
a3cb7ba
 
 
cbd1959
a3cb7ba
 
cbd1959
a3cb7ba
00afad7
 
77d4add
cbd1959
 
 
 
 
 
 
 
532a4a4
 
cbd1959
 
 
 
 
 
3404c97
73c8042
cbd1959
 
 
 
 
 
 
73c8042
 
bf967fb
cbd1959
 
90b1ba7
cbd1959
 
 
 
 
 
 
 
73c8042
ea8fa3f
 
 
90b1ba7
ea8fa3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49b1770
 
ea8fa3f
 
 
 
 
 
 
49b1770
ea8fa3f
 
 
 
 
 
 
 
 
 
 
 
49b1770
 
ea8fa3f
 
 
 
49b1770
ea8fa3f
 
49b1770
ea8fa3f
 
49b1770
ea8fa3f
 
49b1770
ea8fa3f
49b1770
ea8fa3f
 
 
cbd1959
 
 
 
 
 
a3cb7ba
 
 
8fb171c
a3cb7ba
cbd1959
 
a3cb7ba
 
 
cbd1959
a3cb7ba
cbd1959
 
 
 
 
3404c97
cbd1959
 
 
a3cb7ba
cbd1959
90b1ba7
a3cb7ba
 
 
 
 
 
 
 
 
 
 
cbd1959
 
 
 
a3cb7ba
 
 
 
 
 
 
cbd1959
 
a3cb7ba
cbd1959
 
 
 
 
 
a3cb7ba
 
 
 
 
 
 
cbd1959
 
ea8fa3f
 
bf967fb
a3cb7ba
 
 
 
 
 
 
 
cbd1959
 
 
 
 
e6f8dd1
a3cb7ba
 
e900098
a3cb7ba
 
 
e6f8dd1
a3cb7ba
e6f8dd1
 
 
a3cb7ba
397d798
 
cbd1959
a3cb7ba
 
cbd1959
a3cb7ba
 
cbd1959
a3cb7ba
 
cbd1959
6d1be3a
e6f8dd1
 
 
 
cbd1959
6d1be3a
cbd1959
 
 
 
a3cb7ba
77d4add
cbd1959
714de6d
a3cb7ba
cbd1959
a3cb7ba
e6f8dd1
 
 
 
 
 
 
cbd1959
a3cb7ba
 
 
 
 
 
 
 
 
 
 
714de6d
cbd1959
714de6d
cbd1959
 
 
714de6d
 
6d1be3a
a3cb7ba
cbd1959
e6f8dd1
 
714de6d
a3cb7ba
77d4add
532a4a4
e6f8dd1
 
6d1be3a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# Adapted from https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/evaluate_from_local.py
import csv
import json
import argparse
import os
import torch
import spaces
import random
import transformers
import time
import re
from vllm import LLM, SamplingParams
from tqdm import tqdm
import logging
import sys
from datasets import load_dataset
import pandas as pd
import numpy as np

logging.basicConfig(level=logging.INFO)

# Can be found at https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/cot_prompt_lib/initial_prompt.txt
initial_prompt = "The following are multiple choice questions (with answers) about {$}. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice."

choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]
max_model_length = 4096
max_new_tokens = 2048


def preprocess(test_df):
    res_df = []
    for each in test_df:
        options = []
        for opt in each["options"]:
            if opt == "N/A":
                continue
            options.append(opt)
        each["options"] = options
        res_df.append(each)
    return res_df


def load_mmlu_pro():
    dataset = load_dataset("TIGER-Lab/MMLU-Pro")
    test_df, val_df = dataset["test"], dataset["validation"]
    test_df = preprocess(test_df)
    val_df = preprocess(val_df)
    
    # Convert to DataFrames right after loading and preprocessing
    test_df = pd.DataFrame(test_df)
    val_df = pd.DataFrame(val_df)
    
    return test_df, val_df


def load_model(model_name, gpu_utilization=0.8):
    llm = LLM(model=model_name, gpu_memory_utilization=float(gpu_utilization),
                tensor_parallel_size=torch.cuda.device_count(),
                max_model_len=max_model_length,
                trust_remote_code=True)
    logging.info(f"Torch Device CUDA Count: {torch.cuda.device_count()}")
    sampling_params = SamplingParams(temperature=0, max_tokens=max_new_tokens,
                                        stop=["Question:"])
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    return (llm, sampling_params), tokenizer


def format_cot_example(example, including_answer=True):
    # Handle both Series and dict inputs
    if isinstance(example, pd.Series):
        example = example.to_dict()
        
    prompt = "Question:\n"
    question = example["question"]
    options = example["options"]
    prompt += question + "\n"
    prompt += "Options:\n"
    for i, opt in enumerate(options):
        prompt += "{}. {}\n".format(choices[i], opt)
    if including_answer:
        cot_content = example["cot_content"].replace("A: Let's think step by step.",
                                                     "Answer: Let's think step by step.")
        prompt += cot_content + "\n\n"
    else:
        prompt += "Answer: Let's think step by step."
    return prompt


def generate_cot_prompt(val_df, curr, k):
    """
    Generate prompt with examples from val_df matching curr's category.
    
    Args:
        val_df: DataFrame containing validation examples
        curr: Series or dict representing current example
        k: Number of examples to include
    """
    prompt = initial_prompt
    
    # Handle both Series and dict inputs for curr
    if isinstance(curr, pd.Series):
        subject = curr["category"]
    else:
        subject = curr["category"]
    
    # Filter validation examples by category
    filtered_val_df = val_df[val_df["category"] == subject].head(k)
    
    prompt = prompt.replace("{$}", subject) + "\n"
    
    # Add each example to the prompt
    for _, example in filtered_val_df.iterrows():
        prompt += format_cot_example(example, including_answer=True)
    
    # Add the current example
    prompt += format_cot_example(curr, including_answer=False)
    
    return prompt


def extract_answer(text):
    pattern = r"answer is \(?([A-J])\)?"
    match = re.search(pattern, text)
    if match:
        return match.group(1)
    else:
        print("1st answer extract failed\n" + text)
        return extract_again(text)


def extract_again(text):
    match = re.search(r'.*[aA]nswer:\s*([A-J])', text)
    if match:
        return match.group(1)
    else:
        return extract_final(text)


def extract_final(text):
    pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(0)
    else:
        return None


def batch_inference(llm, sampling_params, inference_batch, tokenizer):
    start = time.time()
    outputs = llm.generate(inference_batch, sampling_params)
    logging.info("Batch of size: %s. Time taken: %s", len(inference_batch), time.time() - start)
    response_batch = []
    pred_batch = []
    for output in outputs:
        generated_text = output.outputs[0].text
        response_batch.append(generated_text)
        pred = extract_answer(generated_text)
        pred_batch.append(pred)
    return pred_batch, response_batch

def batch_inference_debug_mode(llm, sampling_params, inference_batch, tokenizer):
    start = time.time()
    outputs = llm.generate(inference_batch, sampling_params)
    logging.info("Batch of size: %s. Time taken: %s", len(inference_batch), time.time() - start)
    response_batch = []
    pred_batch = []
    input_token_counts = []
    output_token_counts = []

    for i, output in enumerate(outputs):
        generated_text = output.outputs[0].text
        response_batch.append(generated_text)
        pred = extract_answer(generated_text)
        pred_batch.append(pred)
        
        # Proper token count using tokenizer
        input_tokens = len(tokenizer.encode(inference_batch[i]))
        output_tokens = len(tokenizer.encode(generated_text))
        
        input_token_counts.append(input_tokens)
        output_token_counts.append(output_tokens)
    
    logging.info("\n----------- PRED BATCH -----------\n%s", pred_batch)
    logging.info("\n----------- RESPONSE BATCH -----------\n%s", response_batch)
    
    # Convert to DataFrame for logging (handle cases with fewer than 40 requests)
    num_samples = min(40, len(inference_batch))
    summary_df = pd.DataFrame({
        'Input': inference_batch[:num_samples],
        'Response': response_batch[:num_samples]
    })
    logging.info("\n----------- Summary of first %d requests and responses -----------\n%s", num_samples, summary_df.to_string())
    
    # Total and average input/output token statistics
    total_input_tokens = sum(input_token_counts)
    total_output_tokens = sum(output_token_counts)
    avg_input_tokens = total_input_tokens / len(input_token_counts)
    avg_output_tokens = total_output_tokens / len(output_token_counts)
    
    max_input_idx = np.argmax(input_token_counts)
    max_output_idx = np.argmax(output_token_counts)
    min_input_idx = np.argmin(input_token_counts)
    min_output_idx = np.argmin(output_token_counts)
    
    logging.info("\n----------- Token Statistics -----------")
    logging.info("Total input tokens: %d", total_input_tokens)
    logging.info("Total output tokens: %d", total_output_tokens)
    logging.info("Average input tokens: %.2f", avg_input_tokens)
    logging.info("Average output tokens: %.2f", avg_output_tokens)
    
    logging.info("\n----------- Request with max input tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", 
                 max_input_idx, input_token_counts[max_input_idx], inference_batch[max_input_idx], response_batch[max_input_idx])
    
    logging.info("\n----------- Request with max output tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", 
                 max_output_idx, output_token_counts[max_output_idx], inference_batch[max_output_idx], response_batch[max_output_idx])
    
    logging.info("\n----------- Request with min input tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", 
                 min_input_idx, input_token_counts[min_input_idx], inference_batch[min_input_idx], response_batch[min_input_idx])
    
    logging.info("\n----------- Request with min output tokens -----------\nIndex: %d (Tokens: %d)\nInput: %s\nOutput: %s", 
                 min_output_idx, output_token_counts[min_output_idx], inference_batch[min_output_idx], response_batch[min_output_idx])
    
    return pred_batch, response_batch


def calculate_accuracy(res):
    """
    Calculate accuracy and return an array of correctness (1 if correct, 0 if wrong)
    along with the overall accuracy.
    """
    correctness = []
    
    # Process predictions and compute correctness
    for i, row in res.iterrows():
        logging.info(f"Processing row {i}. Prediction: {row.get('pred')}, Answer: {row.get('answer')}")
        if not row["pred"]:
            # If prediction is None, use random choice with fixed seed
            random.seed(12345)
            options_len = len(row["options"]) if isinstance(row["options"], list) else 4
            x = random.randint(0, options_len - 1)
            is_correct = 1 if x == row["answer_index"] else 0
        else:
            is_correct = 1 if row["pred"] == row["answer"] else 0
        correctness.append(is_correct)
    
    # Calculate accuracy from correctness array
    if len(correctness) == 0:
        return [], 0.0
        
    accuracy = sum(correctness) / len(correctness)
    return correctness, accuracy


@torch.no_grad()
def eval_cot(subject, model, tokenizer, val_df, test_df, num_shots=5, debug_mode=False):
    """
    Evaluate model using chain-of-thought prompting.
    
    Args:
        subject: Subject category being evaluated
        model: Tuple of (llm, sampling_params)
        tokenizer: Model tokenizer
        val_df: DataFrame with validation examples
        test_df: DataFrame with test examples
        num_shots: Number of examples to include in prompt
    """
    llm, sampling_params = model
    global choices
    logging.info("evaluating " + subject)
    inference_batches = []
    
    # Process each test example
    for i in range(len(test_df)):
        curr = test_df.iloc[i]
        k = num_shots  # Reset k for each example
        
        # Find prompt that fits within token limit
        prompt_length_ok = False
        prompt = None
        while not prompt_length_ok and k > 0:
            prompt = generate_cot_prompt(val_df, curr, k)
            inputs = tokenizer(prompt, return_tensors="pt")
            inputs = {key: value.cuda() for key, value in inputs.items()}
            length = len(inputs["input_ids"][0])
            if length < max_model_length - max_new_tokens:
                prompt_length_ok = True
            else:
                k -= 1
                
        if not prompt_length_ok:
            # If we couldn't fit any examples, use just the test question
            prompt = generate_cot_prompt(val_df.head(0), curr, 0)
            
        inference_batches.append(prompt)


    batch_fn = batch_inference_debug_mode if debug_mode else batch_inference
    pred_batch, response_batch = batch_fn(llm, sampling_params, inference_batches, tokenizer)
    
    # Add predictions to test DataFrame
    results_df = test_df.copy()
    results_df["pred"] = pred_batch
    results_df["model_outputs"] = response_batch
    
    # Calculate accuracy
    correctness, accuracy = calculate_accuracy(results_df)
    logging.info("This batch accuracy is: {}, correct samples: {}/{}\n".format(
        str(accuracy), str(sum(correctness)), str(len(correctness))))

    return correctness, accuracy

def evaluate_mmlu_pro(model_name, num_subjects=-1, num_questions=10, num_shots=5, specific_subjects=None, flash_attention=False, regex_pattern=None):
    """
    Main evaluation function for MMLU-Pro benchmark.
    
    Args:
        model_name: Name/path of model to evaluate
        num_subjects: Number of subjects to test (-1 for all)
        num_questions: Number of questions per subject (-1 for all)
        num_shots: Number of examples to include in prompts
        specific_subjects: List of specific subjects to evaluate (overrides num_subjects)
        flash_attention: Whether to use flash attention (currently ignored)
        regex_pattern: Regex pattern for answer extraction (currently ignored)
    """
    print(f"Is CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    
    # Load model and data
    model, tokenizer = load_model(model_name, gpu_utilization=0.8)
    test_df, val_df = load_mmlu_pro()
    
    # Sort DataFrames
    test_df = test_df.sort_values(['category', 'question_id'])
    val_df = val_df.sort_values(['category', 'question_id'])
    # Get unique subjects
    all_subjects = sorted(test_df['category'].unique())
    
    # Select subjects based on parameters
    if specific_subjects is not None:
        selected_subjects = [subject for subject in specific_subjects if subject in all_subjects]
    elif num_subjects == -1 or num_subjects >= len(all_subjects):
        selected_subjects = all_subjects
    else:
        selected_subjects = all_subjects[:num_subjects]
        
    logging.info("selected subjects:\n" + "\n".join(selected_subjects))
    
    # Prepare results tracking
    results = {}
    all_correctness = []
    results_table = []
    # Process each subject
    for subject in tqdm(selected_subjects, desc="Processing Selected Categories"):
        # Filter data for current subject
        if num_questions == -1:
            # Use all questions for this subject
            test_samples = test_df[test_df['category'] == subject]
        else:
            # Use specified number of questions
            test_samples = test_df[test_df['category'] == subject].head(num_questions)
            
        val_samples = val_df[val_df['category'] == subject].head(num_shots)
        # Run evaluation
        correctness, acc = eval_cot(
            subject, 
            model, 
            tokenizer, 
            val_df=val_samples, 
            test_df=test_samples, 
            num_shots=num_shots
        )
        
        # Store results
        results[subject] = acc
        all_correctness.extend(correctness)
        results_table.append({
            'Subject': subject, 
            'Num_samples': len(test_samples), 
            'Num_correct': sum(correctness),
            'Accuracy': acc
        })
    
    # Calculate overall metrics
    weighted_acc = np.mean(all_correctness)
    min_acc_subject = min(results.items(), key=lambda x: x[1])
    max_acc_subject = max(results.items(), key=lambda x: x[1])
    
    # Return results summary
    return {
        "overall_accuracy": weighted_acc,
        "min_accuracy_subject": min_acc_subject,
        "max_accuracy_subject": max_acc_subject,
        "full_accuracy_table": results_table,
    }