rohansampath commited on
Commit
6d1be3a
·
verified ·
1 Parent(s): 4105791

Update mmlu_eval_original.py

Browse files
Files changed (1) hide show
  1. mmlu_eval_original.py +31 -10
mmlu_eval_original.py CHANGED
@@ -153,9 +153,16 @@ def eval (subject, model, tokenizer, dev_df, test_df, num_questions_per_subject=
153
 
154
  def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=5):
155
  """
156
- Evaluates the model on MMLU across all subjects.
 
 
 
 
 
 
 
157
  """
158
- model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference.
159
 
160
  dataset = load_dataset_from_hf(verbose=True)
161
 
@@ -167,12 +174,17 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
167
  test_df = test_df.sort_values(['subject', 'question'])
168
  dev_df = dev_df.sort_values(['subject', 'question'])
169
 
170
- subjects = sorted(test_df['subject'].unique())
 
 
 
 
 
 
 
 
171
 
172
  results = {}
173
- correct_examples = []
174
- incorrect_examples = []
175
- all_accuracies = []
176
  all_cors = []
177
  results_table = []
178
 
@@ -183,7 +195,16 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
183
  # Log subject and sample counts
184
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
185
 
186
- cors, acc, probs = eval(subject, model, tokenizer, dev_samples, test_samples, num_questions_per_subject=num_questions, train_shots=num_shots)
 
 
 
 
 
 
 
 
 
187
  results[subject] = acc
188
  all_cors.append(cors)
189
 
@@ -193,7 +214,7 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
193
  'Num_correct': int(np.sum(cors)),
194
  'Accuracy': acc
195
  })
196
-
197
  weighted_acc = np.mean(np.concatenate(all_cors))
198
 
199
  min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
@@ -203,5 +224,5 @@ def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=
203
  "overall_accuracy": weighted_acc,
204
  "min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
205
  "max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
206
- "full_accuracy_table": results_table
207
- }
 
153
 
154
  def evaluate_mmlu(model, tokenizer, num_subjects=-1, num_questions=5, num_shots=5):
155
  """
156
+ Evaluates the model on MMLU across specified number of subjects.
157
+
158
+ Args:
159
+ model: The model to evaluate
160
+ tokenizer: The tokenizer to use
161
+ num_subjects (int): Number of subjects to evaluate. If -1, evaluates all subjects
162
+ num_questions (int): Number of questions per subject
163
+ num_shots (int): Number of few-shot examples to use
164
  """
165
+ model.eval() # Ensure Dropout and BatchNorm behave appropriately for inference
166
 
167
  dataset = load_dataset_from_hf(verbose=True)
168
 
 
174
  test_df = test_df.sort_values(['subject', 'question'])
175
  dev_df = dev_df.sort_values(['subject', 'question'])
176
 
177
+ # Get all unique subjects
178
+ all_subjects = sorted(test_df['subject'].unique())
179
+
180
+ # Select subjects based on num_subjects parameter
181
+ if num_subjects == -1 or num_subjects >= len(all_subjects):
182
+ subjects = all_subjects
183
+ else:
184
+ # Take the first num_subjects subjects
185
+ subjects = all_subjects[:num_subjects]
186
 
187
  results = {}
 
 
 
188
  all_cors = []
189
  results_table = []
190
 
 
195
  # Log subject and sample counts
196
  logger.info(f"Subject: {subject}, Test Samples: {len(test_samples)}, Dev Samples: {len(dev_samples)}")
197
 
198
+ cors, acc, probs = eval(
199
+ subject,
200
+ model,
201
+ tokenizer,
202
+ dev_samples,
203
+ test_samples,
204
+ num_questions_per_subject=num_questions,
205
+ train_shots=num_shots
206
+ )
207
+
208
  results[subject] = acc
209
  all_cors.append(cors)
210
 
 
214
  'Num_correct': int(np.sum(cors)),
215
  'Accuracy': acc
216
  })
217
+
218
  weighted_acc = np.mean(np.concatenate(all_cors))
219
 
220
  min_acc_subject = min(results.items(), key=lambda x: x[1])[0]
 
224
  "overall_accuracy": weighted_acc,
225
  "min_accuracy_subject": (min_acc_subject, results[min_acc_subject]),
226
  "max_accuracy_subject": (max_acc_subject, results[max_acc_subject]),
227
+ "full_accuracy_table": results_table,
228
+ }