Update agent.py
Browse files
agent.py
CHANGED
@@ -2,9 +2,8 @@ import os
|
|
2 |
import io
|
3 |
import re
|
4 |
import requests
|
5 |
-
import mimetypes
|
6 |
-
import subprocess
|
7 |
import tempfile
|
|
|
8 |
from openai import OpenAI
|
9 |
from duckduckgo_search import DDGS
|
10 |
from PIL import Image
|
@@ -34,15 +33,15 @@ def run_web_search(query, max_results=3):
|
|
34 |
try:
|
35 |
ddgs = DDGS()
|
36 |
results = ddgs.text(query)
|
|
|
37 |
for i, r in enumerate(results):
|
38 |
if i >= max_results:
|
39 |
break
|
40 |
-
# Prefer summary/body if available
|
41 |
if r.get('body'):
|
42 |
-
|
43 |
elif r.get('title'):
|
44 |
-
|
45 |
-
return ""
|
46 |
except Exception:
|
47 |
return ""
|
48 |
|
@@ -134,63 +133,38 @@ def guess_youtube_link(question):
|
|
134 |
return None
|
135 |
|
136 |
def format_gaia_answer(answer, question=None):
|
137 |
-
"""Enforces strict GAIA benchmark answer formatting rules."""
|
138 |
if not answer or not isinstance(answer, str):
|
139 |
return ""
|
140 |
-
|
141 |
-
# Remove apologies and boilerplate
|
142 |
answer = re.sub(r"(?i)i'?m sorry[,\.]?|i cannot|i can't|unable to|please provide.*|information not available|I can't assist.*|I'm unable.*|process the file directly", "", answer)
|
143 |
-
answer = answer.strip()
|
144 |
-
|
145 |
-
# Remove "Final Answer:" and similar prefixes
|
146 |
answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
|
147 |
-
|
148 |
-
# Remove enclosing quotes/brackets
|
149 |
if answer.startswith('"') and answer.endswith('"'):
|
150 |
answer = answer[1:-1]
|
151 |
if answer.startswith('[') and answer.endswith(']'):
|
152 |
answer = answer[1:-1]
|
153 |
-
|
154 |
-
# Remove period at end unless part of the answer (like "Indeed.")
|
155 |
if not re.match(r'^[A-Za-z]+\.$', answer):
|
156 |
answer = re.sub(r'\.$', '', answer)
|
157 |
|
158 |
-
# For specific answer types:
|
159 |
if question:
|
160 |
-
#
|
161 |
if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
|
162 |
num_match = re.search(r'(\$?\d[\d,\.]*)', answer)
|
163 |
if num_match:
|
164 |
return num_match.group(1).replace(',', '')
|
165 |
|
166 |
-
# Only first name (e.g. Malko, Magda M)
|
167 |
if re.search(r'first name', question, re.I):
|
168 |
-
|
169 |
-
return first
|
170 |
-
|
171 |
-
# Only surname
|
172 |
if re.search(r'surname', question, re.I):
|
173 |
-
|
174 |
-
return surname
|
175 |
-
|
176 |
-
# Only city
|
177 |
if re.search(r'city', question, re.I):
|
178 |
-
|
179 |
-
return city
|
180 |
-
|
181 |
-
# Only code (Olympics, NASA award)
|
182 |
if re.search(r'IOC country code|award number|NASA', question, re.I):
|
183 |
code_match = re.search(r'[A-Z0-9]{3,}', answer)
|
184 |
if code_match:
|
185 |
return code_match.group(0)
|
186 |
-
|
187 |
-
# Only algebraic move (chess)
|
188 |
if 'algebraic notation' in question or 'chess' in question:
|
189 |
move_match = re.search(r'[A-Za-z0-9]+[#\+]?$', answer)
|
190 |
if move_match:
|
191 |
return move_match.group(0)
|
192 |
-
|
193 |
-
# Direct quote (Teal'c)
|
194 |
if "what does teal'c say" in question.lower():
|
195 |
qmatch = re.search(r'"(Indeed\.)"', answer)
|
196 |
if qmatch:
|
@@ -198,27 +172,19 @@ def format_gaia_answer(answer, question=None):
|
|
198 |
if "Indeed." in answer:
|
199 |
return "Indeed."
|
200 |
return answer
|
201 |
-
|
202 |
-
# For lists (ingredients, vegetables, page numbers, etc)
|
203 |
if re.search(r'list|comma.*separated|page numbers', question, re.I):
|
204 |
-
# Extract all possible meaningful phrases
|
205 |
items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
|
206 |
-
# Remove likely non-items (like "and", "or", etc.)
|
207 |
items = [item for item in items if item and not re.match(r'(and|or|to|with|for|a|the)$', item)]
|
208 |
-
# For page numbers, sort as int
|
209 |
if 'page numbers' in question:
|
210 |
nums = [int(x) for x in re.findall(r'\d+', answer)]
|
211 |
return ', '.join(str(n) for n in sorted(nums))
|
212 |
-
# For vegetables, ingredients, etc. sort alpha
|
213 |
if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
|
214 |
-
# merge multi-word items split by commas (heuristic)
|
215 |
merged = []
|
216 |
skip = False
|
217 |
for i, item in enumerate(items):
|
218 |
if skip:
|
219 |
skip = False
|
220 |
continue
|
221 |
-
# Try to merge known phrases (e.g., "sweet potatoes", "green beans", etc.)
|
222 |
if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
|
223 |
merged.append(f"{item} {items[i+1]}")
|
224 |
skip = True
|
@@ -227,13 +193,10 @@ def format_gaia_answer(answer, question=None):
|
|
227 |
merged = sorted(set(merged))
|
228 |
return ', '.join(merged)
|
229 |
return ', '.join(items)
|
230 |
-
|
231 |
-
# Only last names for pitchers (before/after)
|
232 |
if re.search(r'pitcher.*before.*after', question, re.I):
|
233 |
names = re.findall(r'\b[A-Z][a-z]+', answer)
|
234 |
return ', '.join(names[:2])
|
235 |
|
236 |
-
# Generic fallback
|
237 |
return answer.strip().rstrip('.').strip()
|
238 |
|
239 |
class GaiaAgent:
|
@@ -249,10 +212,10 @@ class GaiaAgent:
|
|
249 |
"Always output the answer only—no explanations, no extra text."
|
250 |
)
|
251 |
|
252 |
-
def
|
253 |
file_text = ""
|
254 |
prompt_parts = [self.instructions]
|
255 |
-
# 1. File
|
256 |
if task_id:
|
257 |
file_bytes, content_type = fetch_file(task_id)
|
258 |
if file_bytes and content_type:
|
@@ -265,7 +228,7 @@ class GaiaAgent:
|
|
265 |
transcript = transcribe_youtube_audio(youtube_url)
|
266 |
if transcript:
|
267 |
prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
|
268 |
-
# 3. Web search
|
269 |
search_needed = not file_text and not youtube_url
|
270 |
search_keywords = [
|
271 |
"who", "what", "when", "where", "name", "number", "how many",
|
@@ -275,13 +238,8 @@ class GaiaAgent:
|
|
275 |
search_results = run_web_search(question)
|
276 |
if search_results:
|
277 |
prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
|
278 |
-
# 4. Compose prompt
|
279 |
prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
|
280 |
prompt = "\n".join(prompt_parts)
|
281 |
-
return prompt
|
282 |
-
|
283 |
-
def __call__(self, question: str, task_id: str = None) -> str:
|
284 |
-
prompt = self.answer_with_tools(question, task_id)
|
285 |
response = self.client.chat.completions.create(
|
286 |
model="gpt-4o",
|
287 |
messages=[
|
@@ -293,7 +251,6 @@ class GaiaAgent:
|
|
293 |
)
|
294 |
raw_output = safe_strip(response.choices[0].message.content)
|
295 |
formatted = format_gaia_answer(raw_output, question)
|
296 |
-
# Retry with web search if result is empty or likely incorrect for key factual types
|
297 |
if not formatted or formatted.lower() in ('', 'unknown', 'none', 'n/a') or 'apolog' in formatted.lower():
|
298 |
web_info = run_web_search(question)
|
299 |
if web_info:
|
|
|
2 |
import io
|
3 |
import re
|
4 |
import requests
|
|
|
|
|
5 |
import tempfile
|
6 |
+
import subprocess
|
7 |
from openai import OpenAI
|
8 |
from duckduckgo_search import DDGS
|
9 |
from PIL import Image
|
|
|
33 |
try:
|
34 |
ddgs = DDGS()
|
35 |
results = ddgs.text(query)
|
36 |
+
bodies = []
|
37 |
for i, r in enumerate(results):
|
38 |
if i >= max_results:
|
39 |
break
|
|
|
40 |
if r.get('body'):
|
41 |
+
bodies.append(r['body'])
|
42 |
elif r.get('title'):
|
43 |
+
bodies.append(r['title'])
|
44 |
+
return "\n".join(bodies)
|
45 |
except Exception:
|
46 |
return ""
|
47 |
|
|
|
133 |
return None
|
134 |
|
135 |
def format_gaia_answer(answer, question=None):
|
|
|
136 |
if not answer or not isinstance(answer, str):
|
137 |
return ""
|
|
|
|
|
138 |
answer = re.sub(r"(?i)i'?m sorry[,\.]?|i cannot|i can't|unable to|please provide.*|information not available|I can't assist.*|I'm unable.*|process the file directly", "", answer)
|
|
|
|
|
|
|
139 |
answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
|
|
|
|
|
140 |
if answer.startswith('"') and answer.endswith('"'):
|
141 |
answer = answer[1:-1]
|
142 |
if answer.startswith('[') and answer.endswith(']'):
|
143 |
answer = answer[1:-1]
|
|
|
|
|
144 |
if not re.match(r'^[A-Za-z]+\.$', answer):
|
145 |
answer = re.sub(r'\.$', '', answer)
|
146 |
|
|
|
147 |
if question:
|
148 |
+
# Pure number answers
|
149 |
if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
|
150 |
num_match = re.search(r'(\$?\d[\d,\.]*)', answer)
|
151 |
if num_match:
|
152 |
return num_match.group(1).replace(',', '')
|
153 |
|
|
|
154 |
if re.search(r'first name', question, re.I):
|
155 |
+
return answer.strip().split()[0]
|
|
|
|
|
|
|
156 |
if re.search(r'surname', question, re.I):
|
157 |
+
return answer.strip().split()[-1]
|
|
|
|
|
|
|
158 |
if re.search(r'city', question, re.I):
|
159 |
+
return answer.strip().split()[0]
|
|
|
|
|
|
|
160 |
if re.search(r'IOC country code|award number|NASA', question, re.I):
|
161 |
code_match = re.search(r'[A-Z0-9]{3,}', answer)
|
162 |
if code_match:
|
163 |
return code_match.group(0)
|
|
|
|
|
164 |
if 'algebraic notation' in question or 'chess' in question:
|
165 |
move_match = re.search(r'[A-Za-z0-9]+[#\+]?$', answer)
|
166 |
if move_match:
|
167 |
return move_match.group(0)
|
|
|
|
|
168 |
if "what does teal'c say" in question.lower():
|
169 |
qmatch = re.search(r'"(Indeed\.)"', answer)
|
170 |
if qmatch:
|
|
|
172 |
if "Indeed." in answer:
|
173 |
return "Indeed."
|
174 |
return answer
|
|
|
|
|
175 |
if re.search(r'list|comma.*separated|page numbers', question, re.I):
|
|
|
176 |
items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
|
|
|
177 |
items = [item for item in items if item and not re.match(r'(and|or|to|with|for|a|the)$', item)]
|
|
|
178 |
if 'page numbers' in question:
|
179 |
nums = [int(x) for x in re.findall(r'\d+', answer)]
|
180 |
return ', '.join(str(n) for n in sorted(nums))
|
|
|
181 |
if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
|
|
|
182 |
merged = []
|
183 |
skip = False
|
184 |
for i, item in enumerate(items):
|
185 |
if skip:
|
186 |
skip = False
|
187 |
continue
|
|
|
188 |
if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
|
189 |
merged.append(f"{item} {items[i+1]}")
|
190 |
skip = True
|
|
|
193 |
merged = sorted(set(merged))
|
194 |
return ', '.join(merged)
|
195 |
return ', '.join(items)
|
|
|
|
|
196 |
if re.search(r'pitcher.*before.*after', question, re.I):
|
197 |
names = re.findall(r'\b[A-Z][a-z]+', answer)
|
198 |
return ', '.join(names[:2])
|
199 |
|
|
|
200 |
return answer.strip().rstrip('.').strip()
|
201 |
|
202 |
class GaiaAgent:
|
|
|
212 |
"Always output the answer only—no explanations, no extra text."
|
213 |
)
|
214 |
|
215 |
+
def __call__(self, question: str, task_id: str = None) -> str:
|
216 |
file_text = ""
|
217 |
prompt_parts = [self.instructions]
|
218 |
+
# 1. File (image, Excel, etc)
|
219 |
if task_id:
|
220 |
file_bytes, content_type = fetch_file(task_id)
|
221 |
if file_bytes and content_type:
|
|
|
228 |
transcript = transcribe_youtube_audio(youtube_url)
|
229 |
if transcript:
|
230 |
prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
|
231 |
+
# 3. Web search for open facts
|
232 |
search_needed = not file_text and not youtube_url
|
233 |
search_keywords = [
|
234 |
"who", "what", "when", "where", "name", "number", "how many",
|
|
|
238 |
search_results = run_web_search(question)
|
239 |
if search_results:
|
240 |
prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
|
|
|
241 |
prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
|
242 |
prompt = "\n".join(prompt_parts)
|
|
|
|
|
|
|
|
|
243 |
response = self.client.chat.completions.create(
|
244 |
model="gpt-4o",
|
245 |
messages=[
|
|
|
251 |
)
|
252 |
raw_output = safe_strip(response.choices[0].message.content)
|
253 |
formatted = format_gaia_answer(raw_output, question)
|
|
|
254 |
if not formatted or formatted.lower() in ('', 'unknown', 'none', 'n/a') or 'apolog' in formatted.lower():
|
255 |
web_info = run_web_search(question)
|
256 |
if web_info:
|