Update agent.py
Browse files
agent.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
#
|
2 |
import os
|
3 |
import re
|
4 |
import io
|
5 |
import base64
|
6 |
import requests
|
7 |
import pandas as pd
|
8 |
-
from word2number import w2n
|
9 |
from openai import OpenAI
|
|
|
10 |
|
11 |
class GaiaAgent:
|
12 |
def __init__(self):
|
@@ -26,7 +26,7 @@ class GaiaAgent:
|
|
26 |
response = self.client.chat.completions.create(
|
27 |
model=model,
|
28 |
messages=[
|
29 |
-
{"role": "system", "content": "You are a precise assistant.
|
30 |
{"role": "user", "content": prompt.strip() + "\nFinal Answer:"}
|
31 |
],
|
32 |
temperature=0.0,
|
@@ -66,6 +66,25 @@ class GaiaAgent:
|
|
66 |
except Exception:
|
67 |
return "$0.00"
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def extract_answer(self, text, question):
|
70 |
q = question.lower()
|
71 |
text = text.strip().strip("\"'").strip()
|
@@ -81,14 +100,9 @@ class GaiaAgent:
|
|
81 |
match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", text)
|
82 |
return match.group(1) if match else text
|
83 |
|
84 |
-
if "
|
85 |
-
|
86 |
-
return ", ".join(sorted(set(
|
87 |
-
|
88 |
-
if "vegetables" in q:
|
89 |
-
veggies = ['acorns', 'broccoli', 'celery', 'green beans', 'lettuce', 'peanuts', 'sweet potatoes']
|
90 |
-
found = [v for v in veggies if v in text.lower()]
|
91 |
-
return ", ".join(sorted(found))
|
92 |
|
93 |
if "usd with two decimal places" in q:
|
94 |
match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", text)
|
@@ -127,6 +141,11 @@ class GaiaAgent:
|
|
127 |
file_bytes, ctype = self.fetch_file(task_id)
|
128 |
|
129 |
try:
|
|
|
|
|
|
|
|
|
|
|
130 |
if file_bytes and "image" in ctype:
|
131 |
raw = self.ask_image(file_bytes, question)
|
132 |
elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):
|
|
|
1 |
+
# agent_v26.py
|
2 |
import os
|
3 |
import re
|
4 |
import io
|
5 |
import base64
|
6 |
import requests
|
7 |
import pandas as pd
|
|
|
8 |
from openai import OpenAI
|
9 |
+
from word2number import w2n
|
10 |
|
11 |
class GaiaAgent:
|
12 |
def __init__(self):
|
|
|
26 |
response = self.client.chat.completions.create(
|
27 |
model=model,
|
28 |
messages=[
|
29 |
+
{"role": "system", "content": "You are a precise assistant. Only return the final answer. Do not guess. Use reasoning and tools when needed."},
|
30 |
{"role": "user", "content": prompt.strip() + "\nFinal Answer:"}
|
31 |
],
|
32 |
temperature=0.0,
|
|
|
66 |
except Exception:
|
67 |
return "$0.00"
|
68 |
|
69 |
+
def analyze_commutativity(self, question):
|
70 |
+
try:
|
71 |
+
rows = re.findall(r"\|([a-e])\|([a-e\|]+)\|", question)
|
72 |
+
table = {}
|
73 |
+
for row in rows:
|
74 |
+
key, values = row
|
75 |
+
table[key] = values.strip('|').split('|')
|
76 |
+
elements = list(table.keys())
|
77 |
+
non_comm = set()
|
78 |
+
for i, x in enumerate(elements):
|
79 |
+
for j, y in enumerate(elements):
|
80 |
+
a = table[x][j]
|
81 |
+
b = table[y][i]
|
82 |
+
if a != b:
|
83 |
+
non_comm.update([x, y])
|
84 |
+
return ", ".join(sorted(non_comm))
|
85 |
+
except:
|
86 |
+
return ""
|
87 |
+
|
88 |
def extract_answer(self, text, question):
|
89 |
q = question.lower()
|
90 |
text = text.strip().strip("\"'").strip()
|
|
|
100 |
match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", text)
|
101 |
return match.group(1) if match else text
|
102 |
|
103 |
+
if "comma separated list" in q:
|
104 |
+
words = re.findall(r"[a-zA-Z][a-zA-Z ]+[a-zA-Z]", text)
|
105 |
+
return ", ".join(sorted(set(w.strip().lower() for w in words)))
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
if "usd with two decimal places" in q:
|
108 |
match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", text)
|
|
|
141 |
file_bytes, ctype = self.fetch_file(task_id)
|
142 |
|
143 |
try:
|
144 |
+
if "commutative" in question.lower():
|
145 |
+
result = self.analyze_commutativity(question)
|
146 |
+
if result:
|
147 |
+
return result
|
148 |
+
|
149 |
if file_bytes and "image" in ctype:
|
150 |
raw = self.ask_image(file_bytes, question)
|
151 |
elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):
|