dawid-lorek commited on
Commit
70672a2
·
verified ·
1 Parent(s): 6e0803e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +30 -11
agent.py CHANGED
@@ -1,12 +1,12 @@
1
- # agent_v25.py
2
  import os
3
  import re
4
  import io
5
  import base64
6
  import requests
7
  import pandas as pd
8
- from word2number import w2n
9
  from openai import OpenAI
 
10
 
11
  class GaiaAgent:
12
  def __init__(self):
@@ -26,7 +26,7 @@ class GaiaAgent:
26
  response = self.client.chat.completions.create(
27
  model=model,
28
  messages=[
29
- {"role": "system", "content": "You are a precise assistant. Return only the final answer. Do not explain."},
30
  {"role": "user", "content": prompt.strip() + "\nFinal Answer:"}
31
  ],
32
  temperature=0.0,
@@ -66,6 +66,25 @@ class GaiaAgent:
66
  except Exception:
67
  return "$0.00"
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def extract_answer(self, text, question):
70
  q = question.lower()
71
  text = text.strip().strip("\"'").strip()
@@ -81,14 +100,9 @@ class GaiaAgent:
81
  match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", text)
82
  return match.group(1) if match else text
83
 
84
- if "ingredients" in q or "comma separated list" in q:
85
- items = re.findall(r"[a-zA-Z]+(?: [a-zA-Z]+)?", text)
86
- return ", ".join(sorted(set(i.lower() for i in items)))
87
-
88
- if "vegetables" in q:
89
- veggies = ['acorns', 'broccoli', 'celery', 'green beans', 'lettuce', 'peanuts', 'sweet potatoes']
90
- found = [v for v in veggies if v in text.lower()]
91
- return ", ".join(sorted(found))
92
 
93
  if "usd with two decimal places" in q:
94
  match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", text)
@@ -127,6 +141,11 @@ class GaiaAgent:
127
  file_bytes, ctype = self.fetch_file(task_id)
128
 
129
  try:
 
 
 
 
 
130
  if file_bytes and "image" in ctype:
131
  raw = self.ask_image(file_bytes, question)
132
  elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):
 
1
+ # agent_v26.py
2
  import os
3
  import re
4
  import io
5
  import base64
6
  import requests
7
  import pandas as pd
 
8
  from openai import OpenAI
9
+ from word2number import w2n
10
 
11
  class GaiaAgent:
12
  def __init__(self):
 
26
  response = self.client.chat.completions.create(
27
  model=model,
28
  messages=[
29
+ {"role": "system", "content": "You are a precise assistant. Only return the final answer. Do not guess. Use reasoning and tools when needed."},
30
  {"role": "user", "content": prompt.strip() + "\nFinal Answer:"}
31
  ],
32
  temperature=0.0,
 
66
  except Exception:
67
  return "$0.00"
68
 
69
+ def analyze_commutativity(self, question):
70
+ try:
71
+ rows = re.findall(r"\|([a-e])\|([a-e\|]+)\|", question)
72
+ table = {}
73
+ for row in rows:
74
+ key, values = row
75
+ table[key] = values.strip('|').split('|')
76
+ elements = list(table.keys())
77
+ non_comm = set()
78
+ for i, x in enumerate(elements):
79
+ for j, y in enumerate(elements):
80
+ a = table[x][j]
81
+ b = table[y][i]
82
+ if a != b:
83
+ non_comm.update([x, y])
84
+ return ", ".join(sorted(non_comm))
85
+ except:
86
+ return ""
87
+
88
  def extract_answer(self, text, question):
89
  q = question.lower()
90
  text = text.strip().strip("\"'").strip()
 
100
  match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", text)
101
  return match.group(1) if match else text
102
 
103
+ if "comma separated list" in q:
104
+ words = re.findall(r"[a-zA-Z][a-zA-Z ]+[a-zA-Z]", text)
105
+ return ", ".join(sorted(set(w.strip().lower() for w in words)))
 
 
 
 
 
106
 
107
  if "usd with two decimal places" in q:
108
  match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", text)
 
141
  file_bytes, ctype = self.fetch_file(task_id)
142
 
143
  try:
144
+ if "commutative" in question.lower():
145
+ result = self.analyze_commutativity(question)
146
+ if result:
147
+ return result
148
+
149
  if file_bytes and "image" in ctype:
150
  raw = self.ask_image(file_bytes, question)
151
  elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):