dawid-lorek commited on
Commit
a3a06d3
·
verified ·
1 Parent(s): d4f0ca4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +63 -52
agent.py CHANGED
@@ -1,96 +1,107 @@
1
  import os
2
  import asyncio
3
  import re
4
- from typing import Any
5
-
6
- from llama_index.llms.openai import OpenAI
7
  from llama_index.core.agent.react import ReActAgent
8
  from llama_index.core.tools import FunctionTool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Correct import for LlamaIndex >= 0.10
11
- from llama_index.tools.duckduckgo_search import DuckDuckGoSearchTool
12
 
13
- # Simple tool: Evaluate Python code for math/code questions
14
  def eval_python_code(code: str) -> str:
15
  try:
16
  return str(eval(code, {"__builtins__": {}}))
17
  except Exception as e:
18
  return f"ERROR: {e}"
19
 
20
- # Strict output formatting for GAIA
21
  def format_gaia_answer(answer: str, question: str = "") -> str:
22
  if not answer:
23
  return ""
24
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
25
- answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable to| apologize| not available|process the file).*', '', answer).strip()
26
- if answer.startswith('"') and answer.endswith('"'): answer = answer[1:-1]
27
- if answer.startswith('[') and answer.endswith(']'): answer = answer[1:-1]
28
- if not re.match(r'^[A-Za-z]+\.$', answer): answer = re.sub(r'\.$', '', answer)
 
 
 
 
29
  if question:
30
- if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
 
 
31
  num = re.search(r'(\$?\d[\d,\.]*)', answer)
32
- if num: return num.group(1).replace(',', '')
33
- if 'first name' in question: return answer.split()[0]
34
- if 'surname' in question: return answer.split()[-1]
35
- if 'city' in question: return answer.split()[0]
 
 
 
 
36
  if re.search(r'IOC country code|award number|NASA', question, re.I):
37
  code = re.search(r'[A-Z0-9]{3,}', answer)
38
- if code: return code.group(0)
39
- if re.search(r'list|comma.*separated|page numbers', question, re.I):
 
40
  items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
41
  if 'page numbers' in question:
42
- nums = [int(x) for x in re.findall(r'\d+', answer)]
43
- return ', '.join(str(n) for n in sorted(nums))
44
- if 'ingredient' in question or 'vegetable' in question:
45
- merged = []
46
- skip = False
47
- for i, item in enumerate(items):
48
- if skip: skip = False; continue
49
- if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
50
- merged.append(f"{item} {items[i+1]}")
 
51
  skip = True
52
- else: merged.append(item)
53
- merged = sorted(set(merged))
54
- return ', '.join(merged)
55
  return ', '.join(items)
56
- return answer.strip().rstrip('.').strip()
57
 
58
- # LLM setup
59
- llm = OpenAI(model="gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
 
 
60
 
61
- # Tool registry
62
  tools = [
63
- DuckDuckGoSearchTool(),
64
- FunctionTool.from_defaults(
65
- eval_python_code,
66
- name="python_eval",
67
- description="Evaluate simple Python code and return result as string. Use for math or code output."
68
- ),
69
- FunctionTool.from_defaults(
70
- format_gaia_answer,
71
- name="format_gaia_answer",
72
- description="Postprocess and enforce strict GAIA format on answers given a question."
73
- ),
74
  ]
75
 
76
- # Main agent
77
  agent = ReActAgent.from_tools(
78
  tools=tools,
79
  llm=llm,
80
- system_prompt="You are a helpful GAIA benchmark agent. For every question, use the best tools available and always return only the final answer in the strict GAIA-required format—never explain, never apologize.",
81
  verbose=False
82
  )
83
 
84
- # Async entrypoint
85
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
86
  result = await agent.achat(question)
87
  return result.response
88
 
89
- # Synchronous wrapper
90
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
91
  return asyncio.run(answer_question(question, task_id, file_path))
92
 
93
- # For compatibility with app.py (GAIAAgent class)
94
  class GaiaAgent:
95
- def __call__(self, question: str, task_id: str = None, file_path: str = None) -> str:
96
- return answer_question_sync(question, task_id, file_path)
 
1
  import os
2
  import asyncio
3
  import re
4
+ from openai import OpenAI
 
 
5
  from llama_index.core.agent.react import ReActAgent
6
  from llama_index.core.tools import FunctionTool
7
+ from duckduckgo_search import DDGS
8
+
9
+ # --- Custom DuckDuckGo Search Tool ---
10
+ class DuckDuckGoSearchTool:
11
+ def __init__(self):
12
+ self.metadata = {
13
+ "name": "duckduckgo_search",
14
+ "description": "Search web via DuckDuckGo and return brief summaries."
15
+ }
16
+ def __call__(self, query: str) -> str:
17
+ try:
18
+ with DDGS() as ddg:
19
+ results = ddg.text(query=query, region="wt-wt", max_results=3)
20
+ return "\n".join(r.get('body', '') for r in results if r.get('body'))
21
+ except Exception as e:
22
+ return f"ERROR: {e}"
23
 
24
+ # --- Other Tools ---
 
25
 
 
26
  def eval_python_code(code: str) -> str:
27
  try:
28
  return str(eval(code, {"__builtins__": {}}))
29
  except Exception as e:
30
  return f"ERROR: {e}"
31
 
 
32
  def format_gaia_answer(answer: str, question: str = "") -> str:
33
  if not answer:
34
  return ""
35
  answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
36
+ answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable| apologize| not available|process the file).*', '', answer).strip()
37
+ if answer.startswith('"') and answer.endswith('"'):
38
+ answer = answer[1:-1]
39
+ if answer.startswith('[') and answer.endswith(']'):
40
+ answer = answer[1:-1]
41
+ if not re.match(r'^[A-Za-z]+\.$', answer):
42
+ answer = re.sub(r'\.$', '', answer)
43
+
44
  if question:
45
+ num_q = re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I)
46
+ list_q = re.search(r'list|comma.*separated|page numbers', question, re.I)
47
+ if num_q:
48
  num = re.search(r'(\$?\d[\d,\.]*)', answer)
49
+ if num:
50
+ return num.group(1).replace(',', '')
51
+ if 'first name' in question:
52
+ return answer.split()[0]
53
+ if 'surname' in question:
54
+ return answer.split()[-1]
55
+ if 'city' in question:
56
+ return answer.split()[0]
57
  if re.search(r'IOC country code|award number|NASA', question, re.I):
58
  code = re.search(r'[A-Z0-9]{3,}', answer)
59
+ if code:
60
+ return code.group(0)
61
+ if list_q:
62
  items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
63
  if 'page numbers' in question:
64
+ nums = sorted(int(x) for x in re.findall(r'\d+', answer))
65
+ return ', '.join(str(n) for n in nums)
66
+ if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
67
+ merged, skip = [], False
68
+ for i, x in enumerate(items):
69
+ if skip:
70
+ skip = False
71
+ continue
72
+ if i+1 < len(items) and x in ['sweet','green','lemon','ripe','whole','fresh']:
73
+ merged.append(f"{x} {items[i+1]}")
74
  skip = True
75
+ else:
76
+ merged.append(x)
77
+ return ', '.join(sorted(set(merged)))
78
  return ', '.join(items)
 
79
 
80
+ return answer.strip().rstrip('.')
81
+
82
+ # --- LLM & Tools Setup ---
83
+ llm = OpenAI(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))
84
 
 
85
  tools = [
86
+ FunctionTool.from_defaults(DuckDuckGoSearchTool(), name="duckduckgo_search", description="Searches the web via DuckDuckGo"),
87
+ FunctionTool.from_defaults(eval_python_code, name="python_eval", description="Evaluate Python code"),
88
+ FunctionTool.from_defaults(format_gaia_answer, name="format_gaia_answer", description="Strict GAIA output formatting")
 
 
 
 
 
 
 
 
89
  ]
90
 
 
91
  agent = ReActAgent.from_tools(
92
  tools=tools,
93
  llm=llm,
94
+ system_prompt="You're a GAIA benchmark agent. Use tools and always output only the final answer in strict format—no explanation or apology.",
95
  verbose=False
96
  )
97
 
 
98
  async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
99
  result = await agent.achat(question)
100
  return result.response
101
 
 
102
  def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
103
  return asyncio.run(answer_question(question, task_id, file_path))
104
 
 
105
  class GaiaAgent:
106
+ def __call__(self, question: str, task_id: str = None) -> str:
107
+ return answer_question_sync(question, task_id)