Update agent.py
Browse files
agent.py
CHANGED
@@ -8,7 +8,6 @@ def duckduckgo_search(query: str) -> str:
|
|
8 |
with DDGS() as ddg:
|
9 |
results = ddg.text(query=query, region="wt-wt", max_results=5)
|
10 |
bodies = [r.get('body', '') for r in results if r.get('body')]
|
11 |
-
# For GAIA, prefer the first non-empty, or join a few if possible
|
12 |
return "\n".join(bodies[:3])
|
13 |
except Exception as e:
|
14 |
return f"ERROR: {e}"
|
@@ -20,21 +19,24 @@ def eval_python_code(code: str) -> str:
|
|
20 |
return f"ERROR: {e}"
|
21 |
|
22 |
def format_gaia_answer(answer: str, question: str = "") -> str:
|
23 |
-
"""
|
24 |
if not answer:
|
25 |
return ""
|
26 |
-
# Remove
|
27 |
answer = re.sub(
|
28 |
-
r'(?i)(
|
29 |
'', answer).strip()
|
30 |
-
# Remove
|
31 |
-
|
|
|
|
|
|
|
32 |
# Only numbers for count questions
|
33 |
if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
|
34 |
match = re.search(r'(\d+)', answer)
|
35 |
if match:
|
36 |
return match.group(1)
|
37 |
-
# Only
|
38 |
if "surname" in question:
|
39 |
return answer.split()[-1]
|
40 |
if "first name" in question:
|
@@ -48,7 +50,7 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
|
|
48 |
code = re.search(r'[A-Z0-9]{3,}', answer)
|
49 |
if code:
|
50 |
return code.group(0)
|
51 |
-
# For lists:
|
52 |
if "list" in question or "ingredient" in question or "vegetable" in question:
|
53 |
items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
|
54 |
merged = []
|
@@ -65,59 +67,68 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
|
|
65 |
merged = [x.lower() for x in merged]
|
66 |
merged = sorted(set(merged))
|
67 |
return ', '.join(merged)
|
68 |
-
# For chess: algebraic move
|
69 |
if "algebraic notation" in question or "chess" in question:
|
70 |
move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
|
71 |
if move:
|
72 |
return move[-1]
|
73 |
-
|
74 |
-
answer = answer.split('\n')[0].split('.')[0].strip()
|
75 |
-
return answer
|
76 |
|
77 |
class GaiaAgent:
|
78 |
def __init__(self):
|
79 |
self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
|
80 |
|
81 |
def __call__(self, question: str, task_id: str = None) -> str:
|
82 |
-
# 1. Try tool-based search for all fact/list/code questions
|
83 |
-
ql = question.lower()
|
84 |
-
# Try to route every "who", "what", "number", "albums", "at bats", "surname", etc. to web search
|
85 |
search_keywords = [
|
86 |
"who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats",
|
87 |
"nasa", "city", "winner", "code", "vegetable", "ingredient", "magda m.", "featured article"
|
88 |
]
|
89 |
-
|
|
|
90 |
web_result = duckduckgo_search(question)
|
91 |
llm_answer = self.llm.chat.completions.create(
|
92 |
model="gpt-4o",
|
93 |
messages=[
|
94 |
-
{"role": "system", "content": "You are a research assistant. Based on the
|
95 |
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
96 |
],
|
97 |
temperature=0.0,
|
98 |
max_tokens=256,
|
99 |
).choices[0].message.content.strip()
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
|
104 |
code = code_match.group(1) if code_match else ""
|
105 |
result = eval_python_code(code)
|
106 |
return format_gaia_answer(result, question)
|
107 |
-
#
|
108 |
-
if "list" in
|
109 |
web_result = duckduckgo_search(question)
|
110 |
llm_answer = self.llm.chat.completions.create(
|
111 |
model="gpt-4o",
|
112 |
messages=[
|
113 |
-
{"role": "system", "content": "You are a research assistant. Based on the
|
114 |
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
115 |
],
|
116 |
temperature=0.0,
|
117 |
max_tokens=256,
|
118 |
).choices[0].message.content.strip()
|
119 |
return format_gaia_answer(llm_answer, question)
|
120 |
-
#
|
121 |
llm_answer = self.llm.chat.completions.create(
|
122 |
model="gpt-4o",
|
123 |
messages=[
|
@@ -125,6 +136,6 @@ class GaiaAgent:
|
|
125 |
{"role": "user", "content": question}
|
126 |
],
|
127 |
temperature=0.0,
|
128 |
-
max_tokens=
|
129 |
).choices[0].message.content.strip()
|
130 |
return format_gaia_answer(llm_answer, question)
|
|
|
8 |
with DDGS() as ddg:
|
9 |
results = ddg.text(query=query, region="wt-wt", max_results=5)
|
10 |
bodies = [r.get('body', '') for r in results if r.get('body')]
|
|
|
11 |
return "\n".join(bodies[:3])
|
12 |
except Exception as e:
|
13 |
return f"ERROR: {e}"
|
|
|
19 |
return f"ERROR: {e}"
|
20 |
|
21 |
def format_gaia_answer(answer: str, question: str = "") -> str:
|
22 |
+
"""Strict GAIA output, eliminate apologies, extract only answer value."""
|
23 |
if not answer:
|
24 |
return ""
|
25 |
+
# Remove apologies and anything after
|
26 |
answer = re.sub(
|
27 |
+
r'(?i)(I[\' ]?m sorry.*|Unfortunately.*|I cannot.*|I am unable.*|error:.*|no file.*|but.*|however.*|unable to.*|not available.*|if you have access.*|I can\'t.*)',
|
28 |
'', answer).strip()
|
29 |
+
# Remove everything after the first period if it's not a list
|
30 |
+
if not ("list" in question or "ingredient" in question or "vegetable" in question):
|
31 |
+
answer = answer.split('\n')[0].split('.')[0]
|
32 |
+
# Remove quotes/brackets
|
33 |
+
answer = answer.strip(' "\'[](),;:')
|
34 |
# Only numbers for count questions
|
35 |
if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
|
36 |
match = re.search(r'(\d+)', answer)
|
37 |
if match:
|
38 |
return match.group(1)
|
39 |
+
# Only last word for "surname", first for "first name"
|
40 |
if "surname" in question:
|
41 |
return answer.split()[-1]
|
42 |
if "first name" in question:
|
|
|
50 |
code = re.search(r'[A-Z0-9]{3,}', answer)
|
51 |
if code:
|
52 |
return code.group(0)
|
53 |
+
# For lists: comma-separated, alpha, deduped, merged phrases
|
54 |
if "list" in question or "ingredient" in question or "vegetable" in question:
|
55 |
items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
|
56 |
merged = []
|
|
|
67 |
merged = [x.lower() for x in merged]
|
68 |
merged = sorted(set(merged))
|
69 |
return ', '.join(merged)
|
70 |
+
# For chess: algebraic move
|
71 |
if "algebraic notation" in question or "chess" in question:
|
72 |
move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
|
73 |
if move:
|
74 |
return move[-1]
|
75 |
+
return answer.strip(' "\'[](),;:')
|
|
|
|
|
76 |
|
77 |
class GaiaAgent:
|
78 |
def __init__(self):
|
79 |
self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
|
80 |
|
81 |
def __call__(self, question: str, task_id: str = None) -> str:
|
|
|
|
|
|
|
82 |
search_keywords = [
|
83 |
"who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats",
|
84 |
"nasa", "city", "winner", "code", "vegetable", "ingredient", "magda m.", "featured article"
|
85 |
]
|
86 |
+
needs_search = any(kw in question.lower() for kw in search_keywords)
|
87 |
+
if needs_search:
|
88 |
web_result = duckduckgo_search(question)
|
89 |
llm_answer = self.llm.chat.completions.create(
|
90 |
model="gpt-4o",
|
91 |
messages=[
|
92 |
+
{"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
|
93 |
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
94 |
],
|
95 |
temperature=0.0,
|
96 |
max_tokens=256,
|
97 |
).choices[0].message.content.strip()
|
98 |
+
formatted = format_gaia_answer(llm_answer, question)
|
99 |
+
# Retry if apology/empty/incorrect
|
100 |
+
if not formatted or "sorry" in formatted.lower() or "unable" in formatted.lower():
|
101 |
+
llm_answer2 = self.llm.chat.completions.create(
|
102 |
+
model="gpt-4o",
|
103 |
+
messages=[
|
104 |
+
{"role": "system", "content": "Only answer with the value. No explanation. Do not apologize. Do not begin with 'I'm sorry', 'Unfortunately', or similar."},
|
105 |
+
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
106 |
+
],
|
107 |
+
temperature=0.0,
|
108 |
+
max_tokens=128,
|
109 |
+
).choices[0].message.content.strip()
|
110 |
+
formatted = format_gaia_answer(llm_answer2, question)
|
111 |
+
return formatted
|
112 |
+
# For code/math output
|
113 |
+
if "output" in question.lower() and "python" in question.lower():
|
114 |
code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
|
115 |
code = code_match.group(1) if code_match else ""
|
116 |
result = eval_python_code(code)
|
117 |
return format_gaia_answer(result, question)
|
118 |
+
# For lists/ingredients, always web search and format
|
119 |
+
if "list" in question.lower() or "ingredient" in question.lower() or "vegetable" in question.lower():
|
120 |
web_result = duckduckgo_search(question)
|
121 |
llm_answer = self.llm.chat.completions.create(
|
122 |
model="gpt-4o",
|
123 |
messages=[
|
124 |
+
{"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
|
125 |
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
126 |
],
|
127 |
temperature=0.0,
|
128 |
max_tokens=256,
|
129 |
).choices[0].message.content.strip()
|
130 |
return format_gaia_answer(llm_answer, question)
|
131 |
+
# Fallback: strict LLM answer, formatted
|
132 |
llm_answer = self.llm.chat.completions.create(
|
133 |
model="gpt-4o",
|
134 |
messages=[
|
|
|
136 |
{"role": "user", "content": question}
|
137 |
],
|
138 |
temperature=0.0,
|
139 |
+
max_tokens=128,
|
140 |
).choices[0].message.content.strip()
|
141 |
return format_gaia_answer(llm_answer, question)
|