dawid-lorek commited on
Commit
40f559b
·
verified ·
1 Parent(s): 36284fd

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +86 -124
agent.py CHANGED
@@ -1,4 +1,4 @@
1
- # agent_v31.py
2
  import os
3
  import re
4
  import io
@@ -24,159 +24,121 @@ class GaiaAgent:
24
  except Exception:
25
  return None, None
26
 
27
- def get_web_info(self, query):
28
  try:
29
- return self.search_tool.run(query)
30
  except Exception:
31
  return "[NO WEB INFO FOUND]"
32
 
33
- def ask(self, prompt, model="gpt-4-turbo"):
 
 
 
 
34
  response = self.client.chat.completions.create(
35
  model=model,
36
- messages=[
37
- {"role": "system", "content": "Return only a short factual answer. Format it properly. Never guess."},
38
- {"role": "user", "content": prompt.strip() + "\nAnswer:"}
39
- ],
40
  temperature=0.0,
41
  )
42
  return response.choices[0].message.content.strip()
43
 
44
- def ask_audio(self, audio_bytes, question):
45
- path = "/tmp/audio.mp3"
46
- with open(path, "wb") as f:
47
- f.write(audio_bytes)
48
- transcript = self.client.audio.transcriptions.create(model="whisper-1", file=open(path, "rb"))
49
- return self.ask(f"Audio transcript: {transcript.text}\n\n{question}")
50
-
51
- def ask_image(self, image_bytes, question):
52
- image_b64 = base64.b64encode(image_bytes).decode("utf-8")
53
- messages = [
54
- {"role": "system", "content": "Return only the winning move in chess algebraic notation (e.g., Qd1). No explanation."},
55
- {
56
- "role": "user",
57
- "content": [
58
- {"type": "text", "text": question},
59
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
60
- ]
61
- }
62
- ]
63
- response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
64
- return response.choices[0].message.content.strip()
65
-
66
- def extract_from_excel(self, file_bytes):
67
- try:
68
- df = pd.read_excel(io.BytesIO(file_bytes), engine="openpyxl")
69
- df.columns = [col.lower() for col in df.columns]
70
- if 'category' in df.columns and 'sales' in df.columns:
71
- df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
72
- food_df = df[df['category'].str.lower() == 'food']
73
- total = food_df['sales'].sum()
74
- return f"${total:.2f}" if not pd.isna(total) else "$0.00"
75
- except Exception:
76
- pass
77
- return "$0.00"
78
-
79
- def extract_commutative_set(self, question):
80
- try:
81
- rows = re.findall(r"\|([a-e])\|([a-e\|]+)\|", question)
82
- table = {}
83
- for row in rows:
84
- key, values = row
85
- table[key] = values.strip('|').split('|')
86
- elements = list(table.keys())
87
- non_comm = set()
88
- for i, x in enumerate(elements):
89
- for j, y in enumerate(elements):
90
- if x != y:
91
- a = table[x][j]
92
- b = table[y][i]
93
- if a != b:
94
- non_comm.update([x, y])
95
- return ", ".join(sorted(non_comm))
96
- except:
97
- return ""
98
-
99
- def extract_answer(self, raw, question):
100
  q = question.lower()
101
- raw = raw.strip().strip("\"'").strip()
102
-
103
- if "studio albums" in q:
104
- try:
105
- return str(w2n.word_to_num(raw))
106
- except:
107
- match = re.search(r"\b\d+\b", raw)
108
- return match.group(0) if match else raw
109
-
110
- if "algebraic notation" in q:
111
- match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", raw)
112
- return match.group(1) if match else raw
113
 
114
  if "usd with two decimal places" in q:
115
- match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", raw)
116
  return f"${float(match.group(1)):.2f}" if match else "$0.00"
117
 
 
 
 
 
118
  if "ioc country code" in q:
119
- match = re.search(r"\b[A-Z]{3}\b", raw.upper())
120
  return match.group(0)
121
 
 
 
 
122
  if "page numbers" in q:
123
- pages = sorted(set(re.findall(r"\b\d+\b", raw)))
124
- return ", ".join(pages)
125
 
126
  if "at bats" in q:
127
- match = re.search(r"\b(\d{3,4})\b", raw)
128
- return match.group(1)
129
 
130
- if "first name" in q:
131
- return raw.split()[0]
 
 
 
 
132
 
133
  if "award number" in q:
134
- match = re.search(r"80NSSC[0-9A-Z]{6,7}", raw)
135
- return match.group(0) if match else raw
136
 
137
  if "vegetables" in q or "ingredients" in q:
138
- stopwords = set(["pure", "extract", "granulated", "sugar", "juice", "vanilla", "ripe", "fresh", "whole", "bean", "pinch", "cups", "salt", "water"])
139
- tokens = [t.lower() for t in re.findall(r"[a-zA-Z]+", raw)]
140
- clean = [t for t in tokens if t not in stopwords and len(t) > 2]
141
- return ", ".join(sorted(set(clean)))
 
 
142
 
143
- return raw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def __call__(self, question, task_id=None):
146
  file_bytes, ctype = None, ""
147
  if task_id:
148
  file_bytes, ctype = self.fetch_file(task_id)
149
 
150
- try:
151
- if "youtube.com" in question:
152
- video_id = re.search(r"v=([\w-]+)", question)
153
- if video_id:
154
- summary = self.get_web_info(f"youtube video transcript {video_id.group(1)}")
155
- return self.ask(f"Transcript: {summary}\n\n{question}")
156
-
157
- if "malko competition" in question.lower():
158
- search = self.get_web_info("malko competition winner yugoslavia after 1977 site:wikipedia.org")
159
- return self.ask(f"Using the search result:\n{search}\n\n{question}")
160
-
161
- if "commutative" in question:
162
- result = self.extract_commutative_set(question)
163
- return result
164
-
165
- if file_bytes and "image" in ctype:
166
- raw = self.ask_image(file_bytes, question)
167
- elif file_bytes and ("audio" in ctype or task_id.endswith(".mp3")):
168
- raw = self.ask_audio(file_bytes, question)
169
- elif file_bytes and ("excel" in ctype or task_id.endswith(".xlsx")):
170
- return self.extract_from_excel(file_bytes)
171
- elif file_bytes:
172
- try:
173
- text = file_bytes.decode("utf-8")
174
- raw = self.ask(f"Text content:\n{text[:3000]}\n\n{question}")
175
- except:
176
- raw = "[UNREADABLE FILE CONTENT]"
177
- else:
178
- raw = self.ask(question)
179
- except Exception as e:
180
- return f"[ERROR: {e}]"
181
-
182
- return self.extract_answer(raw, question)
 
1
+ # agent_v31.py (wersja generyczna – podejście uniwersalne bez ifów per pytanie)
2
  import os
3
  import re
4
  import io
 
24
  except Exception:
25
  return None, None
26
 
27
+ def search_web_context(self, question):
28
  try:
29
+ return self.search_tool.run(question)
30
  except Exception:
31
  return "[NO WEB INFO FOUND]"
32
 
33
+ def ask(self, context, question, model="gpt-4-turbo"):
34
+ messages = [
35
+ {"role": "system", "content": "You are an expert assistant. Use provided web or file context to answer. Output only the short final answer, formatted correctly. Do not explain."},
36
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"}
37
+ ]
38
  response = self.client.chat.completions.create(
39
  model=model,
40
+ messages=messages,
 
 
 
41
  temperature=0.0,
42
  )
43
  return response.choices[0].message.content.strip()
44
 
45
+ def format_answer(self, answer, question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  q = question.lower()
47
+ a = answer.strip().strip("\"'").strip()
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  if "usd with two decimal places" in q:
50
+ match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", a)
51
  return f"${float(match.group(1)):.2f}" if match else "$0.00"
52
 
53
+ if "algebraic notation" in q:
54
+ match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", a)
55
+ return match.group(1) if match else a
56
+
57
  if "ioc country code" in q:
58
+ match = re.search(r"\b[A-Z]{3}\b", a.upper())
59
  return match.group(0)
60
 
61
+ if "first name" in q:
62
+ return a.split()[0]
63
+
64
  if "page numbers" in q:
65
+ nums = sorted(set(re.findall(r"\b\d+\b", a)))
66
+ return ", ".join(nums)
67
 
68
  if "at bats" in q:
69
+ match = re.search(r"\b(\d{3,4})\b", a)
70
+ return match.group(1) if match else a
71
 
72
+ if "studio albums" in q or "how many" in q:
73
+ try:
74
+ return str(w2n.word_to_num(a))
75
+ except:
76
+ match = re.search(r"\b\d+\b", a)
77
+ return match.group(0) if match else a
78
 
79
  if "award number" in q:
80
+ match = re.search(r"80NSSC[0-9A-Z]{6,7}", a)
81
+ return match.group(0) if match else a
82
 
83
  if "vegetables" in q or "ingredients" in q:
84
+ tokens = [t.lower() for t in re.findall(r"[a-zA-Z]+", a)]
85
+ blacklist = {"extract", "juice", "pure", "vanilla", "sugar", "granulated", "fresh", "ripe", "pinch", "water", "whole", "cups", "salt"}
86
+ clean = sorted(set(t for t in tokens if t not in blacklist and len(t) > 2))
87
+ return ", ".join(clean)
88
+
89
+ return a
90
 
91
+ def handle_file_context(self, file_bytes, ctype, question):
92
+ if not file_bytes:
93
+ return ""
94
+ if "image" in ctype:
95
+ image_b64 = base64.b64encode(file_bytes).decode("utf-8")
96
+ messages = [
97
+ {"role": "system", "content": "You're a visual reasoning assistant. Answer the question based on the image. Output only the move in algebraic notation."},
98
+ {
99
+ "role": "user",
100
+ "content": [
101
+ {"type": "text", "text": question},
102
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
103
+ ]
104
+ }
105
+ ]
106
+ response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
107
+ return response.choices[0].message.content.strip()
108
+ elif "audio" in ctype or question.endswith(".mp3"):
109
+ path = "/tmp/audio.mp3"
110
+ with open(path, "wb") as f:
111
+ f.write(file_bytes)
112
+ transcript = self.client.audio.transcriptions.create(model="whisper-1", file=open(path, "rb"))
113
+ return transcript.text
114
+ elif "excel" in ctype or question.endswith(".xlsx"):
115
+ try:
116
+ df = pd.read_excel(io.BytesIO(file_bytes), engine="openpyxl")
117
+ df.columns = [c.lower() for c in df.columns]
118
+ df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
119
+ food_df = df[df['category'].str.lower() == 'food']
120
+ total = food_df['sales'].sum()
121
+ return f"${total:.2f}" if not pd.isna(total) else "$0.00"
122
+ except Exception:
123
+ return "[EXCEL ERROR]"
124
+ else:
125
+ try:
126
+ return file_bytes.decode("utf-8")[:3000]
127
+ except:
128
+ return ""
129
 
130
  def __call__(self, question, task_id=None):
131
  file_bytes, ctype = None, ""
132
  if task_id:
133
  file_bytes, ctype = self.fetch_file(task_id)
134
 
135
+ file_context = self.handle_file_context(file_bytes, ctype, question)
136
+ if file_context and not file_context.startswith("$"):
137
+ raw = self.ask(file_context, question)
138
+ elif file_context.startswith("$"):
139
+ return file_context # Excel result
140
+ else:
141
+ web_context = self.search_web_context(question)
142
+ raw = self.ask(web_context, question)
143
+
144
+ return self.format_answer(raw, question)