Update app.py
Browse files
app.py
CHANGED
@@ -3,46 +3,74 @@ from transformers import pipeline
|
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
5 |
|
6 |
-
# Load GTA dataset
|
7 |
gta = load_dataset("Jize1/GTA", split="train")
|
8 |
|
9 |
def evaluate_model(model_name, num_samples):
|
10 |
try:
|
11 |
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
log = []
|
16 |
|
17 |
for i in range(min(num_samples, len(gta))):
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
out = pipe(query, max_new_tokens=128, do_sample=False)[0]["generated_text"].strip().lower()
|
24 |
|
25 |
-
#
|
26 |
-
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
except Exception as e:
|
37 |
-
return f"❌
|
38 |
|
|
|
39 |
with gr.Blocks() as demo:
|
40 |
-
gr.Markdown("#
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
output_md = gr.Markdown()
|
44 |
|
45 |
-
|
46 |
-
sample_count.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
|
47 |
|
48 |
demo.launch()
|
|
|
3 |
from datasets import load_dataset
|
4 |
import torch
|
5 |
|
|
|
6 |
gta = load_dataset("Jize1/GTA", split="train")
|
7 |
|
8 |
def evaluate_model(model_name, num_samples):
|
9 |
try:
|
10 |
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
|
11 |
|
12 |
+
inst_correct, tool_correct, summ_correct, ans_correct = 0, 0, 0, 0
|
13 |
+
logs = []
|
|
|
14 |
|
15 |
for i in range(min(num_samples, len(gta))):
|
16 |
+
sample = gta[i]
|
17 |
+
query = sample["dialogs"][0]["content"]
|
18 |
+
tools_used = [step["function"]["name"].lower() for step in sample["dialogs"] if "function" in step.get("function", {})]
|
19 |
|
20 |
+
prediction = pipe(query, max_new_tokens=256, do_sample=False)[0]["generated_text"].strip().lower()
|
|
|
21 |
|
22 |
+
# Instruction following: if answer is long enough and not hallucinated
|
23 |
+
inst_pass = len(prediction) > 10 and any(w in prediction for w in ["use", "calculate", "looks like", "means", "based on"])
|
24 |
+
inst_correct += inst_pass
|
25 |
|
26 |
+
# ToolAcc: if any known tool name is mentioned
|
27 |
+
tool_pass = any(tool in prediction for tool in tools_used)
|
28 |
+
tool_correct += tool_pass
|
29 |
|
30 |
+
# SummAcc: if answer includes concluding phrases or numbers (as proxy)
|
31 |
+
summ_pass = any(x in prediction for x in ["so", "therefore", "the answer is", "equals", "you will need", "hence"])
|
32 |
+
summ_correct += summ_pass
|
33 |
+
|
34 |
+
# AnsAcc: match whitelist phrase
|
35 |
+
gt_phrases = sample["gt_answer"].get("whitelist", [])
|
36 |
+
flat_gt = {s.strip().lower() for group in gt_phrases for s in group if isinstance(s, str)}
|
37 |
+
ans_pass = any(g in prediction for g in flat_gt)
|
38 |
+
ans_correct += ans_pass
|
39 |
+
|
40 |
+
logs.append(f"""
|
41 |
+
### Query {i}
|
42 |
+
**Input**: {query}
|
43 |
+
**Prediction**: {prediction}
|
44 |
+
**GT**: {flat_gt}
|
45 |
+
**Instruction✔️**: {inst_pass}
|
46 |
+
**Tool✔️**: {tool_pass}
|
47 |
+
**Summary✔️**: {summ_pass}
|
48 |
+
**Answer✔️**: {ans_pass}
|
49 |
+
---""")
|
50 |
+
|
51 |
+
total = min(num_samples, len(gta))
|
52 |
+
results = {
|
53 |
+
"InstAcc": round((inst_correct / total) * 100, 2),
|
54 |
+
"ToolAcc": round((tool_correct / total) * 100, 2),
|
55 |
+
"SummAcc": round((summ_correct / total) * 100, 2),
|
56 |
+
"AnsAcc": round((ans_correct / total) * 100, 2),
|
57 |
+
}
|
58 |
+
|
59 |
+
summary = "\n".join([f"**{k}**: {v}%" for k, v in results.items()])
|
60 |
+
return f"## 🔬 GTA Evaluation for `{model_name}` on {total} queries\n\n{summary}\n\n---\n" + "\n".join(logs)
|
61 |
|
62 |
except Exception as e:
|
63 |
+
return f"❌ Error: {e}"
|
64 |
|
65 |
+
# Gradio UI
|
66 |
with gr.Blocks() as demo:
|
67 |
+
gr.Markdown("# 🧠 GTA Tool Use Evaluation (Real Metrics, Real Queries)")
|
68 |
+
with gr.Row():
|
69 |
+
model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen2.5-3B")
|
70 |
+
sample_slider = gr.Slider(label="Number of GTA samples", minimum=1, maximum=229, value=10, step=1)
|
71 |
+
run_btn = gr.Button("Run Evaluation")
|
72 |
output_md = gr.Markdown()
|
73 |
|
74 |
+
run_btn.click(fn=evaluate_model, inputs=[model_input, sample_slider], outputs=output_md)
|
|
|
75 |
|
76 |
demo.launch()
|