File size: 5,548 Bytes
77d4add
 
 
 
c69128a
b748395
 
 
 
 
77d4add
 
b748395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77d4add
00afad7
 
 
 
 
 
 
 
 
 
 
 
 
 
c69128a
00afad7
77d4add
00afad7
77d4add
00afad7
77d4add
 
 
 
 
00afad7
 
 
77d4add
 
 
00afad7
 
 
 
 
 
77d4add
6cf3cc0
b748395
77d4add
00afad7
77d4add
b748395
77d4add
1b7636f
 
00afad7
 
 
 
 
77d4add
00afad7
 
 
1b7636f
77d4add
 
1b7636f
77d4add
35f8612
77d4add
00afad7
 
 
35f8612
00afad7
 
35f8612
00afad7
77d4add
 
 
1b7636f
 
 
 
 
 
77d4add
00afad7
 
 
 
 
77d4add
 
 
 
 
 
 
 
 
 
 
1b7636f
 
00afad7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

accuracy_metric = evaluate.load("accuracy")

def load_dataset_from_hf(verbose=False):
    mmlu_dataset = load_dataset("cais/mmlu", "all")
    if verbose:
        for split in mmlu_dataset.keys():
            dataset = mmlu_dataset[split]  # Access the dataset split
            
            # Log number of rows and columns
            num_rows = len(dataset)
            num_cols = len(dataset.column_names)
            
            logger.info(f"Dataset Split: {split}")
            logger.info(f"Number of Rows: {num_rows}")
            logger.info(f"Number of Columns: {num_cols}")
            
            # Log column names and their types
            column_types = {col: str(dataset.features[col].dtype) for col in dataset.column_names}
            logger.info(f"Column Names: {dataset.column_names}")
            logger.info(f"Column Types: {column_types}")
        
            # Log a sample of 5 rows
            sample_rows = dataset.select(range(min(5, num_rows)))  # Ensure we don't exceed available rows
            logger.info("Sample Rows:")
            for row in sample_rows:
                logger.info(row)
        
            logger.info("=" * 50)  # Separator for readability
    return mmlu_dataset
        

def format_mmlu_prompt(question, choices):
    """
    Formats the prompt according to Mistral's official instruction format.
    Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
    """
    formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
    prompt = f"""<s>[INST] You are taking a multiple choice test. Select the correct answer by responding with only the letter (A, B, C, or D) of the correct choice.

Question: {question}

Choices:
{formatted_choices} [/INST]"""
    return prompt

@spaces.GPU
def generate_answer(model, tokenizer, question, choices):
    """
    Generates an answer using Mistral's instruction format for multiple choice questions.
    """
    prompt = format_mmlu_prompt(question, choices)
    
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,  # We only need a single letter
            do_sample=False,   # Use deterministic greedy decoding
            num_beams=1,       # Use simple greedy search
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    # Extract just the letter answer
    for char in response:
        if char in 'ABCD':
            return char
    return response[:1]  # Fallback: take first character

@torch.no_grad()
def evaluate_mmlu(model, tokenizer, num_questions=5):
    """
    Evaluates the model on MMLU across all tasks.
    """
    mmlu_dataset = load_dataset_from_hf(verbose=True)
    results = {}
    correct_examples = []
    incorrect_examples = []
    
    # Filter out 'auxiliary_train' and other non-test splits
    test_tasks = [k for k in mmlu_dataset.keys() if 'test' in k]
    
    for task_name in sorted(test_tasks):  # Sort tasks for deterministic order
        dataset = mmlu_dataset[task_name]
        # Instead of random sampling, take the first n questions
        total_questions = min(num_questions_per_task, len(dataset))
        sampled_questions = [dataset[i] for i in range(total_questions)]

        predictions = []
        references = []

        for sample in sampled_questions:
            print ("TASK", task_name, "Sample", sample)
            question = sample["question"]
            choices = [sample["choices"][i] for i in range(4)]
            # Convert numeric answer to letter (0->A, 1->B, etc.)
            correct_answer = chr(65 + sample["answer"])
            print ("question:", question, "\n choices:", choices, "\n correct answer:", correct_answer)
            
            model_output = generate_answer(model, tokenizer, question, choices)
            print ("model output:", model_output)
            
            predictions.append(model_output)
            references.append(correct_answer)

            # Store examples
            if model_output == correct_answer and len(correct_examples) < 2:
                correct_examples.append((task_name, question, model_output, correct_answer))
            elif model_output != correct_answer and len(incorrect_examples) < 2:
                incorrect_examples.append((task_name, question, model_output, correct_answer))

        # Compute accuracy for the task
        task_accuracy = accuracy_metric.compute(
            predictions=predictions,
            references=references
        )["accuracy"]
        
        results[task_name] = task_accuracy

    # Compute overall statistics
    overall_accuracy = sum(results.values()) / len(results)
    min_task = min(results, key=results.get)
    max_task = max(results, key=results.get)

    return {
        "overall_accuracy": overall_accuracy,
        "min_accuracy_task": (min_task, results[min_task]),
        "max_accuracy_task": (max_task, results[max_task]),
        "correct_examples": correct_examples,
        "incorrect_examples": incorrect_examples,
        "all_results": results  # Added for detailed analysis
    }