dawid-lorek commited on
Commit
b7f42a2
·
verified ·
1 Parent(s): 25d6735

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +32 -37
agent.py CHANGED
@@ -1,28 +1,28 @@
1
  import os
2
  import asyncio
3
  import re
4
- from openai import OpenAI
 
5
  from llama_index.core.agent.react import ReActAgent
6
  from llama_index.core.tools import FunctionTool
7
  from duckduckgo_search import DDGS
8
 
9
- # --- Custom DuckDuckGo Search Tool ---
10
  class DuckDuckGoSearchTool:
11
  def __init__(self):
12
  self.metadata = {
13
  "name": "duckduckgo_search",
14
- "description": "Search web via DuckDuckGo and return brief summaries."
15
  }
16
  def __call__(self, query: str) -> str:
17
  try:
18
  with DDGS() as ddg:
19
  results = ddg.text(query=query, region="wt-wt", max_results=3)
20
- return "\n".join(r.get('body', '') for r in results if r.get('body'))
21
  except Exception as e:
22
  return f"ERROR: {e}"
23
 
24
- # --- Other Tools ---
25
-
26
  def eval_python_code(code: str) -> str:
27
  try:
28
  return str(eval(code, {"__builtins__": {}}))
@@ -32,36 +32,32 @@ def eval_python_code(code: str) -> str:
32
  def format_gaia_answer(answer: str, question: str = "") -> str:
33
  if not answer:
34
  return ""
35
- answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
36
- answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable| apologize| not available|process the file).*', '', answer).strip()
37
- if answer.startswith('"') and answer.endswith('"'):
38
- answer = answer[1:-1]
39
- if answer.startswith('[') and answer.endswith(']'):
40
- answer = answer[1:-1]
41
- if not re.match(r'^[A-Za-z]+\.$', answer):
42
- answer = re.sub(r'\.$', '', answer)
43
 
44
  if question:
45
- num_q = re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I)
46
- list_q = re.search(r'list|comma.*separated|page numbers', question, re.I)
47
- if num_q:
48
- num = re.search(r'(\$?\d[\d,\.]*)', answer)
49
- if num:
50
- return num.group(1).replace(',', '')
51
  if 'first name' in question:
52
- return answer.split()[0]
53
  if 'surname' in question:
54
- return answer.split()[-1]
55
  if 'city' in question:
56
- return answer.split()[0]
57
  if re.search(r'IOC country code|award number|NASA', question, re.I):
58
- code = re.search(r'[A-Z0-9]{3,}', answer)
59
- if code:
60
- return code.group(0)
61
- if list_q:
62
- items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
63
  if 'page numbers' in question:
64
- nums = sorted(int(x) for x in re.findall(r'\d+', answer))
65
  return ', '.join(str(n) for n in nums)
66
  if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
67
  merged, skip = [], False
@@ -76,28 +72,27 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
76
  merged.append(x)
77
  return ', '.join(sorted(set(merged)))
78
  return ', '.join(items)
 
79
 
80
- return answer.strip().rstrip('.')
81
-
82
- # --- LLM & Tools Setup ---
83
  llm = OpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))
84
 
85
  tools = [
86
- FunctionTool.from_defaults(DuckDuckGoSearchTool(), name="duckduckgo_search", description="Searches the web via DuckDuckGo"),
87
  FunctionTool.from_defaults(eval_python_code, name="python_eval", description="Evaluate Python code"),
88
- FunctionTool.from_defaults(format_gaia_answer, name="format_gaia_answer", description="Strict GAIA output formatting")
89
  ]
90
 
91
  agent = ReActAgent.from_tools(
92
  tools=tools,
93
  llm=llm,
94
- system_prompt="You're a GAIA benchmark agent. Use tools and always output only the final answer in strict format—no explanation or apology.",
95
  verbose=False
96
  )
97
 
98
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
99
- result = await agent.achat(question)
100
- return result.response
101
 
102
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
103
  return asyncio.run(answer_question(question, task_id, file_path))
 
1
  import os
2
  import asyncio
3
  import re
4
+ from openai import OpenAI as OpenAIClient
5
+ from llama_index.llms.openai import OpenAI
6
  from llama_index.core.agent.react import ReActAgent
7
  from llama_index.core.tools import FunctionTool
8
  from duckduckgo_search import DDGS
9
 
10
+ # --- Custom DuckDuckGo Tool ---
11
  class DuckDuckGoSearchTool:
12
  def __init__(self):
13
  self.metadata = {
14
  "name": "duckduckgo_search",
15
+ "description": "Search the web via DuckDuckGo."
16
  }
17
  def __call__(self, query: str) -> str:
18
  try:
19
  with DDGS() as ddg:
20
  results = ddg.text(query=query, region="wt-wt", max_results=3)
21
+ return "\n".join(r.get('body','') for r in results if r.get('body'))
22
  except Exception as e:
23
  return f"ERROR: {e}"
24
 
25
+ # --- Tools ---
 
26
  def eval_python_code(code: str) -> str:
27
  try:
28
  return str(eval(code, {"__builtins__": {}}))
 
32
  def format_gaia_answer(answer: str, question: str = "") -> str:
33
  if not answer:
34
  return ""
35
+ ans = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
36
+ ans = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable| apologize| not available).*', '', ans).strip()
37
+ if ans.startswith('"') and ans.endswith('"'):
38
+ ans = ans[1:-1]
39
+ if ans.startswith('[') and ans.endswith(']'):
40
+ ans = ans[1:-1]
41
+ if not re.match(r'^[A-Za-z]+\.$', ans):
42
+ ans = re.sub(r'\.$', '', ans)
43
 
44
  if question:
45
+ if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
46
+ m = re.search(r'(\$?\d[\d,\.]*)', ans)
47
+ if m: return m.group(1).replace(',', '')
 
 
 
48
  if 'first name' in question:
49
+ return ans.split()[0]
50
  if 'surname' in question:
51
+ return ans.split()[-1]
52
  if 'city' in question:
53
+ return ans.split()[0]
54
  if re.search(r'IOC country code|award number|NASA', question, re.I):
55
+ c = re.search(r'[A-Z0-9]{3,}', ans)
56
+ if c: return c.group(0)
57
+ if re.search(r'list|comma.*separated|page numbers', question, re.I):
58
+ items = [x.strip('",.').lower() for x in re.split(r'[,\n]', ans) if x.strip()]
 
59
  if 'page numbers' in question:
60
+ nums = sorted(int(x) for x in re.findall(r'\d+', ans))
61
  return ', '.join(str(n) for n in nums)
62
  if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
63
  merged, skip = [], False
 
72
  merged.append(x)
73
  return ', '.join(sorted(set(merged)))
74
  return ', '.join(items)
75
+ return ans.strip().rstrip('.')
76
 
77
+ # --- LLM setup ---
 
 
78
  llm = OpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))
79
 
80
  tools = [
81
+ FunctionTool.from_defaults(DuckDuckGoSearchTool(), name="duckduckgo_search", description="Web search tool"),
82
  FunctionTool.from_defaults(eval_python_code, name="python_eval", description="Evaluate Python code"),
83
+ FunctionTool.from_defaults(format_gaia_answer, name="format_gaia_answer", description="GAIA output formatting")
84
  ]
85
 
86
  agent = ReActAgent.from_tools(
87
  tools=tools,
88
  llm=llm,
89
+ system_prompt="You're a GAIA benchmark agent. Always use tools and strictly output only the final answer in GAIA format—no apologies or explanations.",
90
  verbose=False
91
  )
92
 
93
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
94
+ resp = await agent.achat(question)
95
+ return resp.response
96
 
97
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
98
  return asyncio.run(answer_question(question, task_id, file_path))