Update agent.py
Browse files
agent.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# agent_v31.py
|
2 |
import os
|
3 |
import re
|
4 |
import io
|
@@ -24,159 +24,121 @@ class GaiaAgent:
|
|
24 |
except Exception:
|
25 |
return None, None
|
26 |
|
27 |
-
def
|
28 |
try:
|
29 |
-
return self.search_tool.run(
|
30 |
except Exception:
|
31 |
return "[NO WEB INFO FOUND]"
|
32 |
|
33 |
-
def ask(self,
|
|
|
|
|
|
|
|
|
34 |
response = self.client.chat.completions.create(
|
35 |
model=model,
|
36 |
-
messages=
|
37 |
-
{"role": "system", "content": "Return only a short factual answer. Format it properly. Never guess."},
|
38 |
-
{"role": "user", "content": prompt.strip() + "\nAnswer:"}
|
39 |
-
],
|
40 |
temperature=0.0,
|
41 |
)
|
42 |
return response.choices[0].message.content.strip()
|
43 |
|
44 |
-
def
|
45 |
-
path = "/tmp/audio.mp3"
|
46 |
-
with open(path, "wb") as f:
|
47 |
-
f.write(audio_bytes)
|
48 |
-
transcript = self.client.audio.transcriptions.create(model="whisper-1", file=open(path, "rb"))
|
49 |
-
return self.ask(f"Audio transcript: {transcript.text}\n\n{question}")
|
50 |
-
|
51 |
-
def ask_image(self, image_bytes, question):
|
52 |
-
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
53 |
-
messages = [
|
54 |
-
{"role": "system", "content": "Return only the winning move in chess algebraic notation (e.g., Qd1). No explanation."},
|
55 |
-
{
|
56 |
-
"role": "user",
|
57 |
-
"content": [
|
58 |
-
{"type": "text", "text": question},
|
59 |
-
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
60 |
-
]
|
61 |
-
}
|
62 |
-
]
|
63 |
-
response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
|
64 |
-
return response.choices[0].message.content.strip()
|
65 |
-
|
66 |
-
def extract_from_excel(self, file_bytes):
|
67 |
-
try:
|
68 |
-
df = pd.read_excel(io.BytesIO(file_bytes), engine="openpyxl")
|
69 |
-
df.columns = [col.lower() for col in df.columns]
|
70 |
-
if 'category' in df.columns and 'sales' in df.columns:
|
71 |
-
df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
|
72 |
-
food_df = df[df['category'].str.lower() == 'food']
|
73 |
-
total = food_df['sales'].sum()
|
74 |
-
return f"${total:.2f}" if not pd.isna(total) else "$0.00"
|
75 |
-
except Exception:
|
76 |
-
pass
|
77 |
-
return "$0.00"
|
78 |
-
|
79 |
-
def extract_commutative_set(self, question):
|
80 |
-
try:
|
81 |
-
rows = re.findall(r"\|([a-e])\|([a-e\|]+)\|", question)
|
82 |
-
table = {}
|
83 |
-
for row in rows:
|
84 |
-
key, values = row
|
85 |
-
table[key] = values.strip('|').split('|')
|
86 |
-
elements = list(table.keys())
|
87 |
-
non_comm = set()
|
88 |
-
for i, x in enumerate(elements):
|
89 |
-
for j, y in enumerate(elements):
|
90 |
-
if x != y:
|
91 |
-
a = table[x][j]
|
92 |
-
b = table[y][i]
|
93 |
-
if a != b:
|
94 |
-
non_comm.update([x, y])
|
95 |
-
return ", ".join(sorted(non_comm))
|
96 |
-
except:
|
97 |
-
return ""
|
98 |
-
|
99 |
-
def extract_answer(self, raw, question):
|
100 |
q = question.lower()
|
101 |
-
|
102 |
-
|
103 |
-
if "studio albums" in q:
|
104 |
-
try:
|
105 |
-
return str(w2n.word_to_num(raw))
|
106 |
-
except:
|
107 |
-
match = re.search(r"\b\d+\b", raw)
|
108 |
-
return match.group(0) if match else raw
|
109 |
-
|
110 |
-
if "algebraic notation" in q:
|
111 |
-
match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", raw)
|
112 |
-
return match.group(1) if match else raw
|
113 |
|
114 |
if "usd with two decimal places" in q:
|
115 |
-
match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)",
|
116 |
return f"${float(match.group(1)):.2f}" if match else "$0.00"
|
117 |
|
|
|
|
|
|
|
|
|
118 |
if "ioc country code" in q:
|
119 |
-
match = re.search(r"\b[A-Z]{3}\b",
|
120 |
return match.group(0)
|
121 |
|
|
|
|
|
|
|
122 |
if "page numbers" in q:
|
123 |
-
|
124 |
-
return ", ".join(
|
125 |
|
126 |
if "at bats" in q:
|
127 |
-
match = re.search(r"\b(\d{3,4})\b",
|
128 |
-
return match.group(1)
|
129 |
|
130 |
-
if "
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
|
133 |
if "award number" in q:
|
134 |
-
match = re.search(r"80NSSC[0-9A-Z]{6,7}",
|
135 |
-
return match.group(0) if match else
|
136 |
|
137 |
if "vegetables" in q or "ingredients" in q:
|
138 |
-
|
139 |
-
|
140 |
-
clean =
|
141 |
-
return ", ".join(
|
|
|
|
|
142 |
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
def __call__(self, question, task_id=None):
|
146 |
file_bytes, ctype = None, ""
|
147 |
if task_id:
|
148 |
file_bytes, ctype = self.fetch_file(task_id)
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
if "commutative" in question:
|
162 |
-
result = self.extract_commutative_set(question)
|
163 |
-
return result
|
164 |
-
|
165 |
-
if file_bytes and "image" in ctype:
|
166 |
-
raw = self.ask_image(file_bytes, question)
|
167 |
-
elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):
|
168 |
-
raw = self.ask_audio(file_bytes, question)
|
169 |
-
elif file_bytes and ("excel" in ctype or task_id.endswith(".xlsx")):
|
170 |
-
return self.extract_from_excel(file_bytes)
|
171 |
-
elif file_bytes:
|
172 |
-
try:
|
173 |
-
text = file_bytes.decode("utf-8")
|
174 |
-
raw = self.ask(f"Text content:\n{text[:3000]}\n\n{question}")
|
175 |
-
except:
|
176 |
-
raw = "[UNREADABLE FILE CONTENT]"
|
177 |
-
else:
|
178 |
-
raw = self.ask(question)
|
179 |
-
except Exception as e:
|
180 |
-
return f"[ERROR: {e}]"
|
181 |
-
|
182 |
-
return self.extract_answer(raw, question)
|
|
|
1 |
+
# agent_v31.py (wersja generyczna – podejście uniwersalne bez ifów per pytanie)
|
2 |
import os
|
3 |
import re
|
4 |
import io
|
|
|
24 |
except Exception:
|
25 |
return None, None
|
26 |
|
27 |
+
def search_web_context(self, question):
|
28 |
try:
|
29 |
+
return self.search_tool.run(question)
|
30 |
except Exception:
|
31 |
return "[NO WEB INFO FOUND]"
|
32 |
|
33 |
+
def ask(self, context, question, model="gpt-4-turbo"):
|
34 |
+
messages = [
|
35 |
+
{"role": "system", "content": "You are an expert assistant. Use provided web or file context to answer. Output only the short final answer, formatted correctly. Do not explain."},
|
36 |
+
{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"}
|
37 |
+
]
|
38 |
response = self.client.chat.completions.create(
|
39 |
model=model,
|
40 |
+
messages=messages,
|
|
|
|
|
|
|
41 |
temperature=0.0,
|
42 |
)
|
43 |
return response.choices[0].message.content.strip()
|
44 |
|
45 |
+
def format_answer(self, answer, question):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
q = question.lower()
|
47 |
+
a = answer.strip().strip("\"'").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if "usd with two decimal places" in q:
|
50 |
+
match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", a)
|
51 |
return f"${float(match.group(1)):.2f}" if match else "$0.00"
|
52 |
|
53 |
+
if "algebraic notation" in q:
|
54 |
+
match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", a)
|
55 |
+
return match.group(1) if match else a
|
56 |
+
|
57 |
if "ioc country code" in q:
|
58 |
+
match = re.search(r"\b[A-Z]{3}\b", a.upper())
|
59 |
return match.group(0)
|
60 |
|
61 |
+
if "first name" in q:
|
62 |
+
return a.split()[0]
|
63 |
+
|
64 |
if "page numbers" in q:
|
65 |
+
nums = sorted(set(re.findall(r"\b\d+\b", a)))
|
66 |
+
return ", ".join(nums)
|
67 |
|
68 |
if "at bats" in q:
|
69 |
+
match = re.search(r"\b(\d{3,4})\b", a)
|
70 |
+
return match.group(1) if match else a
|
71 |
|
72 |
+
if "studio albums" in q or "how many" in q:
|
73 |
+
try:
|
74 |
+
return str(w2n.word_to_num(a))
|
75 |
+
except:
|
76 |
+
match = re.search(r"\b\d+\b", a)
|
77 |
+
return match.group(0) if match else a
|
78 |
|
79 |
if "award number" in q:
|
80 |
+
match = re.search(r"80NSSC[0-9A-Z]{6,7}", a)
|
81 |
+
return match.group(0) if match else a
|
82 |
|
83 |
if "vegetables" in q or "ingredients" in q:
|
84 |
+
tokens = [t.lower() for t in re.findall(r"[a-zA-Z]+", a)]
|
85 |
+
blacklist = {"extract", "juice", "pure", "vanilla", "sugar", "granulated", "fresh", "ripe", "pinch", "water", "whole", "cups", "salt"}
|
86 |
+
clean = sorted(set(t for t in tokens if t not in blacklist and len(t) > 2))
|
87 |
+
return ", ".join(clean)
|
88 |
+
|
89 |
+
return a
|
90 |
|
91 |
+
def handle_file_context(self, file_bytes, ctype, question):
|
92 |
+
if not file_bytes:
|
93 |
+
return ""
|
94 |
+
if "image" in ctype:
|
95 |
+
image_b64 = base64.b64encode(file_bytes).decode("utf-8")
|
96 |
+
messages = [
|
97 |
+
{"role": "system", "content": "You're a visual reasoning assistant. Answer the question based on the image. Output only the move in algebraic notation."},
|
98 |
+
{
|
99 |
+
"role": "user",
|
100 |
+
"content": [
|
101 |
+
{"type": "text", "text": question},
|
102 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
103 |
+
]
|
104 |
+
}
|
105 |
+
]
|
106 |
+
response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
|
107 |
+
return response.choices[0].message.content.strip()
|
108 |
+
elif "audio" in ctype or question.endswith(".mp3"):
|
109 |
+
path = "/tmp/audio.mp3"
|
110 |
+
with open(path, "wb") as f:
|
111 |
+
f.write(file_bytes)
|
112 |
+
transcript = self.client.audio.transcriptions.create(model="whisper-1", file=open(path, "rb"))
|
113 |
+
return transcript.text
|
114 |
+
elif "excel" in ctype or question.endswith(".xlsx"):
|
115 |
+
try:
|
116 |
+
df = pd.read_excel(io.BytesIO(file_bytes), engine="openpyxl")
|
117 |
+
df.columns = [c.lower() for c in df.columns]
|
118 |
+
df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
|
119 |
+
food_df = df[df['category'].str.lower() == 'food']
|
120 |
+
total = food_df['sales'].sum()
|
121 |
+
return f"${total:.2f}" if not pd.isna(total) else "$0.00"
|
122 |
+
except Exception:
|
123 |
+
return "[EXCEL ERROR]"
|
124 |
+
else:
|
125 |
+
try:
|
126 |
+
return file_bytes.decode("utf-8")[:3000]
|
127 |
+
except:
|
128 |
+
return ""
|
129 |
|
130 |
def __call__(self, question, task_id=None):
|
131 |
file_bytes, ctype = None, ""
|
132 |
if task_id:
|
133 |
file_bytes, ctype = self.fetch_file(task_id)
|
134 |
|
135 |
+
file_context = self.handle_file_context(file_bytes, ctype, question)
|
136 |
+
if file_context and not file_context.startswith("$"):
|
137 |
+
raw = self.ask(file_context, question)
|
138 |
+
elif file_context.startswith("$"):
|
139 |
+
return file_context # Excel result
|
140 |
+
else:
|
141 |
+
web_context = self.search_web_context(question)
|
142 |
+
raw = self.ask(web_context, question)
|
143 |
+
|
144 |
+
return self.format_answer(raw, question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|