|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import os |
|
from huggingface_hub import login |
|
from toy-dataset-eval import evaluate_toy_dataset |
|
from mmlu_eval import evaluate_mmlu |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN_READ_WRITE") |
|
if hf_token: |
|
login(hf_token) |
|
else: |
|
print("⚠️ No HF_TOKEN_READ_WRITE found in environment") |
|
|
|
|
|
|
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.3" |
|
tokenizer = None |
|
model = None |
|
model_loaded = False |
|
|
|
@spaces.GPU |
|
def load_model(): |
|
"""Loads the Mistral model and tokenizer and updates the load status.""" |
|
global tokenizer, model, model_loaded |
|
try: |
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
|
if model is None: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
token=hf_token, |
|
torch_dtype=torch.float16 |
|
) |
|
model.to('cuda') |
|
model_loaded = True |
|
return "✅ Model Loaded!" |
|
except Exception as e: |
|
model_loaded = False |
|
return f"❌ Model Load Failed: {str(e)}" |
|
|
|
|
|
|
|
|
|
@spaces.GPU (duration=120) |
|
def run_toy_evaluation(): |
|
"""Runs the toy dataset evaluation.""" |
|
if not model_loaded: |
|
load_model() |
|
|
|
if not model_loaded: |
|
return "⚠️ Model not loaded. Please load the model first." |
|
|
|
results = evaluate_toy_dataset(model, tokenizer) |
|
return results |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def run_mmlu_evaluation(num_questions): |
|
if not model_loaded: |
|
load_model() |
|
|
|
if not model_loaded: |
|
return "⚠️ Model not loaded. Please load the model first." |
|
""" |
|
Runs the MMLU evaluation with the specified number of questions per task. |
|
Also displays two correct and two incorrect examples. |
|
""" |
|
results = evaluate_mmlu(model, tokenizer, num_questions) |
|
|
|
overall_accuracy = results["overall_accuracy"] |
|
min_task, min_acc = results["min_accuracy_task"] |
|
max_task, max_acc = results["max_accuracy_task"] |
|
correct_examples = results["correct_examples"] |
|
incorrect_examples = results["incorrect_examples"] |
|
|
|
|
|
def format_example(example): |
|
task, question, model_output, correct_answer = example |
|
return f"**Task:** {task}\n**Question:** {question}\n**Model Output:** {model_output}\n**Correct Answer:** {correct_answer}\n" |
|
|
|
correct_text = "\n\n".join(format_example(ex) for ex in correct_examples) |
|
incorrect_text = "\n\n".join(format_example(ex) for ex in incorrect_examples) |
|
|
|
report = ( |
|
f"### Overall Accuracy: {overall_accuracy:.2f}\n" |
|
f"**Min Accuracy:** {min_acc:.2f} on `{min_task}`\n" |
|
f"**Max Accuracy:** {max_acc:.2f} on `{max_task}`\n\n" |
|
f"---\n\n" |
|
f"### ✅ Correct Examples\n{correct_text if correct_examples else 'No correct examples available.'}\n\n" |
|
f"### ❌ Incorrect Examples\n{incorrect_text if incorrect_examples else 'No incorrect examples available.'}" |
|
) |
|
|
|
return report |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Mistral-7B Math Evaluation Demo") |
|
gr.Markdown(""" |
|
This demo evaluates Mistral-7B on Various Datasets. |
|
""") |
|
|
|
|
|
load_button = gr.Button("Load Model", variant="primary") |
|
load_status = gr.Textbox(label="Model Status", interactive=False) |
|
load_button.click(fn=load_model, inputs=None, outputs=load_status) |
|
|
|
|
|
gr.Markdown("### Toy Dataset Evaluation") |
|
eval_button = gr.Button("Run Evaluation", variant="primary") |
|
output_text = gr.Textbox(label="Results") |
|
output_plot = gr.HTML(label="Visualization and Details") |
|
|
|
eval_button.click(fn=run_toy_evaluation, inputs=None, outputs=[output_text, output_plot]) |
|
|
|
|
|
gr.Markdown("### MMLU Evaluation") |
|
num_questions_input = gr.Number(label="Questions per Task (Total of 57 tasks)", value=5, precision=0) |
|
eval_mmlu_button = gr.Button("Run MMLU Evaluation", variant="primary") |
|
mmlu_output = gr.Textbox(label="MMLU Evaluation Results") |
|
|
|
eval_mmlu_button.click(fn=run_mmlu_evaluation, inputs=[num_questions_input], outputs=[mmlu_output]) |
|
|
|
demo.launch() |