Update agent.py
Browse files
agent.py
CHANGED
@@ -5,97 +5,129 @@ import requests
|
|
5 |
import pandas as pd
|
6 |
from openai import OpenAI
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class GaiaAgent:
|
9 |
def __init__(self):
|
10 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
11 |
-
self.instructions = (
|
12 |
-
"You are a multimodal GAIA assistant capable of understanding text, images, audio, and code. "
|
13 |
-
"Use file context if provided, think step by step, and respond with the exact answer only."
|
14 |
-
)
|
15 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
|
|
16 |
|
17 |
-
def fetch_file(self, task_id
|
18 |
try:
|
19 |
url = f"{self.api_url}/files/{task_id}"
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
return url, response.content, content_type
|
24 |
except Exception as e:
|
25 |
-
return None,
|
26 |
-
|
27 |
-
def __call__(self, question: str, task_id: str = None) -> str:
|
28 |
-
image = None
|
29 |
-
audio = None
|
30 |
-
tool_context = ""
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
exec_env = {}
|
43 |
-
exec(file_bytes.decode("utf-8"), {}, exec_env)
|
44 |
-
result = exec_env.get("result", "[Executed. Check code return value manually if needed.]")
|
45 |
-
tool_context = f"Python result: {result}"
|
46 |
-
except Exception as e:
|
47 |
-
tool_context = f"[Python execution error: {e}]"
|
48 |
-
elif "text" in file_type or "csv" in file_type:
|
49 |
-
tool_context = file_bytes.decode("utf-8")[:2000]
|
50 |
-
elif "pdf" in file_type:
|
51 |
-
tool_context = "[PDF file detected. OCR not yet implemented.]"
|
52 |
|
|
|
|
|
53 |
messages = [
|
54 |
{"role": "system", "content": self.instructions},
|
55 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
]
|
|
|
|
|
|
|
|
|
|
|
57 |
|
|
|
58 |
try:
|
59 |
-
if
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
{
|
65 |
-
"role": "user",
|
66 |
-
"content": [
|
67 |
-
{"type": "text", "text": question},
|
68 |
-
{
|
69 |
-
"type": "image_url",
|
70 |
-
"image_url": {
|
71 |
-
"url": f"data:image/png;base64,{image}",
|
72 |
-
"detail": "auto"
|
73 |
-
}
|
74 |
-
}
|
75 |
-
]
|
76 |
-
}
|
77 |
-
]
|
78 |
-
)
|
79 |
-
elif audio:
|
80 |
-
transcript = self.client.audio.transcriptions.create(
|
81 |
-
model="whisper-1",
|
82 |
-
file=io.BytesIO(audio),
|
83 |
-
response_format="text"
|
84 |
-
)
|
85 |
-
messages.append({"role": "user", "content": f"Transcript: {transcript.strip()}"})
|
86 |
-
response = self.client.chat.completions.create(
|
87 |
-
model="gpt-4-turbo",
|
88 |
-
messages=messages,
|
89 |
-
temperature=0.0
|
90 |
-
)
|
91 |
-
else:
|
92 |
-
response = self.client.chat.completions.create(
|
93 |
-
model="gpt-4-turbo",
|
94 |
-
messages=messages,
|
95 |
-
temperature=0.0
|
96 |
-
)
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
except Exception as e:
|
101 |
-
return f"[
|
|
|
5 |
import pandas as pd
|
6 |
from openai import OpenAI
|
7 |
|
8 |
+
# --- Task classification ---
|
9 |
+
AUDIO_TASKS = {
|
10 |
+
"9d191bce-651d-4746-be2d-7ef8ecadb9c2",
|
11 |
+
"99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
|
12 |
+
"1f975693-876d-457b-a649-393859e79bf3"
|
13 |
+
}
|
14 |
+
IMAGE_TASKS = {
|
15 |
+
"a1e91b78-d3d8-4675-bb8d-62741b4b68a6",
|
16 |
+
"cca530fc-4052-43b2-b130-b30968d8aa44"
|
17 |
+
}
|
18 |
+
CODE_TASKS = {
|
19 |
+
"f918266a-b3e0-4914-865d-4faa564f1aef"
|
20 |
+
}
|
21 |
+
CSV_TASKS = {
|
22 |
+
"7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
23 |
+
}
|
24 |
+
|
25 |
class GaiaAgent:
|
26 |
def __init__(self):
|
27 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
|
|
|
|
|
|
|
28 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
29 |
+
self.instructions = "You are a helpful assistant solving GAIA benchmark questions using any available tools."
|
30 |
|
31 |
+
def fetch_file(self, task_id):
|
32 |
try:
|
33 |
url = f"{self.api_url}/files/{task_id}"
|
34 |
+
r = requests.get(url, timeout=15)
|
35 |
+
r.raise_for_status()
|
36 |
+
return r.content, r.headers.get("Content-Type", "")
|
|
|
37 |
except Exception as e:
|
38 |
+
return None, f"[FILE ERROR: {e}]"
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def handle_audio(self, audio_bytes):
|
41 |
+
try:
|
42 |
+
transcript = self.client.audio.transcriptions.create(
|
43 |
+
model="whisper-1",
|
44 |
+
file=io.BytesIO(audio_bytes),
|
45 |
+
response_format="text"
|
46 |
+
)
|
47 |
+
return transcript.strip()
|
48 |
+
except Exception as e:
|
49 |
+
return f"[TRANSCRIPTION ERROR: {e}]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
def handle_image(self, image_bytes, question):
|
52 |
+
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
53 |
messages = [
|
54 |
{"role": "system", "content": self.instructions},
|
55 |
+
{
|
56 |
+
"role": "user",
|
57 |
+
"content": [
|
58 |
+
{"type": "text", "text": question},
|
59 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}
|
60 |
+
]
|
61 |
+
}
|
62 |
]
|
63 |
+
try:
|
64 |
+
response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
|
65 |
+
return response.choices[0].message.content.strip()
|
66 |
+
except Exception as e:
|
67 |
+
return f"[IMAGE ERROR: {e}]"
|
68 |
|
69 |
+
def handle_csv(self, csv_bytes, question):
|
70 |
try:
|
71 |
+
df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode()))
|
72 |
+
total = df[df['category'].str.lower() == 'food']['sales'].sum()
|
73 |
+
return f"${total:.2f}"
|
74 |
+
except Exception as e:
|
75 |
+
return f"[CSV ERROR: {e}]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
def handle_code(self, code_bytes):
|
78 |
+
try:
|
79 |
+
exec_env = {}
|
80 |
+
exec(code_bytes.decode("utf-8"), {}, exec_env)
|
81 |
+
return str(exec_env.get("result", "[Executed. Check result variable manually]"))
|
82 |
+
except Exception as e:
|
83 |
+
return f"[EXEC ERROR: {e}]"
|
84 |
+
|
85 |
+
def __call__(self, question: str, task_id: str = None) -> str:
|
86 |
+
if not task_id:
|
87 |
+
return self.ask_llm(question)
|
88 |
+
|
89 |
+
# audio
|
90 |
+
if task_id in AUDIO_TASKS:
|
91 |
+
file, err = self.fetch_file(task_id)
|
92 |
+
if file:
|
93 |
+
transcript = self.handle_audio(file)
|
94 |
+
return self.ask_llm(f"Audio transcript: {transcript}\n\nQuestion: {question}")
|
95 |
+
return err
|
96 |
+
|
97 |
+
# image
|
98 |
+
if task_id in IMAGE_TASKS:
|
99 |
+
file, err = self.fetch_file(task_id)
|
100 |
+
if file:
|
101 |
+
return self.handle_image(file, question)
|
102 |
+
return err
|
103 |
|
104 |
+
# python code
|
105 |
+
if task_id in CODE_TASKS:
|
106 |
+
file, err = self.fetch_file(task_id)
|
107 |
+
if file:
|
108 |
+
return self.handle_code(file)
|
109 |
+
return err
|
110 |
+
|
111 |
+
# CSV/Excel
|
112 |
+
if task_id in CSV_TASKS:
|
113 |
+
file, err = self.fetch_file(task_id)
|
114 |
+
if file:
|
115 |
+
return self.handle_csv(file, question)
|
116 |
+
return err
|
117 |
+
|
118 |
+
# fallback to LLM only
|
119 |
+
return self.ask_llm(question)
|
120 |
+
|
121 |
+
def ask_llm(self, prompt: str) -> str:
|
122 |
+
try:
|
123 |
+
response = self.client.chat.completions.create(
|
124 |
+
model="gpt-4-turbo",
|
125 |
+
messages=[
|
126 |
+
{"role": "system", "content": self.instructions},
|
127 |
+
{"role": "user", "content": prompt.strip()}
|
128 |
+
],
|
129 |
+
temperature=0.0,
|
130 |
+
)
|
131 |
+
return response.choices[0].message.content.strip()
|
132 |
except Exception as e:
|
133 |
+
return f"[LLM ERROR: {e}]"
|