dawid-lorek commited on
Commit
0afb7b8
·
verified ·
1 Parent(s): e95dfeb

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +36 -25
agent.py CHANGED
@@ -8,7 +8,6 @@ def duckduckgo_search(query: str) -> str:
8
  with DDGS() as ddg:
9
  results = ddg.text(query=query, region="wt-wt", max_results=5)
10
  bodies = [r.get('body', '') for r in results if r.get('body')]
11
- # For GAIA, prefer the first non-empty, or join a few if possible
12
  return "\n".join(bodies[:3])
13
  except Exception as e:
14
  return f"ERROR: {e}"
@@ -20,21 +19,24 @@ def eval_python_code(code: str) -> str:
20
  return f"ERROR: {e}"
21
 
22
  def format_gaia_answer(answer: str, question: str = "") -> str:
23
- """Strictly format GAIA output and eliminate apologies or error text."""
24
  if not answer:
25
  return ""
26
- # Remove apology/error phrases
27
  answer = re.sub(
28
- r'(?i)(unfortunately|unable to|error:|not available|i cannot|i am unable|i can\'t|no file|skip|I do not have|I cannot access|I am currently unable|If you have access).*',
29
  '', answer).strip()
30
- # Remove leading/trailing quotes/brackets
31
- answer = answer.strip(' "\'[]')
 
 
 
32
  # Only numbers for count questions
33
  if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
34
  match = re.search(r'(\d+)', answer)
35
  if match:
36
  return match.group(1)
37
- # Only the last word for "surname" or first for "first name"
38
  if "surname" in question:
39
  return answer.split()[-1]
40
  if "first name" in question:
@@ -48,7 +50,7 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
48
  code = re.search(r'[A-Z0-9]{3,}', answer)
49
  if code:
50
  return code.group(0)
51
- # For lists: split, merge common phrases, dedupe, alpha-sort, comma-sep
52
  if "list" in question or "ingredient" in question or "vegetable" in question:
53
  items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
54
  merged = []
@@ -65,59 +67,68 @@ def format_gaia_answer(answer: str, question: str = "") -> str:
65
  merged = [x.lower() for x in merged]
66
  merged = sorted(set(merged))
67
  return ', '.join(merged)
68
- # For chess: algebraic move (like Qd1+)
69
  if "algebraic notation" in question or "chess" in question:
70
  move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
71
  if move:
72
  return move[-1]
73
- # Remove everything after first period for single-word answers
74
- answer = answer.split('\n')[0].split('.')[0].strip()
75
- return answer
76
 
77
  class GaiaAgent:
78
  def __init__(self):
79
  self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
80
 
81
  def __call__(self, question: str, task_id: str = None) -> str:
82
- # 1. Try tool-based search for all fact/list/code questions
83
- ql = question.lower()
84
- # Try to route every "who", "what", "number", "albums", "at bats", "surname", etc. to web search
85
  search_keywords = [
86
  "who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats",
87
  "nasa", "city", "winner", "code", "vegetable", "ingredient", "magda m.", "featured article"
88
  ]
89
- if any(kw in ql for kw in search_keywords):
 
90
  web_result = duckduckgo_search(question)
91
  llm_answer = self.llm.chat.completions.create(
92
  model="gpt-4o",
93
  messages=[
94
- {"role": "system", "content": "You are a research assistant. Based on the following web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
95
  {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
96
  ],
97
  temperature=0.0,
98
  max_tokens=256,
99
  ).choices[0].message.content.strip()
100
- return format_gaia_answer(llm_answer, question)
101
- # 2. For code/math
102
- if "output" in ql and "python" in ql:
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
104
  code = code_match.group(1) if code_match else ""
105
  result = eval_python_code(code)
106
  return format_gaia_answer(result, question)
107
- # 3. For lists or ingredients, always web search and format
108
- if "list" in ql or "ingredient" in ql or "vegetable" in ql:
109
  web_result = duckduckgo_search(question)
110
  llm_answer = self.llm.chat.completions.create(
111
  model="gpt-4o",
112
  messages=[
113
- {"role": "system", "content": "You are a research assistant. Based on the following web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
114
  {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
115
  ],
116
  temperature=0.0,
117
  max_tokens=256,
118
  ).choices[0].message.content.strip()
119
  return format_gaia_answer(llm_answer, question)
120
- # 4. Fallback: strict LLM answer, formatted
121
  llm_answer = self.llm.chat.completions.create(
122
  model="gpt-4o",
123
  messages=[
@@ -125,6 +136,6 @@ class GaiaAgent:
125
  {"role": "user", "content": question}
126
  ],
127
  temperature=0.0,
128
- max_tokens=256,
129
  ).choices[0].message.content.strip()
130
  return format_gaia_answer(llm_answer, question)
 
8
  with DDGS() as ddg:
9
  results = ddg.text(query=query, region="wt-wt", max_results=5)
10
  bodies = [r.get('body', '') for r in results if r.get('body')]
 
11
  return "\n".join(bodies[:3])
12
  except Exception as e:
13
  return f"ERROR: {e}"
 
19
  return f"ERROR: {e}"
20
 
21
  def format_gaia_answer(answer: str, question: str = "") -> str:
22
+ """Strict GAIA output, eliminate apologies, extract only answer value."""
23
  if not answer:
24
  return ""
25
+ # Remove apologies and anything after
26
  answer = re.sub(
27
+ r'(?i)(I[\' ]?m sorry.*|Unfortunately.*|I cannot.*|I am unable.*|error:.*|no file.*|but.*|however.*|unable to.*|not available.*|if you have access.*|I can\'t.*)',
28
  '', answer).strip()
29
+ # Remove everything after the first period if it's not a list
30
+ if not ("list" in question or "ingredient" in question or "vegetable" in question):
31
+ answer = answer.split('\n')[0].split('.')[0]
32
+ # Remove quotes/brackets
33
+ answer = answer.strip(' "\'[](),;:')
34
  # Only numbers for count questions
35
  if re.search(r'how many|number of|albums|at bats|total sales|output', question, re.I):
36
  match = re.search(r'(\d+)', answer)
37
  if match:
38
  return match.group(1)
39
+ # Only last word for "surname", first for "first name"
40
  if "surname" in question:
41
  return answer.split()[-1]
42
  if "first name" in question:
 
50
  code = re.search(r'[A-Z0-9]{3,}', answer)
51
  if code:
52
  return code.group(0)
53
+ # For lists: comma-separated, alpha, deduped, merged phrases
54
  if "list" in question or "ingredient" in question or "vegetable" in question:
55
  items = [x.strip(' "\'') for x in re.split(r'[,\n]', answer) if x.strip()]
56
  merged = []
 
67
  merged = [x.lower() for x in merged]
68
  merged = sorted(set(merged))
69
  return ', '.join(merged)
70
+ # For chess: algebraic move
71
  if "algebraic notation" in question or "chess" in question:
72
  move = re.findall(r'[KQRBN]?[a-h]?[1-8]?x?[a-h][1-8][+#]?', answer)
73
  if move:
74
  return move[-1]
75
+ return answer.strip(' "\'[](),;:')
 
 
76
 
77
  class GaiaAgent:
78
  def __init__(self):
79
  self.llm = OpenAIClient(api_key=os.getenv("OPENAI_API_KEY"))
80
 
81
  def __call__(self, question: str, task_id: str = None) -> str:
 
 
 
82
  search_keywords = [
83
  "who", "when", "what", "which", "how many", "number", "name", "albums", "surname", "at bats",
84
  "nasa", "city", "winner", "code", "vegetable", "ingredient", "magda m.", "featured article"
85
  ]
86
+ needs_search = any(kw in question.lower() for kw in search_keywords)
87
+ if needs_search:
88
  web_result = duckduckgo_search(question)
89
  llm_answer = self.llm.chat.completions.create(
90
  model="gpt-4o",
91
  messages=[
92
+ {"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
93
  {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
94
  ],
95
  temperature=0.0,
96
  max_tokens=256,
97
  ).choices[0].message.content.strip()
98
+ formatted = format_gaia_answer(llm_answer, question)
99
+ # Retry if apology/empty/incorrect
100
+ if not formatted or "sorry" in formatted.lower() or "unable" in formatted.lower():
101
+ llm_answer2 = self.llm.chat.completions.create(
102
+ model="gpt-4o",
103
+ messages=[
104
+ {"role": "system", "content": "Only answer with the value. No explanation. Do not apologize. Do not begin with 'I'm sorry', 'Unfortunately', or similar."},
105
+ {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
106
+ ],
107
+ temperature=0.0,
108
+ max_tokens=128,
109
+ ).choices[0].message.content.strip()
110
+ formatted = format_gaia_answer(llm_answer2, question)
111
+ return formatted
112
+ # For code/math output
113
+ if "output" in question.lower() and "python" in question.lower():
114
  code_match = re.search(r'```python(.*?)```', question, re.DOTALL)
115
  code = code_match.group(1) if code_match else ""
116
  result = eval_python_code(code)
117
  return format_gaia_answer(result, question)
118
+ # For lists/ingredients, always web search and format
119
+ if "list" in question.lower() or "ingredient" in question.lower() or "vegetable" in question.lower():
120
  web_result = duckduckgo_search(question)
121
  llm_answer = self.llm.chat.completions.create(
122
  model="gpt-4o",
123
  messages=[
124
+ {"role": "system", "content": "You are a research assistant. Based on the web search results and question, answer strictly and concisely for the GAIA benchmark. Only the answer, no explanations or apologies."},
125
  {"role": "user", "content": f"Web search results:\n{web_result}\n\nQuestion: {question}"}
126
  ],
127
  temperature=0.0,
128
  max_tokens=256,
129
  ).choices[0].message.content.strip()
130
  return format_gaia_answer(llm_answer, question)
131
+ # Fallback: strict LLM answer, formatted
132
  llm_answer = self.llm.chat.completions.create(
133
  model="gpt-4o",
134
  messages=[
 
136
  {"role": "user", "content": question}
137
  ],
138
  temperature=0.0,
139
+ max_tokens=128,
140
  ).choices[0].message.content.strip()
141
  return format_gaia_answer(llm_answer, question)