mharkey commited on
Commit
d1599df
·
verified ·
1 Parent(s): 060f632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -47
app.py CHANGED
@@ -1,57 +1,90 @@
1
  import os
2
- import gradio as gr
 
3
  from datasets import load_dataset
4
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
- import random
 
 
 
 
 
 
 
 
 
6
 
7
- # Use HF token from environment (set in Hugging Face Space secrets)
8
- hf_token = os.environ.get("HF_TOKEN")
 
 
 
 
9
 
10
- # Load dataset once (train split only)
11
- gta = load_dataset("Jize1/GTA", split="train", use_auth_token=hf_token)
 
12
 
13
- # Pick 5 queries for simplicity
14
- sample_queries = random.sample(list(gta), 5)
 
 
15
 
16
- # Metric simulation logic (placeholder)
17
  def evaluate_model(model_name):
18
- try:
19
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
20
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, use_auth_token=hf_token)
21
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
22
-
23
- inst_acc = round(random.uniform(30, 80), 2)
24
- tool_acc = round(random.uniform(10, 70), 2)
25
- summ_acc = round(random.uniform(40, 90), 2)
26
-
27
- output_rows = []
28
- for q in sample_queries:
29
- user_input = next(d['content'] for d in q['dialogs'] if d['role'] == "user")
30
- toolnames = [t["name"] for t in q["tools"]]
31
- output_rows.append({
32
- "Query": user_input[:80] + "...",
33
- "Tools": ", ".join(toolnames),
34
- "Prediction": pipe(user_input, max_new_tokens=64)[0]["generated_text"]
35
- })
36
-
37
- return f"""
38
- ✅ Evaluation Metrics:
39
- - Instruction Accuracy: {inst_acc}%
40
- - Tool Selection Accuracy: {tool_acc}%
41
- - Summary Accuracy: {summ_acc}%
42
- """, output_rows
43
-
44
- except Exception as e:
45
- return f"❌ Error loading model or generating output: {e}", []
46
 
47
- # Gradio UI
48
- with gr.Blocks() as demo:
49
- gr.Markdown("## 🛠 GTA Benchmark Simulator (Hugging Face Model)")
50
- model_input = gr.Textbox(label="Enter Hugging Face model name", placeholder="e.g., Qwen/Qwen2.5-3B")
51
- run_btn = gr.Button("Run Evaluation")
52
- results = gr.Textbox(label="Evaluation Results")
53
- table = gr.Dataframe(headers=["Query", "Tools", "Prediction"], wrap=True)
 
 
 
 
54
 
55
- run_btn.click(fn=evaluate_model, inputs=model_input, outputs=[results, table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- demo.launch()
 
1
  import os
2
+ import requests
3
+ from huggingface_hub import login, hf_hub_url
4
  from datasets import load_dataset
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import gradio as gr
8
+ from transformers import pipeline
9
+
10
+ # Authenticate using HF token
11
+ login(token=os.environ["HF_TOKEN"])
12
+
13
+ # Helper to resolve image path
14
+ def resolve_image_url(path):
15
+ return hf_hub_url(repo_id="Jize1/GTA", filename=path, repo_type="dataset")
16
 
17
+ # Download image from HF hub with token
18
+ def download_image(url):
19
+ headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
20
+ response = requests.get(url, headers=headers)
21
+ image = Image.open(BytesIO(response.content)).convert("RGB")
22
+ return image
23
 
24
+ # Load GTA dataset
25
+ print("Loading GTA dataset...")
26
+ gta_data = load_dataset("Jize1/GTA", split="train", use_auth_token=True)
27
 
28
+ # Load image captioning and OCR pipelines
29
+ print("Loading vision models...")
30
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
31
+ ocr_pipeline = pipeline("image-classification", model="microsoft/dit-base-finetuned-iiit5k") # placeholder OCR
32
 
 
33
  def evaluate_model(model_name):
34
+ total = 0
35
+ inst_acc = 0
36
+ tool_acc = 0
37
+ summ_acc = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ for example in gta_data.select(range(10)): # limit to 10 for demo
40
+ dialogs = example["dialogs"]
41
+ gt_answer = example["gt_answer"]
42
+
43
+ user_query = dialogs[0]["content"]
44
+ files = example["files"]
45
+ tool_calls = [d for d in dialogs if d.get("tool_calls")]
46
+
47
+ image_path = files[0]["path"]
48
+ image_url = resolve_image_url(image_path)
49
+ image = download_image(image_url)
50
 
51
+ # Fake tool execution: use captioner/ocr based on tool type
52
+ result = ""
53
+ for tool_call in tool_calls:
54
+ tool = tool_call["tool_calls"][0]["function"]["name"]
55
+ if tool == "ImageDescription":
56
+ caption = image_captioner(image)[0]["generated_text"]
57
+ result += f"[Caption] {caption}\n"
58
+ elif tool == "OCR":
59
+ result += f"[OCR] dummy OCR result for {image_path}\n"
60
+ elif tool == "CountGivenObject":
61
+ result += f"[Count] dummy count result\n"
62
+
63
+ # Simulate metrics
64
+ inst_acc += 1
65
+ tool_acc += 1 if len(tool_calls) > 0 else 0
66
+ summ_acc += 1 if gt_answer["whitelist"] else 0
67
+ total += 1
68
+
69
+ return {
70
+ "InstAcc": round(inst_acc / total * 100, 2),
71
+ "ToolAcc": round(tool_acc / total * 100, 2),
72
+ "SummAcc": round(summ_acc / total * 100, 2)
73
+ }
74
+
75
+
76
+ def run_evaluation(model_name):
77
+ results = evaluate_model(model_name)
78
+ return f"Results for {model_name}:\n" + "\n".join(f"{k}: {v}%" for k, v in results.items())
79
+
80
+ # Gradio UI
81
+ demo = gr.Interface(
82
+ fn=run_evaluation,
83
+ inputs=gr.Textbox(label="Hugging Face Model Name", placeholder="e.g. Qwen/Qwen2.5-3B"),
84
+ outputs=gr.Textbox(label="GTA Evaluation Metrics"),
85
+ title="GTA LLM Evaluation",
86
+ description="Enter a model name from Hugging Face to simulate tool use and get GTA-style metrics.",
87
+ allow_flagging="never"
88
+ )
89
 
90
+ demo.launch()