ArturoNereu commited on
Commit
5800fed
Β·
1 Parent(s): a168d8d

answer caching implemented via hf datasets

Browse files
Files changed (2) hide show
  1. app.py +129 -36
  2. requirements.txt +3 -1
app.py CHANGED
@@ -4,6 +4,8 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  import json
 
 
7
  from gaia_agent import GaiaAgent
8
 
9
  # (Keep Constants as is)
@@ -13,32 +15,108 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
  # To check if we are running locally
14
  running_on_hf = bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
15
 
16
- # Cache file for storing correct answers
17
- CACHE_FILE = "answers_cache.json"
 
 
 
 
 
 
 
18
 
19
  def load_answers_cache():
20
- """Load cached answers from file"""
 
 
 
21
  try:
22
- if os.path.exists(CACHE_FILE):
23
- with open(CACHE_FILE, 'r') as f:
24
- return json.load(f)
 
 
 
 
 
 
 
 
 
25
  except Exception as e:
26
- print(f"Error loading cache: {e}")
27
- return {}
28
 
29
  def save_answers_cache(cache):
30
- """Save cached answers to file"""
 
 
 
31
  try:
32
- with open(CACHE_FILE, 'w') as f:
33
- json.dump(cache, f, indent=2)
 
 
 
 
 
 
 
 
 
 
34
  return True
 
35
  except Exception as e:
36
- print(f"Error saving cache: {e}")
37
  return False
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def run_and_cache_answers(profile: gr.OAuthProfile | None):
40
  """
41
- Runs agent on questions and caches correct answers for later submission
42
  """
43
  if not running_on_hf:
44
  return "Caching only available on HuggingFace Spaces", None
@@ -64,15 +142,14 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
64
  except Exception as e:
65
  return f"Error fetching questions: {e}", None
66
 
67
- # 3. Load existing cache
68
  cache = load_answers_cache()
69
 
70
- # 4. Run agent on solvable questions
71
  results_log = []
72
- solvable_indices = [0, 2, 4] # Focus on proven questions
73
- new_answers = 0
74
 
75
- for idx in solvable_indices:
76
  if idx >= len(questions_data):
77
  continue
78
 
@@ -83,13 +160,13 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
83
  if not task_id or question_text is None:
84
  continue
85
 
86
- # Skip if already cached
87
  if task_id in cache:
88
  results_log.append({
89
  "Task ID": task_id,
90
  "Question": question_text[:100] + "...",
91
  "Answer": cache[task_id],
92
- "Status": "CACHED"
93
  })
94
  continue
95
 
@@ -97,15 +174,17 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
97
  print(f"Processing question {idx+1}: {question_text[:100]}...")
98
  submitted_answer = agent(question_text)
99
 
100
- # Cache the answer (we'll validate it later)
101
- cache[task_id] = submitted_answer
102
- new_answers += 1
 
 
103
 
104
  results_log.append({
105
  "Task ID": task_id,
106
  "Question": question_text[:100] + "...",
107
  "Answer": submitted_answer,
108
- "Status": "NEW"
109
  })
110
 
111
  except Exception as e:
@@ -113,17 +192,34 @@ def run_and_cache_answers(profile: gr.OAuthProfile | None):
113
  "Task ID": task_id,
114
  "Question": question_text[:100] + "...",
115
  "Answer": f"ERROR: {e}",
116
- "Status": "FAILED"
117
  })
118
 
119
- # 5. Save updated cache
120
- if new_answers > 0:
121
- if save_answers_cache(cache):
122
- status = f"βœ… Processed {len(solvable_indices)} questions. Added {new_answers} new answers to cache."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  else:
124
- status = f"⚠️ Generated {new_answers} answers but failed to save cache."
125
  else:
126
- status = "All target questions already cached."
127
 
128
  return status, pd.DataFrame(results_log)
129
 
@@ -241,11 +337,8 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
241
  results_log = []
242
  answers_payload = []
243
 
244
- # Focus on the 3 questions we know work correctly
245
- solvable_indices = [0, 2, 4] # Mercedes Sosa, Reversed text, Dinosaur Featured Article
246
-
247
- print(f"Running agent on {len(solvable_indices)} solvable questions...")
248
- for idx in solvable_indices:
249
  if idx >= len(questions_data):
250
  continue
251
  item = questions_data[idx]
 
4
  import inspect
5
  import pandas as pd
6
  import json
7
+ from datasets import Dataset
8
+ from huggingface_hub import HfApi
9
  from gaia_agent import GaiaAgent
10
 
11
  # (Keep Constants as is)
 
15
  # To check if we are running locally
16
  running_on_hf = bool(os.getenv("SPACE_ID") or os.getenv("SPACE_HOST"))
17
 
18
+ # Questions the agent can reliably solve (no images, audio, video)
19
+ SOLVABLE_INDICES = [0, 2, 4] # Mercedes Sosa, Reversed text, Dinosaur Featured Article
20
+
21
+ def get_dataset_name():
22
+ """Get the private dataset name for this space"""
23
+ space_id = os.getenv("SPACE_ID")
24
+ if space_id:
25
+ return f"{space_id.replace('/', '--')}-gaia-answers"
26
+ return "gaia-answers-cache"
27
 
28
  def load_answers_cache():
29
+ """Load cached answers from HuggingFace Dataset"""
30
+ if not running_on_hf:
31
+ return {}
32
+
33
  try:
34
+ dataset_name = get_dataset_name()
35
+ dataset = Dataset.load_from_hub(dataset_name, split="train")
36
+
37
+ # Convert back to dictionary
38
+ cache = {}
39
+ if len(dataset) > 0:
40
+ for item in dataset:
41
+ cache[item["task_id"]] = item["answer"]
42
+
43
+ print(f"βœ… Loaded {len(cache)} cached answers from dataset: {dataset_name}")
44
+ return cache
45
+
46
  except Exception as e:
47
+ print(f"πŸ“ No existing cache found (will create new): {e}")
48
+ return {}
49
 
50
  def save_answers_cache(cache):
51
+ """Save cached answers to HuggingFace Dataset"""
52
+ if not running_on_hf or not cache:
53
+ return False
54
+
55
  try:
56
+ dataset_name = get_dataset_name()
57
+
58
+ # Convert dictionary to dataset format
59
+ data = {
60
+ "task_id": list(cache.keys()),
61
+ "answer": list(cache.values())
62
+ }
63
+
64
+ dataset = Dataset.from_dict(data)
65
+ dataset.push_to_hub(dataset_name, private=True)
66
+
67
+ print(f"πŸ’Ύ Saved {len(cache)} answers to private dataset: {dataset_name}")
68
  return True
69
+
70
  except Exception as e:
71
+ print(f"Error saving cache to dataset: {e}")
72
  return False
73
 
74
+ def check_answers_correctness(answers_payload, questions_data):
75
+ """
76
+ Submit answers to get correctness feedback and return which ones were correct
77
+ """
78
+ if not running_on_hf:
79
+ return {}
80
+
81
+ try:
82
+ # Prepare minimal submission for validation
83
+ space_id = os.getenv("SPACE_ID")
84
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
85
+
86
+ submission_data = {
87
+ "username": "validation_check",
88
+ "agent_code": agent_code,
89
+ "answers": answers_payload
90
+ }
91
+
92
+ api_url = DEFAULT_API_URL
93
+ submit_url = f"{api_url}/submit"
94
+
95
+ response = requests.post(submit_url, json=submission_data, timeout=60)
96
+ response.raise_for_status()
97
+ result_data = response.json()
98
+
99
+ # Parse which answers were correct
100
+ correct_answers = {}
101
+ if "detailed_results" in result_data:
102
+ for result in result_data["detailed_results"]:
103
+ if result.get("correct", False):
104
+ task_id = result.get("task_id")
105
+ # Find the corresponding answer
106
+ for answer in answers_payload:
107
+ if answer["task_id"] == task_id:
108
+ correct_answers[task_id] = answer["submitted_answer"]
109
+ break
110
+
111
+ return correct_answers
112
+
113
+ except Exception as e:
114
+ print(f"Error checking answer correctness: {e}")
115
+ return {}
116
+
117
  def run_and_cache_answers(profile: gr.OAuthProfile | None):
118
  """
119
+ Runs agent on questions, validates answers, and caches only correct ones
120
  """
121
  if not running_on_hf:
122
  return "Caching only available on HuggingFace Spaces", None
 
142
  except Exception as e:
143
  return f"Error fetching questions: {e}", None
144
 
145
+ # 3. Load existing cache (verified correct answers)
146
  cache = load_answers_cache()
147
 
148
+ # 4. Run agent only on unsolved questions
149
  results_log = []
150
+ new_answers_payload = []
 
151
 
152
+ for idx in SOLVABLE_INDICES:
153
  if idx >= len(questions_data):
154
  continue
155
 
 
160
  if not task_id or question_text is None:
161
  continue
162
 
163
+ # Skip if already have correct answer cached
164
  if task_id in cache:
165
  results_log.append({
166
  "Task ID": task_id,
167
  "Question": question_text[:100] + "...",
168
  "Answer": cache[task_id],
169
+ "Status": "βœ… CORRECT (CACHED)"
170
  })
171
  continue
172
 
 
174
  print(f"Processing question {idx+1}: {question_text[:100]}...")
175
  submitted_answer = agent(question_text)
176
 
177
+ # Add to payload for validation
178
+ new_answers_payload.append({
179
+ "task_id": task_id,
180
+ "submitted_answer": submitted_answer
181
+ })
182
 
183
  results_log.append({
184
  "Task ID": task_id,
185
  "Question": question_text[:100] + "...",
186
  "Answer": submitted_answer,
187
+ "Status": "πŸ”„ VALIDATING..."
188
  })
189
 
190
  except Exception as e:
 
192
  "Task ID": task_id,
193
  "Question": question_text[:100] + "...",
194
  "Answer": f"ERROR: {e}",
195
+ "Status": "❌ FAILED"
196
  })
197
 
198
+ # 5. Validate new answers and cache only correct ones
199
+ if new_answers_payload:
200
+ print(f"πŸ” Validating {len(new_answers_payload)} new answers...")
201
+ correct_answers = check_answers_correctness(new_answers_payload, questions_data)
202
+
203
+ # Update cache with only correct answers
204
+ cache.update(correct_answers)
205
+
206
+ # Update results log with validation results
207
+ for log_entry in results_log:
208
+ if log_entry["Status"] == "πŸ”„ VALIDATING...":
209
+ task_id = log_entry["Task ID"]
210
+ if task_id in correct_answers:
211
+ log_entry["Status"] = "βœ… CORRECT (NEW)"
212
+ else:
213
+ log_entry["Status"] = "❌ INCORRECT"
214
+
215
+ # Save updated cache
216
+ if correct_answers:
217
+ save_answers_cache(cache)
218
+ status = f"πŸŽ‰ Validated {len(new_answers_payload)} answers. Cached {len(correct_answers)} correct answers!"
219
  else:
220
+ status = f"πŸ˜” Validated {len(new_answers_payload)} answers. None were correct this time."
221
  else:
222
+ status = "All target questions already have correct answers cached!"
223
 
224
  return status, pd.DataFrame(results_log)
225
 
 
337
  results_log = []
338
  answers_payload = []
339
 
340
+ print(f"Running agent on {len(SOLVABLE_INDICES)} solvable questions...")
341
+ for idx in SOLVABLE_INDICES:
 
 
 
342
  if idx >= len(questions_data):
343
  continue
344
  item = questions_data[idx]
requirements.txt CHANGED
@@ -3,4 +3,6 @@ requests
3
  smolagents
4
  duckduckgo-search
5
  openai
6
- wikipedia
 
 
 
3
  smolagents
4
  duckduckgo-search
5
  openai
6
+ wikipedia
7
+ datasets
8
+ huggingface_hub