dawid-lorek commited on
Commit
99134fe
·
verified ·
1 Parent(s): bd03e7f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +12 -55
agent.py CHANGED
@@ -2,9 +2,8 @@ import os
2
  import io
3
  import re
4
  import requests
5
- import mimetypes
6
- import subprocess
7
  import tempfile
 
8
  from openai import OpenAI
9
  from duckduckgo_search import DDGS
10
  from PIL import Image
@@ -34,15 +33,15 @@ def run_web_search(query, max_results=3):
34
  try:
35
  ddgs = DDGS()
36
  results = ddgs.text(query)
 
37
  for i, r in enumerate(results):
38
  if i >= max_results:
39
  break
40
- # Prefer summary/body if available
41
  if r.get('body'):
42
- return r['body']
43
  elif r.get('title'):
44
- return r['title']
45
- return ""
46
  except Exception:
47
  return ""
48
 
@@ -134,63 +133,38 @@ def guess_youtube_link(question):
134
  return None
135
 
136
  def format_gaia_answer(answer, question=None):
137
- """Enforces strict GAIA benchmark answer formatting rules."""
138
  if not answer or not isinstance(answer, str):
139
  return ""
140
-
141
- # Remove apologies and boilerplate
142
  answer = re.sub(r"(?i)i'?m sorry[,\.]?|i cannot|i can't|unable to|please provide.*|information not available|I can't assist.*|I'm unable.*|process the file directly", "", answer)
143
- answer = answer.strip()
144
-
145
- # Remove "Final Answer:" and similar prefixes
146
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
147
-
148
- # Remove enclosing quotes/brackets
149
  if answer.startswith('"') and answer.endswith('"'):
150
  answer = answer[1:-1]
151
  if answer.startswith('[') and answer.endswith(']'):
152
  answer = answer[1:-1]
153
-
154
- # Remove period at end unless part of the answer (like "Indeed.")
155
  if not re.match(r'^[A-Za-z]+\.$', answer):
156
  answer = re.sub(r'\.$', '', answer)
157
 
158
- # For specific answer types:
159
  if question:
160
- # Numeric answer only
161
  if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
162
  num_match = re.search(r'(\$?\d[\d,\.]*)', answer)
163
  if num_match:
164
  return num_match.group(1).replace(',', '')
165
 
166
- # Only first name (e.g. Malko, Magda M)
167
  if re.search(r'first name', question, re.I):
168
- first = answer.strip().split()[0]
169
- return first
170
-
171
- # Only surname
172
  if re.search(r'surname', question, re.I):
173
- surname = answer.strip().split()[-1]
174
- return surname
175
-
176
- # Only city
177
  if re.search(r'city', question, re.I):
178
- city = answer.strip().split()[0]
179
- return city
180
-
181
- # Only code (Olympics, NASA award)
182
  if re.search(r'IOC country code|award number|NASA', question, re.I):
183
  code_match = re.search(r'[A-Z0-9]{3,}', answer)
184
  if code_match:
185
  return code_match.group(0)
186
-
187
- # Only algebraic move (chess)
188
  if 'algebraic notation' in question or 'chess' in question:
189
  move_match = re.search(r'[A-Za-z0-9]+[#\+]?$', answer)
190
  if move_match:
191
  return move_match.group(0)
192
-
193
- # Direct quote (Teal'c)
194
  if "what does teal'c say" in question.lower():
195
  qmatch = re.search(r'"(Indeed\.)"', answer)
196
  if qmatch:
@@ -198,27 +172,19 @@ def format_gaia_answer(answer, question=None):
198
  if "Indeed." in answer:
199
  return "Indeed."
200
  return answer
201
-
202
- # For lists (ingredients, vegetables, page numbers, etc)
203
  if re.search(r'list|comma.*separated|page numbers', question, re.I):
204
- # Extract all possible meaningful phrases
205
  items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
206
- # Remove likely non-items (like "and", "or", etc.)
207
  items = [item for item in items if item and not re.match(r'(and|or|to|with|for|a|the)$', item)]
208
- # For page numbers, sort as int
209
  if 'page numbers' in question:
210
  nums = [int(x) for x in re.findall(r'\d+', answer)]
211
  return ', '.join(str(n) for n in sorted(nums))
212
- # For vegetables, ingredients, etc. sort alpha
213
  if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
214
- # merge multi-word items split by commas (heuristic)
215
  merged = []
216
  skip = False
217
  for i, item in enumerate(items):
218
  if skip:
219
  skip = False
220
  continue
221
- # Try to merge known phrases (e.g., "sweet potatoes", "green beans", etc.)
222
  if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
223
  merged.append(f"{item} {items[i+1]}")
224
  skip = True
@@ -227,13 +193,10 @@ def format_gaia_answer(answer, question=None):
227
  merged = sorted(set(merged))
228
  return ', '.join(merged)
229
  return ', '.join(items)
230
-
231
- # Only last names for pitchers (before/after)
232
  if re.search(r'pitcher.*before.*after', question, re.I):
233
  names = re.findall(r'\b[A-Z][a-z]+', answer)
234
  return ', '.join(names[:2])
235
 
236
- # Generic fallback
237
  return answer.strip().rstrip('.').strip()
238
 
239
  class GaiaAgent:
@@ -249,10 +212,10 @@ class GaiaAgent:
249
  "Always output the answer only—no explanations, no extra text."
250
  )
251
 
252
- def answer_with_tools(self, question, task_id):
253
  file_text = ""
254
  prompt_parts = [self.instructions]
255
- # 1. File handling (image, Excel, CSV, PDF, text, audio)
256
  if task_id:
257
  file_bytes, content_type = fetch_file(task_id)
258
  if file_bytes and content_type:
@@ -265,7 +228,7 @@ class GaiaAgent:
265
  transcript = transcribe_youtube_audio(youtube_url)
266
  if transcript:
267
  prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
268
- # 3. Web search fallback if not enough info
269
  search_needed = not file_text and not youtube_url
270
  search_keywords = [
271
  "who", "what", "when", "where", "name", "number", "how many",
@@ -275,13 +238,8 @@ class GaiaAgent:
275
  search_results = run_web_search(question)
276
  if search_results:
277
  prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
278
- # 4. Compose prompt
279
  prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
280
  prompt = "\n".join(prompt_parts)
281
- return prompt
282
-
283
- def __call__(self, question: str, task_id: str = None) -> str:
284
- prompt = self.answer_with_tools(question, task_id)
285
  response = self.client.chat.completions.create(
286
  model="gpt-4o",
287
  messages=[
@@ -293,7 +251,6 @@ class GaiaAgent:
293
  )
294
  raw_output = safe_strip(response.choices[0].message.content)
295
  formatted = format_gaia_answer(raw_output, question)
296
- # Retry with web search if result is empty or likely incorrect for key factual types
297
  if not formatted or formatted.lower() in ('', 'unknown', 'none', 'n/a') or 'apolog' in formatted.lower():
298
  web_info = run_web_search(question)
299
  if web_info:
 
2
  import io
3
  import re
4
  import requests
 
 
5
  import tempfile
6
+ import subprocess
7
  from openai import OpenAI
8
  from duckduckgo_search import DDGS
9
  from PIL import Image
 
33
  try:
34
  ddgs = DDGS()
35
  results = ddgs.text(query)
36
+ bodies = []
37
  for i, r in enumerate(results):
38
  if i >= max_results:
39
  break
 
40
  if r.get('body'):
41
+ bodies.append(r['body'])
42
  elif r.get('title'):
43
+ bodies.append(r['title'])
44
+ return "\n".join(bodies)
45
  except Exception:
46
  return ""
47
 
 
133
  return None
134
 
135
  def format_gaia_answer(answer, question=None):
 
136
  if not answer or not isinstance(answer, str):
137
  return ""
 
 
138
  answer = re.sub(r"(?i)i'?m sorry[,\.]?|i cannot|i can't|unable to|please provide.*|information not available|I can't assist.*|I'm unable.*|process the file directly", "", answer)
 
 
 
139
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
 
 
140
  if answer.startswith('"') and answer.endswith('"'):
141
  answer = answer[1:-1]
142
  if answer.startswith('[') and answer.endswith(']'):
143
  answer = answer[1:-1]
 
 
144
  if not re.match(r'^[A-Za-z]+\.$', answer):
145
  answer = re.sub(r'\.$', '', answer)
146
 
 
147
  if question:
148
+ # Pure number answers
149
  if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
150
  num_match = re.search(r'(\$?\d[\d,\.]*)', answer)
151
  if num_match:
152
  return num_match.group(1).replace(',', '')
153
 
 
154
  if re.search(r'first name', question, re.I):
155
+ return answer.strip().split()[0]
 
 
 
156
  if re.search(r'surname', question, re.I):
157
+ return answer.strip().split()[-1]
 
 
 
158
  if re.search(r'city', question, re.I):
159
+ return answer.strip().split()[0]
 
 
 
160
  if re.search(r'IOC country code|award number|NASA', question, re.I):
161
  code_match = re.search(r'[A-Z0-9]{3,}', answer)
162
  if code_match:
163
  return code_match.group(0)
 
 
164
  if 'algebraic notation' in question or 'chess' in question:
165
  move_match = re.search(r'[A-Za-z0-9]+[#\+]?$', answer)
166
  if move_match:
167
  return move_match.group(0)
 
 
168
  if "what does teal'c say" in question.lower():
169
  qmatch = re.search(r'"(Indeed\.)"', answer)
170
  if qmatch:
 
172
  if "Indeed." in answer:
173
  return "Indeed."
174
  return answer
 
 
175
  if re.search(r'list|comma.*separated|page numbers', question, re.I):
 
176
  items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
 
177
  items = [item for item in items if item and not re.match(r'(and|or|to|with|for|a|the)$', item)]
 
178
  if 'page numbers' in question:
179
  nums = [int(x) for x in re.findall(r'\d+', answer)]
180
  return ', '.join(str(n) for n in sorted(nums))
 
181
  if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
 
182
  merged = []
183
  skip = False
184
  for i, item in enumerate(items):
185
  if skip:
186
  skip = False
187
  continue
 
188
  if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
189
  merged.append(f"{item} {items[i+1]}")
190
  skip = True
 
193
  merged = sorted(set(merged))
194
  return ', '.join(merged)
195
  return ', '.join(items)
 
 
196
  if re.search(r'pitcher.*before.*after', question, re.I):
197
  names = re.findall(r'\b[A-Z][a-z]+', answer)
198
  return ', '.join(names[:2])
199
 
 
200
  return answer.strip().rstrip('.').strip()
201
 
202
  class GaiaAgent:
 
212
  "Always output the answer only—no explanations, no extra text."
213
  )
214
 
215
+ def __call__(self, question: str, task_id: str = None) -> str:
216
  file_text = ""
217
  prompt_parts = [self.instructions]
218
+ # 1. File (image, Excel, etc)
219
  if task_id:
220
  file_bytes, content_type = fetch_file(task_id)
221
  if file_bytes and content_type:
 
228
  transcript = transcribe_youtube_audio(youtube_url)
229
  if transcript:
230
  prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
231
+ # 3. Web search for open facts
232
  search_needed = not file_text and not youtube_url
233
  search_keywords = [
234
  "who", "what", "when", "where", "name", "number", "how many",
 
238
  search_results = run_web_search(question)
239
  if search_results:
240
  prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
 
241
  prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
242
  prompt = "\n".join(prompt_parts)
 
 
 
 
243
  response = self.client.chat.completions.create(
244
  model="gpt-4o",
245
  messages=[
 
251
  )
252
  raw_output = safe_strip(response.choices[0].message.content)
253
  formatted = format_gaia_answer(raw_output, question)
 
254
  if not formatted or formatted.lower() in ('', 'unknown', 'none', 'n/a') or 'apolog' in formatted.lower():
255
  web_info = run_web_search(question)
256
  if web_info: