|
import torch |
|
import evaluate |
|
import re |
|
import base64 |
|
import io |
|
import matplotlib.pyplot as plt |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import spaces |
|
|
|
|
|
|
|
|
|
test_data = [ |
|
{"question": "What is 2+2?", "answer": "4"}, |
|
{"question": "What is 3*3?", "answer": "9"}, |
|
{"question": "What is 10/2?", "answer": "5"}, |
|
] |
|
|
|
|
|
|
|
|
|
accuracy_metric = evaluate.load("accuracy") |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def generate_answer(question, model, tokenizer): |
|
""" |
|
Generates an answer using Mistral's instruction format. |
|
""" |
|
|
|
prompt = f"""<s>[INST] {question}. Provide only the numerical answer. [/INST]""" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=50, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
text_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return text_output.replace(question, "").strip() |
|
|
|
def parse_answer(model_output): |
|
""" |
|
Extract numeric answer from model's text output. |
|
""" |
|
|
|
match = re.search(r"(-?\d*\.?\d+)", model_output) |
|
if match: |
|
return match.group(1) |
|
return model_output.strip() |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def evaluate_toy_dataset(model, tokenizer): |
|
predictions = [] |
|
references = [] |
|
raw_outputs = [] |
|
|
|
for sample in test_data: |
|
question = sample["question"] |
|
reference_answer = sample["answer"] |
|
|
|
|
|
model_output = generate_answer(question, model, tokenizer) |
|
predicted_answer = parse_answer(model_output) |
|
|
|
predictions.append(predicted_answer) |
|
references.append(reference_answer) |
|
raw_outputs.append({ |
|
"question": question, |
|
"model_output": model_output, |
|
"parsed_answer": predicted_answer, |
|
"reference": reference_answer |
|
}) |
|
|
|
|
|
def normalize_answer(ans): |
|
return str(ans).lower().strip() |
|
|
|
norm_preds = [normalize_answer(p) for p in predictions] |
|
norm_refs = [normalize_answer(r) for r in references] |
|
|
|
|
|
results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs) |
|
accuracy = results["accuracy"] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs)) |
|
incorrect_count = len(test_data) - correct_count |
|
|
|
bars = ax.bar(["Correct", "Incorrect"], |
|
[correct_count, incorrect_count], |
|
color=["#2ecc71", "#e74c3c"]) |
|
|
|
|
|
for bar in bars: |
|
height = bar.get_height() |
|
ax.text(bar.get_x() + bar.get_width()/2., height, |
|
f'{int(height)}', |
|
ha='center', va='bottom') |
|
|
|
ax.set_title("Evaluation Results") |
|
ax.set_ylabel("Count") |
|
ax.set_ylim([0, len(test_data) + 0.5]) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches='tight', dpi=300) |
|
buf.seek(0) |
|
plt.close(fig) |
|
data = base64.b64encode(buf.read()).decode("utf-8") |
|
|
|
|
|
details_html = """ |
|
<div style="margin-top: 20px;"> |
|
<h3>Detailed Results:</h3> |
|
<table style="width:100%; border-collapse: collapse;"> |
|
<tr style="background-color: #f5f5f5;"> |
|
<th style="padding: 8px; border: 1px solid #ddd;">Question</th> |
|
<th style="padding: 8px; border: 1px solid #ddd;">Model Output</th> |
|
<th style="padding: 8px; border: 1px solid #ddd;">Parsed Answer</th> |
|
<th style="padding: 8px; border: 1px solid #ddd;">Reference</th> |
|
</tr> |
|
""" |
|
|
|
for result in raw_outputs: |
|
details_html += f""" |
|
<tr> |
|
<td style="padding: 8px; border: 1px solid #ddd;">{result['question']}</td> |
|
<td style="padding: 8px; border: 1px solid #ddd;">{result['model_output']}</td> |
|
<td style="padding: 8px; border: 1px solid #ddd;">{result['parsed_answer']}</td> |
|
<td style="padding: 8px; border: 1px solid #ddd;">{result['reference']}</td> |
|
</tr> |
|
""" |
|
|
|
details_html += "</table></div>" |
|
|
|
full_html = f""" |
|
<div> |
|
<img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;"> |
|
{details_html} |
|
</div> |
|
""" |
|
|
|
return f"Accuracy: {accuracy:.2f}", full_html |
|
|