mharkey commited on
Commit
2c7a7bc
·
verified ·
1 Parent(s): 25b3bcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -23
app.py CHANGED
@@ -3,46 +3,74 @@ from transformers import pipeline
3
  from datasets import load_dataset
4
  import torch
5
 
6
- # Load GTA dataset
7
  gta = load_dataset("Jize1/GTA", split="train")
8
 
9
  def evaluate_model(model_name, num_samples):
10
  try:
11
  pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
12
 
13
- correct = 0
14
- total = 0
15
- log = []
16
 
17
  for i in range(min(num_samples, len(gta))):
18
- query = gta[i]["dialogs"][0]["content"]
19
- gt_answers = gta[i]["gt_answer"].get("whitelist", [])
20
- flat_gt = {ans.strip().lower() for group in gt_answers for ans in group if isinstance(ans, str)}
21
 
22
- # Generate model output
23
- out = pipe(query, max_new_tokens=128, do_sample=False)[0]["generated_text"].strip().lower()
24
 
25
- # Match: exact substring match with any whitelist answer
26
- matched = any(gt in out for gt in flat_gt)
 
27
 
28
- log.append(f"### Query {i}\n**Input**: {query}\n**Prediction**: {out}\n**GT**: {flat_gt}\n**✔️ Correct**: {matched}\n")
29
- correct += int(matched)
30
- total += 1
31
 
32
- acc = round((correct / total) * 100, 2)
33
- summary = f"### 🔍 GTA Answer Accuracy (AnsAcc) for `{model_name}`: **{acc}%** on {total} queries\n\n---\n"
34
- return summary + "\n".join(log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  except Exception as e:
37
- return f"❌ Evaluation failed: {e}"
38
 
 
39
  with gr.Blocks() as demo:
40
- gr.Markdown("# 🧪 Real GTA Evaluation (Answer Accuracy Only)")
41
- model_input = gr.Textbox(label="Enter Hugging Face Model Name", value="Qwen/Qwen2.5-3B")
42
- sample_count = gr.Slider(label="Number of GTA samples to evaluate", minimum=1, maximum=229, value=10, step=1)
 
 
43
  output_md = gr.Markdown()
44
 
45
- model_input.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
46
- sample_count.change(fn=evaluate_model, inputs=[model_input, sample_count], outputs=output_md)
47
 
48
  demo.launch()
 
3
  from datasets import load_dataset
4
  import torch
5
 
 
6
  gta = load_dataset("Jize1/GTA", split="train")
7
 
8
  def evaluate_model(model_name, num_samples):
9
  try:
10
  pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)
11
 
12
+ inst_correct, tool_correct, summ_correct, ans_correct = 0, 0, 0, 0
13
+ logs = []
 
14
 
15
  for i in range(min(num_samples, len(gta))):
16
+ sample = gta[i]
17
+ query = sample["dialogs"][0]["content"]
18
+ tools_used = [step["function"]["name"].lower() for step in sample["dialogs"] if "function" in step.get("function", {})]
19
 
20
+ prediction = pipe(query, max_new_tokens=256, do_sample=False)[0]["generated_text"].strip().lower()
 
21
 
22
+ # Instruction following: if answer is long enough and not hallucinated
23
+ inst_pass = len(prediction) > 10 and any(w in prediction for w in ["use", "calculate", "looks like", "means", "based on"])
24
+ inst_correct += inst_pass
25
 
26
+ # ToolAcc: if any known tool name is mentioned
27
+ tool_pass = any(tool in prediction for tool in tools_used)
28
+ tool_correct += tool_pass
29
 
30
+ # SummAcc: if answer includes concluding phrases or numbers (as proxy)
31
+ summ_pass = any(x in prediction for x in ["so", "therefore", "the answer is", "equals", "you will need", "hence"])
32
+ summ_correct += summ_pass
33
+
34
+ # AnsAcc: match whitelist phrase
35
+ gt_phrases = sample["gt_answer"].get("whitelist", [])
36
+ flat_gt = {s.strip().lower() for group in gt_phrases for s in group if isinstance(s, str)}
37
+ ans_pass = any(g in prediction for g in flat_gt)
38
+ ans_correct += ans_pass
39
+
40
+ logs.append(f"""
41
+ ### Query {i}
42
+ **Input**: {query}
43
+ **Prediction**: {prediction}
44
+ **GT**: {flat_gt}
45
+ **Instruction✔️**: {inst_pass}
46
+ **Tool✔️**: {tool_pass}
47
+ **Summary✔️**: {summ_pass}
48
+ **Answer✔️**: {ans_pass}
49
+ ---""")
50
+
51
+ total = min(num_samples, len(gta))
52
+ results = {
53
+ "InstAcc": round((inst_correct / total) * 100, 2),
54
+ "ToolAcc": round((tool_correct / total) * 100, 2),
55
+ "SummAcc": round((summ_correct / total) * 100, 2),
56
+ "AnsAcc": round((ans_correct / total) * 100, 2),
57
+ }
58
+
59
+ summary = "\n".join([f"**{k}**: {v}%" for k, v in results.items()])
60
+ return f"## 🔬 GTA Evaluation for `{model_name}` on {total} queries\n\n{summary}\n\n---\n" + "\n".join(logs)
61
 
62
  except Exception as e:
63
+ return f"❌ Error: {e}"
64
 
65
+ # Gradio UI
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("# 🧠 GTA Tool Use Evaluation (Real Metrics, Real Queries)")
68
+ with gr.Row():
69
+ model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen2.5-3B")
70
+ sample_slider = gr.Slider(label="Number of GTA samples", minimum=1, maximum=229, value=10, step=1)
71
+ run_btn = gr.Button("Run Evaluation")
72
  output_md = gr.Markdown()
73
 
74
+ run_btn.click(fn=evaluate_model, inputs=[model_input, sample_slider], outputs=output_md)
 
75
 
76
  demo.launch()