Spaces:
Sleeping
Sleeping
Commit
Β·
5800fed
1
Parent(s):
a168d8d
answer caching implemented via hf datasets
Browse files- app.py +129 -36
- requirements.txt +3 -1
app.py
CHANGED
@@ -4,6 +4,8 @@ import requests
|
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
import json
|
|
|
|
|
7 |
from gaia_agent import GaiaAgent
|
8 |
|
9 |
# (Keep Constants as is)
|
@@ -13,32 +15,108 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
13 |
# To check if we are running locally
|
14 |
running_on_hf = bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
|
15 |
|
16 |
-
#
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def load_answers_cache():
|
20 |
-
"""Load cached answers from
|
|
|
|
|
|
|
21 |
try:
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
except Exception as e:
|
26 |
-
print(f"
|
27 |
-
|
28 |
|
29 |
def save_answers_cache(cache):
|
30 |
-
"""Save cached answers to
|
|
|
|
|
|
|
31 |
try:
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return True
|
|
|
35 |
except Exception as e:
|
36 |
-
print(f"Error saving cache: {e}")
|
37 |
return False
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
40 |
"""
|
41 |
-
Runs agent on questions and caches correct
|
42 |
"""
|
43 |
if not running_on_hf:
|
44 |
return "Caching only available on HuggingFace Spaces", None
|
@@ -64,15 +142,14 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
|
64 |
except Exception as e:
|
65 |
return f"Error fetching questions: {e}", None
|
66 |
|
67 |
-
# 3. Load existing cache
|
68 |
cache = load_answers_cache()
|
69 |
|
70 |
-
# 4. Run agent on
|
71 |
results_log = []
|
72 |
-
|
73 |
-
new_answers = 0
|
74 |
|
75 |
-
for idx in
|
76 |
if idx >= len(questions_data):
|
77 |
continue
|
78 |
|
@@ -83,13 +160,13 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
|
83 |
if not task_id or question_text is None:
|
84 |
continue
|
85 |
|
86 |
-
# Skip if already cached
|
87 |
if task_id in cache:
|
88 |
results_log.append({
|
89 |
"Task ID": task_id,
|
90 |
"Question": question_text[:100] + "...",
|
91 |
"Answer": cache[task_id],
|
92 |
-
"Status": "CACHED"
|
93 |
})
|
94 |
continue
|
95 |
|
@@ -97,15 +174,17 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
|
97 |
print(f"Processing question {idx+1}: {question_text[:100]}...")
|
98 |
submitted_answer = agent(question_text)
|
99 |
|
100 |
-
#
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
|
104 |
results_log.append({
|
105 |
"Task ID": task_id,
|
106 |
"Question": question_text[:100] + "...",
|
107 |
"Answer": submitted_answer,
|
108 |
-
"Status": "
|
109 |
})
|
110 |
|
111 |
except Exception as e:
|
@@ -113,17 +192,34 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
|
113 |
"Task ID": task_id,
|
114 |
"Question": question_text[:100] + "...",
|
115 |
"Answer": f"ERROR: {e}",
|
116 |
-
"Status": "FAILED"
|
117 |
})
|
118 |
|
119 |
-
# 5.
|
120 |
-
if
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
else:
|
124 |
-
status = f"
|
125 |
else:
|
126 |
-
status = "All target questions already cached
|
127 |
|
128 |
return status, pd.DataFrame(results_log)
|
129 |
|
@@ -241,11 +337,8 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
241 |
results_log = []
|
242 |
answers_payload = []
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
print(f"Running agent on {len(solvable_indices)} solvable questions...")
|
248 |
-
for idx in solvable_indices:
|
249 |
if idx >= len(questions_data):
|
250 |
continue
|
251 |
item = questions_data[idx]
|
|
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
import json
|
7 |
+
from datasets import Dataset
|
8 |
+
from huggingface_hub import HfApi
|
9 |
from gaia_agent import GaiaAgent
|
10 |
|
11 |
# (Keep Constants as is)
|
|
|
15 |
# To check if we are running locally
|
16 |
running_on_hf = bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
|
17 |
|
18 |
+
# Questions the agent can reliably solve (no images, audio, video)
|
19 |
+
SOLVABLE_INDICES = [0, 2, 4] # Mercedes Sosa, Reversed text, Dinosaur Featured Article
|
20 |
+
|
21 |
+
def get_dataset_name():
|
22 |
+
"""Get the private dataset name for this space"""
|
23 |
+
space_id = os.getenv("SPACE_ID")
|
24 |
+
if space_id:
|
25 |
+
return f"{space_id.replace('/', '--')}-gaia-answers"
|
26 |
+
return "gaia-answers-cache"
|
27 |
|
28 |
def load_answers_cache():
|
29 |
+
"""Load cached answers from HuggingFace Dataset"""
|
30 |
+
if not running_on_hf:
|
31 |
+
return {}
|
32 |
+
|
33 |
try:
|
34 |
+
dataset_name = get_dataset_name()
|
35 |
+
dataset = Dataset.load_from_hub(dataset_name, split="train")
|
36 |
+
|
37 |
+
# Convert back to dictionary
|
38 |
+
cache = {}
|
39 |
+
if len(dataset) > 0:
|
40 |
+
for item in dataset:
|
41 |
+
cache[item["task_id"]] = item["answer"]
|
42 |
+
|
43 |
+
print(f"β
Loaded {len(cache)} cached answers from dataset: {dataset_name}")
|
44 |
+
return cache
|
45 |
+
|
46 |
except Exception as e:
|
47 |
+
print(f"π No existing cache found (will create new): {e}")
|
48 |
+
return {}
|
49 |
|
50 |
def save_answers_cache(cache):
|
51 |
+
"""Save cached answers to HuggingFace Dataset"""
|
52 |
+
if not running_on_hf or not cache:
|
53 |
+
return False
|
54 |
+
|
55 |
try:
|
56 |
+
dataset_name = get_dataset_name()
|
57 |
+
|
58 |
+
# Convert dictionary to dataset format
|
59 |
+
data = {
|
60 |
+
"task_id": list(cache.keys()),
|
61 |
+
"answer": list(cache.values())
|
62 |
+
}
|
63 |
+
|
64 |
+
dataset = Dataset.from_dict(data)
|
65 |
+
dataset.push_to_hub(dataset_name, private=True)
|
66 |
+
|
67 |
+
print(f"πΎ Saved {len(cache)} answers to private dataset: {dataset_name}")
|
68 |
return True
|
69 |
+
|
70 |
except Exception as e:
|
71 |
+
print(f"Error saving cache to dataset: {e}")
|
72 |
return False
|
73 |
|
74 |
+
def check_answers_correctness(answers_payload, questions_data):
|
75 |
+
"""
|
76 |
+
Submit answers to get correctness feedback and return which ones were correct
|
77 |
+
"""
|
78 |
+
if not running_on_hf:
|
79 |
+
return {}
|
80 |
+
|
81 |
+
try:
|
82 |
+
# Prepare minimal submission for validation
|
83 |
+
space_id = os.getenv("SPACE_ID")
|
84 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
85 |
+
|
86 |
+
submission_data = {
|
87 |
+
"username": "validation_check",
|
88 |
+
"agent_code": agent_code,
|
89 |
+
"answers": answers_payload
|
90 |
+
}
|
91 |
+
|
92 |
+
api_url = DEFAULT_API_URL
|
93 |
+
submit_url = f"{api_url}/submit"
|
94 |
+
|
95 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
96 |
+
response.raise_for_status()
|
97 |
+
result_data = response.json()
|
98 |
+
|
99 |
+
# Parse which answers were correct
|
100 |
+
correct_answers = {}
|
101 |
+
if "detailed_results" in result_data:
|
102 |
+
for result in result_data["detailed_results"]:
|
103 |
+
if result.get("correct", False):
|
104 |
+
task_id = result.get("task_id")
|
105 |
+
# Find the corresponding answer
|
106 |
+
for answer in answers_payload:
|
107 |
+
if answer["task_id"] == task_id:
|
108 |
+
correct_answers[task_id] = answer["submitted_answer"]
|
109 |
+
break
|
110 |
+
|
111 |
+
return correct_answers
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Error checking answer correctness: {e}")
|
115 |
+
return {}
|
116 |
+
|
117 |
def run_and_cache_answers(profile: gr.OAuthProfile | None):
|
118 |
"""
|
119 |
+
Runs agent on questions, validates answers, and caches only correct ones
|
120 |
"""
|
121 |
if not running_on_hf:
|
122 |
return "Caching only available on HuggingFace Spaces", None
|
|
|
142 |
except Exception as e:
|
143 |
return f"Error fetching questions: {e}", None
|
144 |
|
145 |
+
# 3. Load existing cache (verified correct answers)
|
146 |
cache = load_answers_cache()
|
147 |
|
148 |
+
# 4. Run agent only on unsolved questions
|
149 |
results_log = []
|
150 |
+
new_answers_payload = []
|
|
|
151 |
|
152 |
+
for idx in SOLVABLE_INDICES:
|
153 |
if idx >= len(questions_data):
|
154 |
continue
|
155 |
|
|
|
160 |
if not task_id or question_text is None:
|
161 |
continue
|
162 |
|
163 |
+
# Skip if already have correct answer cached
|
164 |
if task_id in cache:
|
165 |
results_log.append({
|
166 |
"Task ID": task_id,
|
167 |
"Question": question_text[:100] + "...",
|
168 |
"Answer": cache[task_id],
|
169 |
+
"Status": "β
CORRECT (CACHED)"
|
170 |
})
|
171 |
continue
|
172 |
|
|
|
174 |
print(f"Processing question {idx+1}: {question_text[:100]}...")
|
175 |
submitted_answer = agent(question_text)
|
176 |
|
177 |
+
# Add to payload for validation
|
178 |
+
new_answers_payload.append({
|
179 |
+
"task_id": task_id,
|
180 |
+
"submitted_answer": submitted_answer
|
181 |
+
})
|
182 |
|
183 |
results_log.append({
|
184 |
"Task ID": task_id,
|
185 |
"Question": question_text[:100] + "...",
|
186 |
"Answer": submitted_answer,
|
187 |
+
"Status": "π VALIDATING..."
|
188 |
})
|
189 |
|
190 |
except Exception as e:
|
|
|
192 |
"Task ID": task_id,
|
193 |
"Question": question_text[:100] + "...",
|
194 |
"Answer": f"ERROR: {e}",
|
195 |
+
"Status": "β FAILED"
|
196 |
})
|
197 |
|
198 |
+
# 5. Validate new answers and cache only correct ones
|
199 |
+
if new_answers_payload:
|
200 |
+
print(f"π Validating {len(new_answers_payload)} new answers...")
|
201 |
+
correct_answers = check_answers_correctness(new_answers_payload, questions_data)
|
202 |
+
|
203 |
+
# Update cache with only correct answers
|
204 |
+
cache.update(correct_answers)
|
205 |
+
|
206 |
+
# Update results log with validation results
|
207 |
+
for log_entry in results_log:
|
208 |
+
if log_entry["Status"] == "π VALIDATING...":
|
209 |
+
task_id = log_entry["Task ID"]
|
210 |
+
if task_id in correct_answers:
|
211 |
+
log_entry["Status"] = "β
CORRECT (NEW)"
|
212 |
+
else:
|
213 |
+
log_entry["Status"] = "β INCORRECT"
|
214 |
+
|
215 |
+
# Save updated cache
|
216 |
+
if correct_answers:
|
217 |
+
save_answers_cache(cache)
|
218 |
+
status = f"π Validated {len(new_answers_payload)} answers. Cached {len(correct_answers)} correct answers!"
|
219 |
else:
|
220 |
+
status = f"π Validated {len(new_answers_payload)} answers. None were correct this time."
|
221 |
else:
|
222 |
+
status = "All target questions already have correct answers cached!"
|
223 |
|
224 |
return status, pd.DataFrame(results_log)
|
225 |
|
|
|
337 |
results_log = []
|
338 |
answers_payload = []
|
339 |
|
340 |
+
print(f"Running agent on {len(SOLVABLE_INDICES)} solvable questions...")
|
341 |
+
for idx in SOLVABLE_INDICES:
|
|
|
|
|
|
|
342 |
if idx >= len(questions_data):
|
343 |
continue
|
344 |
item = questions_data[idx]
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ requests
|
|
3 |
smolagents
|
4 |
duckduckgo-search
|
5 |
openai
|
6 |
-
wikipedia
|
|
|
|
|
|
3 |
smolagents
|
4 |
duckduckgo-search
|
5 |
openai
|
6 |
+
wikipedia
|
7 |
+
datasets
|
8 |
+
huggingface_hub
|