|
import torch |
|
import random |
|
import evaluate |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
accuracy_metric = evaluate.load("accuracy") |
|
|
|
|
|
mmlu_dataset = load_dataset("cais/mmlu", "all") |
|
|
|
def generate_answer(model, tokenizer, question): |
|
""" |
|
Generates an answer using Mistral's instruction format. |
|
""" |
|
prompt = f"<s>[INST] {question}. Provide only the correct answer. [/INST]" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=50, |
|
temperature=0.0, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() |
|
|
|
def evaluate_mmlu(model, tokenizer, num_questions_per_task=5): |
|
""" |
|
Evaluates the model on MMLU across all 57 tasks. |
|
|
|
Returns: |
|
- Overall accuracy |
|
- Min accuracy task |
|
- Max accuracy task |
|
- Two correct examples |
|
- Two incorrect examples |
|
""" |
|
results = {} |
|
correct_examples = [] |
|
incorrect_examples = [] |
|
|
|
for task_name in mmlu_dataset.keys(): |
|
console.log ("TASK NAME", task_name) |
|
dataset = mmlu_dataset[task_name] |
|
sampled_questions = random.sample(list(dataset), min(num_questions_per_task, len(dataset))) |
|
|
|
predictions = [] |
|
references = [] |
|
|
|
for sample in sampled_questions: |
|
console.log ("SAMPLE", sample) |
|
question = sample["question"] |
|
correct_answer = str(sample["answer"]).strip().lower() |
|
model_output = generate_answer(model, tokenizer, question).strip().lower() |
|
|
|
predictions.append(model_output) |
|
references.append(correct_answer) |
|
|
|
|
|
if model_output == correct_answer and len(correct_examples) < 2: |
|
correct_examples.append((task_name, question, model_output, correct_answer)) |
|
elif model_output != correct_answer and len(incorrect_examples) < 2: |
|
incorrect_examples.append((task_name, question, model_output, correct_answer)) |
|
|
|
|
|
norm_preds = [str(p).lower().strip() for p in predictions] |
|
norm_refs = [str(r).lower().strip() for r in references] |
|
task_accuracy = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)["accuracy"] |
|
|
|
results[task_name] = task_accuracy |
|
|
|
|
|
overall_accuracy = sum(results.values()) / len(results) |
|
min_task = min(results, key=results.get) |
|
max_task = max(results, key=results.get) |
|
|
|
return { |
|
"overall_accuracy": overall_accuracy, |
|
"min_accuracy_task": (min_task, results[min_task]), |
|
"max_accuracy_task": (max_task, results[max_task]), |
|
"correct_examples": correct_examples, |
|
"incorrect_examples": incorrect_examples, |
|
} |