dawid-lorek commited on
Commit
9eb69da
·
verified ·
1 Parent(s): d48b3cc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +74 -52
agent.py CHANGED
@@ -1,79 +1,101 @@
1
  import os
2
  import io
3
- import pandas as pd
4
  import requests
 
5
  from openai import OpenAI
6
 
7
- SKIPPED_TASKS = {
8
- # Tasks requiring video, image, or audio
9
- "a1e91b78-d3d8-4675-bb8d-62741b4b68a6", # YouTube birds
10
- "cca530fc-4052-43b2-b130-b30968d8aa44", # Chess image
11
- "9d191bce-651d-4746-be2d-7ef8ecadb9c2", # Teal'c audio
12
- "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3", # Strawberry pie.mp3
13
- "1f975693-876d-457b-a649-393859e79bf3" # Homework.mp3
14
- }
15
-
16
  class GaiaAgent:
17
  def __init__(self):
18
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
  self.instructions = (
20
- "You are a precise and logical assistant solving GAIA benchmark questions. "
21
- "Use any context or data provided. Respond with only the final answer."
22
  )
23
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
24
 
25
- def analyze_csv(self, csv_text: str, question: str) -> str:
26
- try:
27
- df = pd.read_csv(io.StringIO(csv_text))
28
- q = question.lower()
29
- if "total" in q and "food" in q and "not including drinks" in q:
30
- food_items = df[df["category"].str.lower() == "food"]
31
- return f"Total food sales: ${food_items["sales"].sum():.2f}"
32
- return f"Sample row: {df.iloc[0].to_dict()}"
33
- except Exception as e:
34
- return f"[CSV parse failed: {e}]"
35
-
36
- def fetch_file_context(self, task_id: str, question: str) -> str:
37
  try:
38
  url = f"{self.api_url}/files/{task_id}"
39
- response = requests.get(url, timeout=10)
40
  response.raise_for_status()
41
  content_type = response.headers.get("Content-Type", "")
42
-
43
- if "csv" in content_type or url.endswith(".csv"):
44
- return self.analyze_csv(response.text, question)
45
- elif "json" in content_type:
46
- return f"JSON Preview: {response.text[:1000]}"
47
- elif "text/plain" in content_type:
48
- return f"Text Preview: {response.text[:1000]}"
49
- elif "pdf" in content_type:
50
- return "[PDF detected. OCR not supported.]"
51
- else:
52
- return f"[Unsupported file type: {content_type}]"
53
-
54
  except Exception as e:
55
- return f"[File error: {e}]"
56
 
57
  def __call__(self, question: str, task_id: str = None) -> str:
58
- if task_id in SKIPPED_TASKS:
59
- return "SKIPPED"
 
60
 
61
- file_fact = ""
62
  if task_id:
63
- file_fact = self.fetch_file_context(task_id, question)
64
- file_fact = f"FILE CONTEXT:\n{file_fact}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- prompt = f"{self.instructions}\n\n{file_fact}QUESTION: {question}\nANSWER:"
 
 
 
67
 
68
  try:
69
- response = self.client.chat.completions.create(
70
- model="gpt-4-turbo",
71
- messages=[
72
- {"role": "system", "content": self.instructions},
73
- {"role": "user", "content": prompt}
74
- ],
75
- temperature=0.0,
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return response.choices[0].message.content.strip()
 
78
  except Exception as e:
79
  return f"[Agent error: {e}]"
 
1
  import os
2
  import io
3
+ import base64
4
  import requests
5
+ import pandas as pd
6
  from openai import OpenAI
7
 
 
 
 
 
 
 
 
 
 
8
  class GaiaAgent:
9
  def __init__(self):
10
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
11
  self.instructions = (
12
+ "You are a multimodal GAIA assistant capable of understanding text, images, audio, and code. "
13
+ "Use file context if provided, think step by step, and respond with the exact answer only."
14
  )
15
  self.api_url = "https://agents-course-unit4-scoring.hf.space"
16
 
17
+ def fetch_file(self, task_id: str) -> (str, bytes, str):
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
  url = f"{self.api_url}/files/{task_id}"
20
+ response = requests.get(url, timeout=15)
21
  response.raise_for_status()
22
  content_type = response.headers.get("Content-Type", "")
23
+ return url, response.content, content_type
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
+ return None, None, f"[Fetch error: {e}]"
26
 
27
  def __call__(self, question: str, task_id: str = None) -> str:
28
+ image = None
29
+ audio = None
30
+ tool_context = ""
31
 
 
32
  if task_id:
33
+ url, file_bytes, file_type = self.fetch_file(task_id)
34
+ if file_bytes is None:
35
+ tool_context = file_type # error message
36
+ elif "image" in file_type:
37
+ image = base64.b64encode(file_bytes).decode("utf-8")
38
+ elif "audio" in file_type:
39
+ audio = file_bytes
40
+ elif file_type.endswith("python"):
41
+ try:
42
+ exec_env = {}
43
+ exec(file_bytes.decode("utf-8"), {}, exec_env)
44
+ result = exec_env.get("result", "[Executed. Check code return value manually if needed.]")
45
+ tool_context = f"Python result: {result}"
46
+ except Exception as e:
47
+ tool_context = f"[Python execution error: {e}]"
48
+ elif "text" in file_type or "csv" in file_type:
49
+ tool_context = file_bytes.decode("utf-8")[:2000]
50
+ elif "pdf" in file_type:
51
+ tool_context = "[PDF file detected. OCR not yet implemented.]"
52
 
53
+ messages = [
54
+ {"role": "system", "content": self.instructions},
55
+ {"role": "user", "content": f"{tool_context}\n\nQUESTION: {question}\nANSWER:"}
56
+ ]
57
 
58
  try:
59
+ if image:
60
+ response = self.client.chat.completions.create(
61
+ model="gpt-4o",
62
+ messages=[
63
+ {"role": "system", "content": self.instructions},
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {"type": "text", "text": question},
68
+ {
69
+ "type": "image_url",
70
+ "image_url": {
71
+ "url": f"data:image/png;base64,{image}",
72
+ "detail": "auto"
73
+ }
74
+ }
75
+ ]
76
+ }
77
+ ]
78
+ )
79
+ elif audio:
80
+ transcript = self.client.audio.transcriptions.create(
81
+ model="whisper-1",
82
+ file=io.BytesIO(audio),
83
+ response_format="text"
84
+ )
85
+ messages.append({"role": "user", "content": f"Transcript: {transcript.strip()}"})
86
+ response = self.client.chat.completions.create(
87
+ model="gpt-4-turbo",
88
+ messages=messages,
89
+ temperature=0.0
90
+ )
91
+ else:
92
+ response = self.client.chat.completions.create(
93
+ model="gpt-4-turbo",
94
+ messages=messages,
95
+ temperature=0.0
96
+ )
97
+
98
  return response.choices[0].message.content.strip()
99
+
100
  except Exception as e:
101
  return f"[Agent error: {e}]"