Update agent.py
Browse files
agent.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1 |
import os
|
2 |
-
import asyncio
|
3 |
import re
|
4 |
from openai import OpenAI as OpenAIClient
|
5 |
-
from llama_index.llms.openai import OpenAI
|
6 |
-
from llama_index.core.agent.react import ReActAgent
|
7 |
-
from llama_index.core.tools import FunctionTool
|
8 |
from duckduckgo_search import DDGS
|
9 |
|
10 |
def duckduckgo_search(query: str) -> str:
|
11 |
try:
|
12 |
with DDGS() as ddg:
|
13 |
-
results = ddg.text(query=query, region="wt-wt", max_results=
|
14 |
return "\n".join(r.get('body', '') for r in results if r.get('body'))
|
15 |
except Exception as e:
|
16 |
return f"ERROR: {e}"
|
@@ -65,28 +61,51 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
|
|
65 |
return ', '.join(items)
|
66 |
return ans.strip().rstrip('.')
|
67 |
|
68 |
-
llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # no model= param
|
69 |
-
|
70 |
-
tools = [
|
71 |
-
FunctionTool.from_defaults(duckduckgo_search, name="duckduckgo_search", description="Web search tool"),
|
72 |
-
FunctionTool.from_defaults(eval_python_code, name="python_eval", description="Evaluate Python code"),
|
73 |
-
FunctionTool.from_defaults(format_gaia_answer, name="format_gaia_answer", description="GAIA output formatting")
|
74 |
-
]
|
75 |
-
|
76 |
-
agent = ReActAgent.from_tools(
|
77 |
-
tools=tools,
|
78 |
-
llm=llm,
|
79 |
-
system_prompt="You're a GAIA benchmark agent. Always use tools and strictly output only the final answer in GAIA format—no apologies or explanations.",
|
80 |
-
verbose=False
|
81 |
-
)
|
82 |
-
|
83 |
-
async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
|
84 |
-
resp = await agent.achat(question)
|
85 |
-
return resp.response
|
86 |
-
|
87 |
-
def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
|
88 |
-
return asyncio.run(answer_question(question, task_id, file_path))
|
89 |
-
|
90 |
class GaiaAgent:
|
|
|
|
|
|
|
91 |
def __call__(self, question: str, task_id: str = None) -> str:
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import re
|
3 |
from openai import OpenAI as OpenAIClient
|
|
|
|
|
|
|
4 |
from duckduckgo_search import DDGS
|
5 |
|
6 |
def duckduckgo_search(query: str) -> str:
|
7 |
try:
|
8 |
with DDGS() as ddg:
|
9 |
+
results = ddg.text(query=query, region="wt-wt", max_results=5)
|
10 |
return "\n".join(r.get('body', '') for r in results if r.get('body'))
|
11 |
except Exception as e:
|
12 |
return f"ERROR: {e}"
|
|
|
61 |
return ', '.join(items)
|
62 |
return ans.strip().rstrip('.')
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
class GaiaAgent:
|
65 |
+
def __init__(self):
|
66 |
+
self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
|
67 |
+
|
68 |
def __call__(self, question: str, task_id: str = None) -> str:
|
69 |
+
# Route to tools by keyword
|
70 |
+
if any(kw in question.lower() for kw in ["who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats", "nasa", "city", "winner", "code"]):
|
71 |
+
web_result = duckduckgo_search(question)
|
72 |
+
llm_answer = self.llm.chat.completions.create(
|
73 |
+
model="gpt-4o",
|
74 |
+
messages=[
|
75 |
+
{"role": "system", "content": "You are a research assistant. Based on the following web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations."},
|
76 |
+
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
77 |
+
],
|
78 |
+
temperature=0.0,
|
79 |
+
max_tokens=256,
|
80 |
+
).choices[0].message.content.strip()
|
81 |
+
return format_gaia_answer(llm_answer, question)
|
82 |
+
# Code/math
|
83 |
+
if "output" in question.lower() and "python" in question.lower():
|
84 |
+
code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
|
85 |
+
code = code_match.group(1) if code_match else ""
|
86 |
+
result = eval_python_code(code)
|
87 |
+
return format_gaia_answer(result, question)
|
88 |
+
# List/ingredients/vegetables
|
89 |
+
if "list" in question.lower() or "ingredient" in question.lower() or "vegetable" in question.lower():
|
90 |
+
web_result = duckduckgo_search(question)
|
91 |
+
llm_answer = self.llm.chat.completions.create(
|
92 |
+
model="gpt-4o",
|
93 |
+
messages=[
|
94 |
+
{"role": "system", "content": "You are a research assistant. Based on the following web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations."},
|
95 |
+
{"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
|
96 |
+
],
|
97 |
+
temperature=0.0,
|
98 |
+
max_tokens=256,
|
99 |
+
).choices[0].message.content.strip()
|
100 |
+
return format_gaia_answer(llm_answer, question)
|
101 |
+
# Fallback
|
102 |
+
llm_answer = self.llm.chat.completions.create(
|
103 |
+
model="gpt-4o",
|
104 |
+
messages=[
|
105 |
+
{"role": "system", "content": "You are a research assistant. Answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations."},
|
106 |
+
{"role": "user", "content": question}
|
107 |
+
],
|
108 |
+
temperature=0.0,
|
109 |
+
max_tokens=256,
|
110 |
+
).choices[0].message.content.strip()
|
111 |
+
return format_gaia_answer(llm_answer, question)
|