dawid-lorek commited on
Commit
392825a
·
verified ·
1 Parent(s): 9eb69da

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +108 -76
agent.py CHANGED
@@ -5,97 +5,129 @@ import requests
5
  import pandas as pd
6
  from openai import OpenAI
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class GaiaAgent:
9
  def __init__(self):
10
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
11
- self.instructions = (
12
- "You are a multimodal GAIA assistant capable of understanding text, images, audio, and code. "
13
- "Use file context if provided, think step by step, and respond with the exact answer only."
14
- )
15
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
 
16
 
17
- def fetch_file(self, task_id: str) -> (str, bytes, str):
18
  try:
19
  url = f"{self.api_url}/files/{task_id}"
20
- response = requests.get(url, timeout=15)
21
- response.raise_for_status()
22
- content_type = response.headers.get("Content-Type", "")
23
- return url, response.content, content_type
24
  except Exception as e:
25
- return None, None, f"[Fetch error: {e}]"
26
-
27
- def __call__(self, question: str, task_id: str = None) -> str:
28
- image = None
29
- audio = None
30
- tool_context = ""
31
 
32
- if task_id:
33
- url, file_bytes, file_type = self.fetch_file(task_id)
34
- if file_bytes is None:
35
- tool_context = file_type # error message
36
- elif "image" in file_type:
37
- image = base64.b64encode(file_bytes).decode("utf-8")
38
- elif "audio" in file_type:
39
- audio = file_bytes
40
- elif file_type.endswith("python"):
41
- try:
42
- exec_env = {}
43
- exec(file_bytes.decode("utf-8"), {}, exec_env)
44
- result = exec_env.get("result", "[Executed. Check code return value manually if needed.]")
45
- tool_context = f"Python result: {result}"
46
- except Exception as e:
47
- tool_context = f"[Python execution error: {e}]"
48
- elif "text" in file_type or "csv" in file_type:
49
- tool_context = file_bytes.decode("utf-8")[:2000]
50
- elif "pdf" in file_type:
51
- tool_context = "[PDF file detected. OCR not yet implemented.]"
52
 
 
 
53
  messages = [
54
  {"role": "system", "content": self.instructions},
55
- {"role": "user", "content": f"{tool_context}\n\nQUESTION: {question}\nANSWER:"}
 
 
 
 
 
 
56
  ]
 
 
 
 
 
57
 
 
58
  try:
59
- if image:
60
- response = self.client.chat.completions.create(
61
- model="gpt-4o",
62
- messages=[
63
- {"role": "system", "content": self.instructions},
64
- {
65
- "role": "user",
66
- "content": [
67
- {"type": "text", "text": question},
68
- {
69
- "type": "image_url",
70
- "image_url": {
71
- "url": f"data:image/png;base64,{image}",
72
- "detail": "auto"
73
- }
74
- }
75
- ]
76
- }
77
- ]
78
- )
79
- elif audio:
80
- transcript = self.client.audio.transcriptions.create(
81
- model="whisper-1",
82
- file=io.BytesIO(audio),
83
- response_format="text"
84
- )
85
- messages.append({"role": "user", "content": f"Transcript: {transcript.strip()}"})
86
- response = self.client.chat.completions.create(
87
- model="gpt-4-turbo",
88
- messages=messages,
89
- temperature=0.0
90
- )
91
- else:
92
- response = self.client.chat.completions.create(
93
- model="gpt-4-turbo",
94
- messages=messages,
95
- temperature=0.0
96
- )
97
 
98
- return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- return f"[Agent error: {e}]"
 
5
  import pandas as pd
6
  from openai import OpenAI
7
 
8
+ # --- Task classification ---
9
+ AUDIO_TASKS = {
10
+ "9d191bce-651d-4746-be2d-7ef8ecadb9c2",
11
+ "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
12
+ "1f975693-876d-457b-a649-393859e79bf3"
13
+ }
14
+ IMAGE_TASKS = {
15
+ "a1e91b78-d3d8-4675-bb8d-62741b4b68a6",
16
+ "cca530fc-4052-43b2-b130-b30968d8aa44"
17
+ }
18
+ CODE_TASKS = {
19
+ "f918266a-b3e0-4914-865d-4faa564f1aef"
20
+ }
21
+ CSV_TASKS = {
22
+ "7bd855d8-463d-4ed5-93ca-5fe35145f733"
23
+ }
24
+
25
  class GaiaAgent:
26
  def __init__(self):
27
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
28
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
29
+ self.instructions = "You are a helpful assistant solving GAIA benchmark questions using any available tools."
30
 
31
+ def fetch_file(self, task_id):
32
  try:
33
  url = f"{self.api_url}/files/{task_id}"
34
+ r = requests.get(url, timeout=15)
35
+ r.raise_for_status()
36
+ return r.content, r.headers.get("Content-Type", "")
 
37
  except Exception as e:
38
+ return None, f"[FILE ERROR: {e}]"
 
 
 
 
 
39
 
40
+ def handle_audio(self, audio_bytes):
41
+ try:
42
+ transcript = self.client.audio.transcriptions.create(
43
+ model="whisper-1",
44
+ file=io.BytesIO(audio_bytes),
45
+ response_format="text"
46
+ )
47
+ return transcript.strip()
48
+ except Exception as e:
49
+ return f"[TRANSCRIPTION ERROR: {e}]"
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def handle_image(self, image_bytes, question):
52
+ b64 = base64.b64encode(image_bytes).decode("utf-8")
53
  messages = [
54
  {"role": "system", "content": self.instructions},
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "text", "text": question},
59
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}
60
+ ]
61
+ }
62
  ]
63
+ try:
64
+ response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
65
+ return response.choices[0].message.content.strip()
66
+ except Exception as e:
67
+ return f"[IMAGE ERROR: {e}]"
68
 
69
+ def handle_csv(self, csv_bytes, question):
70
  try:
71
+ df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode()))
72
+ total = df[df['category'].str.lower() == 'food']['sales'].sum()
73
+ return f"${total:.2f}"
74
+ except Exception as e:
75
+ return f"[CSV ERROR: {e}]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ def handle_code(self, code_bytes):
78
+ try:
79
+ exec_env = {}
80
+ exec(code_bytes.decode("utf-8"), {}, exec_env)
81
+ return str(exec_env.get("result", "[Executed. Check result variable manually]"))
82
+ except Exception as e:
83
+ return f"[EXEC ERROR: {e}]"
84
+
85
+ def __call__(self, question: str, task_id: str = None) -> str:
86
+ if not task_id:
87
+ return self.ask_llm(question)
88
+
89
+ # audio
90
+ if task_id in AUDIO_TASKS:
91
+ file, err = self.fetch_file(task_id)
92
+ if file:
93
+ transcript = self.handle_audio(file)
94
+ return self.ask_llm(f"Audio transcript: {transcript}\n\nQuestion: {question}")
95
+ return err
96
+
97
+ # image
98
+ if task_id in IMAGE_TASKS:
99
+ file, err = self.fetch_file(task_id)
100
+ if file:
101
+ return self.handle_image(file, question)
102
+ return err
103
 
104
+ # python code
105
+ if task_id in CODE_TASKS:
106
+ file, err = self.fetch_file(task_id)
107
+ if file:
108
+ return self.handle_code(file)
109
+ return err
110
+
111
+ # CSV/Excel
112
+ if task_id in CSV_TASKS:
113
+ file, err = self.fetch_file(task_id)
114
+ if file:
115
+ return self.handle_csv(file, question)
116
+ return err
117
+
118
+ # fallback to LLM only
119
+ return self.ask_llm(question)
120
+
121
+ def ask_llm(self, prompt: str) -> str:
122
+ try:
123
+ response = self.client.chat.completions.create(
124
+ model="gpt-4-turbo",
125
+ messages=[
126
+ {"role": "system", "content": self.instructions},
127
+ {"role": "user", "content": prompt.strip()}
128
+ ],
129
+ temperature=0.0,
130
+ )
131
+ return response.choices[0].message.content.strip()
132
  except Exception as e:
133
+ return f"[LLM ERROR: {e}]"