dawid-lorek commited on
Commit
36284fd
·
verified ·
1 Parent(s): aef7057

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +40 -19
agent.py CHANGED
@@ -1,4 +1,4 @@
1
- # agent_v30.py
2
  import os
3
  import re
4
  import io
@@ -24,6 +24,12 @@ class GaiaAgent:
24
  except Exception:
25
  return None, None
26
 
 
 
 
 
 
 
27
  def ask(self, prompt, model="gpt-4-turbo"):
28
  response = self.client.chat.completions.create(
29
  model=model,
@@ -35,12 +41,6 @@ class GaiaAgent:
35
  )
36
  return response.choices[0].message.content.strip()
37
 
38
- def get_web_info(self, query):
39
- try:
40
- return self.search_tool.run(query)
41
- except Exception:
42
- return "[NO WEB INFO FOUND]"
43
-
44
  def ask_audio(self, audio_bytes, question):
45
  path = "/tmp/audio.mp3"
46
  with open(path, "wb") as f:
@@ -69,17 +69,36 @@ class GaiaAgent:
69
  df.columns = [col.lower() for col in df.columns]
70
  if 'category' in df.columns and 'sales' in df.columns:
71
  df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
72
- food_df = df[df['category'].str.lower().str.contains('food')]
73
  total = food_df['sales'].sum()
74
  return f"${total:.2f}" if not pd.isna(total) else "$0.00"
75
  except Exception:
76
  pass
77
  return "$0.00"
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def extract_answer(self, raw, question):
80
  q = question.lower()
81
  raw = raw.strip().strip("\"'").strip()
82
- raw = re.sub(r"^[-•\s]*", "", raw)
83
 
84
  if "studio albums" in q:
85
  try:
@@ -92,11 +111,6 @@ class GaiaAgent:
92
  match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", raw)
93
  return match.group(1) if match else raw
94
 
95
- if "vegetables" in q or "ingredients" in q:
96
- unwanted = {"pure", "extract", "granulated", "sugar", "juice", "vanilla", "ripe", "fresh", "whole", "bean", "pinch", "cups", "salt", "water"}
97
- terms = [t.lower() for t in re.findall(r"[a-zA-Z]+(?: [a-zA-Z]+)?", raw)]
98
- return ", ".join(sorted(set(t for t in terms if t.split()[0] not in unwanted)))
99
-
100
  if "usd with two decimal places" in q:
101
  match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", raw)
102
  return f"${float(match.group(1)):.2f}" if match else "$0.00"
@@ -120,6 +134,12 @@ class GaiaAgent:
120
  match = re.search(r"80NSSC[0-9A-Z]{6,7}", raw)
121
  return match.group(0) if match else raw
122
 
 
 
 
 
 
 
123
  return raw
124
 
125
  def __call__(self, question, task_id=None):
@@ -131,15 +151,16 @@ class GaiaAgent:
131
  if "youtube.com" in question:
132
  video_id = re.search(r"v=([\w-]+)", question)
133
  if video_id:
134
- summary = self.get_web_info(f"transcript or analysis of YouTube video {video_id.group(1)}")
135
- return self.ask(f"Video summary: {summary}\n\n{question}")
136
 
137
  if "malko competition" in question.lower():
138
- search = self.get_web_info("malko competition winners after 1977 yugoslavia site:wikipedia.org")
139
- return self.ask(f"Web result: {search}\n\n{question}")
140
 
141
  if "commutative" in question:
142
- return self.ask(f"Based on this table, which elements show the operation is not commutative?\n{question}\nList them comma-separated, alphabetically.")
 
143
 
144
  if file_bytes and "image" in ctype:
145
  raw = self.ask_image(file_bytes, question)
 
1
+ # agent_v31.py
2
  import os
3
  import re
4
  import io
 
24
  except Exception:
25
  return None, None
26
 
27
+ def get_web_info(self, query):
28
+ try:
29
+ return self.search_tool.run(query)
30
+ except Exception:
31
+ return "[NO WEB INFO FOUND]"
32
+
33
  def ask(self, prompt, model="gpt-4-turbo"):
34
  response = self.client.chat.completions.create(
35
  model=model,
 
41
  )
42
  return response.choices[0].message.content.strip()
43
 
 
 
 
 
 
 
44
  def ask_audio(self, audio_bytes, question):
45
  path = "/tmp/audio.mp3"
46
  with open(path, "wb") as f:
 
69
  df.columns = [col.lower() for col in df.columns]
70
  if 'category' in df.columns and 'sales' in df.columns:
71
  df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
72
+ food_df = df[df['category'].str.lower() == 'food']
73
  total = food_df['sales'].sum()
74
  return f"${total:.2f}" if not pd.isna(total) else "$0.00"
75
  except Exception:
76
  pass
77
  return "$0.00"
78
 
79
+ def extract_commutative_set(self, question):
80
+ try:
81
+ rows = re.findall(r"\|([a-e])\|([a-e\|]+)\|", question)
82
+ table = {}
83
+ for row in rows:
84
+ key, values = row
85
+ table[key] = values.strip('|').split('|')
86
+ elements = list(table.keys())
87
+ non_comm = set()
88
+ for i, x in enumerate(elements):
89
+ for j, y in enumerate(elements):
90
+ if x != y:
91
+ a = table[x][j]
92
+ b = table[y][i]
93
+ if a != b:
94
+ non_comm.update([x, y])
95
+ return ", ".join(sorted(non_comm))
96
+ except:
97
+ return ""
98
+
99
  def extract_answer(self, raw, question):
100
  q = question.lower()
101
  raw = raw.strip().strip("\"'").strip()
 
102
 
103
  if "studio albums" in q:
104
  try:
 
111
  match = re.search(r"\b([KQBNR]?[a-h]?[1-8]?x?[a-h][1-8][+#]?)\b", raw)
112
  return match.group(1) if match else raw
113
 
 
 
 
 
 
114
  if "usd with two decimal places" in q:
115
  match = re.search(r"\$?([0-9]+(?:\.[0-9]{1,2})?)", raw)
116
  return f"${float(match.group(1)):.2f}" if match else "$0.00"
 
134
  match = re.search(r"80NSSC[0-9A-Z]{6,7}", raw)
135
  return match.group(0) if match else raw
136
 
137
+ if "vegetables" in q or "ingredients" in q:
138
+ stopwords = set(["pure", "extract", "granulated", "sugar", "juice", "vanilla", "ripe", "fresh", "whole", "bean", "pinch", "cups", "salt", "water"])
139
+ tokens = [t.lower() for t in re.findall(r"[a-zA-Z]+", raw)]
140
+ clean = [t for t in tokens if t not in stopwords and len(t) > 2]
141
+ return ", ".join(sorted(set(clean)))
142
+
143
  return raw
144
 
145
  def __call__(self, question, task_id=None):
 
151
  if "youtube.com" in question:
152
  video_id = re.search(r"v=([\w-]+)", question)
153
  if video_id:
154
+ summary = self.get_web_info(f"youtube video transcript {video_id.group(1)}")
155
+ return self.ask(f"Transcript: {summary}\n\n{question}")
156
 
157
  if "malko competition" in question.lower():
158
+ search = self.get_web_info("malko competition winner yugoslavia after 1977 site:wikipedia.org")
159
+ return self.ask(f"Using the search result:\n{search}\n\n{question}")
160
 
161
  if "commutative" in question:
162
+ result = self.extract_commutative_set(question)
163
+ return result
164
 
165
  if file_bytes and "image" in ctype:
166
  raw = self.ask_image(file_bytes, question)