rohansampath commited on
Commit
00afad7
·
verified ·
1 Parent(s): bd9ca6e

Update mmlu_eval.py

Browse files
Files changed (1) hide show
  1. mmlu_eval.py +48 -31
mmlu_eval.py CHANGED
@@ -1,63 +1,78 @@
1
  import torch
2
- import random
3
  import evaluate
4
  from datasets import load_dataset
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import spaces
7
 
8
- # Load Accuracy Metric
9
  accuracy_metric = evaluate.load("accuracy")
10
-
11
- # Load MMLU dataset
12
  mmlu_dataset = load_dataset("cais/mmlu", "all")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @spaces.GPU
15
- def generate_answer(model, tokenizer, question):
16
  """
17
- Generates an answer using Mistral's instruction format.
18
  """
19
- prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]"
20
 
21
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
22
  with torch.no_grad():
23
  outputs = model.generate(
24
  **inputs,
25
- max_new_tokens=50,
26
- temperature=0.0,
 
27
  pad_token_id=tokenizer.pad_token_id,
28
  eos_token_id=tokenizer.eos_token_id
29
  )
30
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
 
31
 
32
  def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
33
  """
34
- Evaluates the model on MMLU across all 57 tasks.
35
-
36
- Returns:
37
- - Overall accuracy
38
- - Min accuracy task
39
- - Max accuracy task
40
- - Two correct examples
41
- - Two incorrect examples
42
  """
43
  results = {}
44
  correct_examples = []
45
  incorrect_examples = []
46
-
47
- for task_name in mmlu_dataset.keys():
48
- print ("TASK NAME: ", task_name)
 
 
49
  dataset = mmlu_dataset[task_name]
50
- sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset)))
 
 
51
 
52
  predictions = []
53
  references = []
54
 
55
  for sample in sampled_questions:
56
- print ("SAMPLE", sample)
57
  question = sample["question"]
58
- correct_answer = str(sample["answer"]).strip().lower()
59
- model_output = generate_answer(model, tokenizer, question).strip().lower()
60
-
 
 
 
61
  predictions.append(model_output)
62
  references.append(correct_answer)
63
 
@@ -68,10 +83,11 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
68
  incorrect_examples.append((task_name, question, model_output, correct_answer))
69
 
70
  # Compute accuracy for the task
71
- norm_preds = [str(p).lower().strip() for p in predictions]
72
- norm_refs = [str(r).lower().strip() for r in references]
73
- task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"]
74
-
 
75
  results[task_name] = task_accuracy
76
 
77
  # Compute overall statistics
@@ -85,4 +101,5 @@ def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
85
  "max_accuracy_task": (max_task, results[max_task]),
86
  "correct_examples": correct_examples,
87
  "incorrect_examples": incorrect_examples,
88
- }
 
 
1
  import torch
 
2
  import evaluate
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import spaces
6
 
 
7
  accuracy_metric = evaluate.load("accuracy")
 
 
8
  mmlu_dataset = load_dataset("cais/mmlu", "all")
9
 
10
+ def format_mmlu_prompt(question, choices):
11
+ """
12
+ Formats the prompt according to Mistral's official instruction format.
13
+ Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
14
+ """
15
+ formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
16
+ prompt = f"""<s>[INST] You are taking a multiple choice test. Select the correct answer by responding with only the letter (A, B, C, or D) of the correct choice.
17
+
18
+ Question: {question}
19
+
20
+ Choices:
21
+ {formatted_choices} [/INST]"""
22
+ return prompt
23
+
24
  @spaces.GPU
25
+ def generate_answer(model, tokenizer, question, choices):
26
  """
27
+ Generates an answer using Mistral's instruction format for multiple choice questions.
28
  """
29
+ prompt = format_mmlu_prompt(question, choices)
30
 
31
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
32
  with torch.no_grad():
33
  outputs = model.generate(
34
  **inputs,
35
+ max_new_tokens=5, # We only need a single letter
36
+ do_sample=False, # Use deterministic greedy decoding
37
+ num_beams=1, # Use simple greedy search
38
  pad_token_id=tokenizer.pad_token_id,
39
  eos_token_id=tokenizer.eos_token_id
40
  )
41
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
42
+ # Extract just the letter answer
43
+ for char in response:
44
+ if char in 'ABCD':
45
+ return char
46
+ return response[:1] # Fallback: take first character
47
 
48
  def evaluate_mmlu(model, tokenizer, num_questions_per_task=5):
49
  """
50
+ Evaluates the model on MMLU across all tasks.
 
 
 
 
 
 
 
51
  """
52
  results = {}
53
  correct_examples = []
54
  incorrect_examples = []
55
+
56
+ # Filter out 'auxiliary_train' and other non-test splits
57
+ test_tasks = [k for k in mmlu_dataset.keys() if 'test' in k]
58
+
59
+ for task_name in sorted(test_tasks): # Sort tasks for deterministic order
60
  dataset = mmlu_dataset[task_name]
61
+ # Instead of random sampling, take the first n questions
62
+ total_questions = min(num_questions_per_task, len(dataset))
63
+ sampled_questions = [dataset[i] for i in range(total_questions)]
64
 
65
  predictions = []
66
  references = []
67
 
68
  for sample in sampled_questions:
 
69
  question = sample["question"]
70
+ choices = [sample["choices"][i] for i in range(4)]
71
+ # Convert numeric answer to letter (0->A, 1->B, etc.)
72
+ correct_answer = chr(65 + sample["answer"])
73
+
74
+ model_output = generate_answer(model, tokenizer, question, choices)
75
+
76
  predictions.append(model_output)
77
  references.append(correct_answer)
78
 
 
83
  incorrect_examples.append((task_name, question, model_output, correct_answer))
84
 
85
  # Compute accuracy for the task
86
+ task_accuracy = accuracy_metric.compute(
87
+ predictions=predictions,
88
+ references=references
89
+ )["accuracy"]
90
+
91
  results[task_name] = task_accuracy
92
 
93
  # Compute overall statistics
 
101
  "max_accuracy_task": (max_task, results[max_task]),
102
  "correct_examples": correct_examples,
103
  "incorrect_examples": incorrect_examples,
104
+ "all_results": results # Added for detailed analysis
105
+ }