Update agent.py
Browse files
agent.py
CHANGED
@@ -5,128 +5,75 @@ import requests
|
|
5 |
import pandas as pd
|
6 |
from openai import OpenAI
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
"
|
11 |
-
"
|
12 |
-
"
|
13 |
-
|
14 |
-
|
15 |
-
"
|
16 |
-
"cca530fc-4052-43b2-b130-b30968d8aa44"
|
17 |
-
}
|
18 |
-
CODE_TASKS = {
|
19 |
-
"f918266a-b3e0-4914-865d-4faa564f1aef"
|
20 |
}
|
|
|
21 |
CSV_TASKS = {
|
22 |
-
"7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
23 |
}
|
24 |
|
25 |
class GaiaAgent:
|
26 |
def __init__(self):
|
27 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
28 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
29 |
-
self.instructions =
|
|
|
|
|
|
|
30 |
|
31 |
def fetch_file(self, task_id):
|
32 |
try:
|
33 |
url = f"{self.api_url}/files/{task_id}"
|
34 |
-
r = requests.get(url, timeout=
|
35 |
r.raise_for_status()
|
36 |
return r.content, r.headers.get("Content-Type", "")
|
37 |
except Exception as e:
|
38 |
return None, f"[FILE ERROR: {e}]"
|
39 |
|
40 |
-
def
|
41 |
-
try:
|
42 |
-
transcript = self.client.audio.transcriptions.create(
|
43 |
-
model="whisper-1",
|
44 |
-
file=io.BytesIO(audio_bytes),
|
45 |
-
response_format="text"
|
46 |
-
)
|
47 |
-
return transcript.strip()
|
48 |
-
except Exception as e:
|
49 |
-
return f"[TRANSCRIPTION ERROR: {e}]"
|
50 |
-
|
51 |
-
def handle_image(self, image_bytes, question):
|
52 |
-
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
53 |
-
messages = [
|
54 |
-
{"role": "system", "content": self.instructions},
|
55 |
-
{
|
56 |
-
"role": "user",
|
57 |
-
"content": [
|
58 |
-
{"type": "text", "text": question},
|
59 |
-
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}
|
60 |
-
]
|
61 |
-
}
|
62 |
-
]
|
63 |
-
try:
|
64 |
-
response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
|
65 |
-
return response.choices[0].message.content.strip()
|
66 |
-
except Exception as e:
|
67 |
-
return f"[IMAGE ERROR: {e}]"
|
68 |
-
|
69 |
-
def handle_csv(self, csv_bytes, question):
|
70 |
try:
|
71 |
df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode()))
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
return f"${total:.2f}"
|
74 |
except Exception as e:
|
75 |
return f"[CSV ERROR: {e}]"
|
76 |
|
77 |
-
def handle_code(self, code_bytes):
|
78 |
-
try:
|
79 |
-
exec_env = {}
|
80 |
-
exec(code_bytes.decode("utf-8"), {}, exec_env)
|
81 |
-
return str(exec_env.get("result", "[Executed. Check result variable manually]"))
|
82 |
-
except Exception as e:
|
83 |
-
return f"[EXEC ERROR: {e}]"
|
84 |
-
|
85 |
def __call__(self, question: str, task_id: str = None) -> str:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
# audio
|
90 |
-
if task_id in AUDIO_TASKS:
|
91 |
-
file, err = self.fetch_file(task_id)
|
92 |
-
if file:
|
93 |
-
transcript = self.handle_audio(file)
|
94 |
-
return self.ask_llm(f"Audio transcript: {transcript}\n\nQuestion: {question}")
|
95 |
-
return err
|
96 |
|
97 |
-
#
|
98 |
-
if task_id in IMAGE_TASKS:
|
99 |
-
file, err = self.fetch_file(task_id)
|
100 |
-
if file:
|
101 |
-
return self.handle_image(file, question)
|
102 |
-
return err
|
103 |
-
|
104 |
-
# python code
|
105 |
-
if task_id in CODE_TASKS:
|
106 |
-
file, err = self.fetch_file(task_id)
|
107 |
-
if file:
|
108 |
-
return self.handle_code(file)
|
109 |
-
return err
|
110 |
-
|
111 |
-
# CSV/Excel
|
112 |
if task_id in CSV_TASKS:
|
113 |
-
|
114 |
-
if
|
115 |
-
|
|
|
|
|
|
|
116 |
return err
|
117 |
|
118 |
-
#
|
119 |
-
return self.ask_llm(question)
|
120 |
-
|
121 |
-
def ask_llm(self, prompt: str) -> str:
|
122 |
try:
|
123 |
response = self.client.chat.completions.create(
|
124 |
model="gpt-4-turbo",
|
125 |
messages=[
|
126 |
{"role": "system", "content": self.instructions},
|
127 |
-
{"role": "user", "content":
|
128 |
],
|
129 |
-
temperature=0.0
|
130 |
)
|
131 |
return response.choices[0].message.content.strip()
|
132 |
except Exception as e:
|
|
|
5 |
import pandas as pd
|
6 |
from openai import OpenAI
|
7 |
|
8 |
+
TEXT_ONLY_TASKS = {
|
9 |
+
"2d83110e-a098-4ebb-9987-066c06fa42d0", # reversed question
|
10 |
+
"4fc2f1ae-8625-45b5-ab34-ad4433bc21f8", # wikipedia FA
|
11 |
+
"6f37996b-2ac7-44b0-8e68-6d28256631b4", # commutative check
|
12 |
+
"3cef3a44-215e-4aed-8e3b-b1e3f08063b7", # grocery list - vegetables
|
13 |
+
"305ac316-eef6-4446-960a-92d80d542f82", # actor - Magda M
|
14 |
+
"cf106601-ab4f-4af9-b045-5295fe67b37d", # least athletes
|
15 |
+
"5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko Competition
|
|
|
|
|
|
|
|
|
16 |
}
|
17 |
+
|
18 |
CSV_TASKS = {
|
19 |
+
"7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel - food sales
|
20 |
}
|
21 |
|
22 |
class GaiaAgent:
|
23 |
def __init__(self):
|
24 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
25 |
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
26 |
+
self.instructions = (
|
27 |
+
"You are a precise assistant solving GAIA benchmark questions. "
|
28 |
+
"Only answer if you are confident you can provide the exact correct result."
|
29 |
+
)
|
30 |
|
31 |
def fetch_file(self, task_id):
|
32 |
try:
|
33 |
url = f"{self.api_url}/files/{task_id}"
|
34 |
+
r = requests.get(url, timeout=10)
|
35 |
r.raise_for_status()
|
36 |
return r.content, r.headers.get("Content-Type", "")
|
37 |
except Exception as e:
|
38 |
return None, f"[FILE ERROR: {e}]"
|
39 |
|
40 |
+
def handle_csv_sales(self, csv_bytes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
try:
|
42 |
df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode()))
|
43 |
+
if 'category' not in df.columns or 'sales' not in df.columns:
|
44 |
+
return "[MISSING COLUMN]"
|
45 |
+
food_df = df[df['category'].str.lower() == 'food']
|
46 |
+
if food_df.empty:
|
47 |
+
return "[NO FOOD ITEMS FOUND]"
|
48 |
+
total = food_df['sales'].sum()
|
49 |
return f"${total:.2f}"
|
50 |
except Exception as e:
|
51 |
return f"[CSV ERROR: {e}]"
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def __call__(self, question: str, task_id: str = None) -> str:
|
54 |
+
# 1. Task filtering
|
55 |
+
if task_id not in TEXT_ONLY_TASKS and task_id not in CSV_TASKS:
|
56 |
+
return "[SKIPPED: Task not eligible for high-confidence answer]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# 2. CSV handling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if task_id in CSV_TASKS:
|
60 |
+
csv_bytes, err = self.fetch_file(task_id)
|
61 |
+
if csv_bytes:
|
62 |
+
result = self.handle_csv_sales(csv_bytes)
|
63 |
+
if result.startswith("["):
|
64 |
+
return "[SKIPPED: Confidence check failed]"
|
65 |
+
return result
|
66 |
return err
|
67 |
|
68 |
+
# 3. Text questions with high confidence
|
|
|
|
|
|
|
69 |
try:
|
70 |
response = self.client.chat.completions.create(
|
71 |
model="gpt-4-turbo",
|
72 |
messages=[
|
73 |
{"role": "system", "content": self.instructions},
|
74 |
+
{"role": "user", "content": f"QUESTION: {question}\nANSWER (concise):"}
|
75 |
],
|
76 |
+
temperature=0.0
|
77 |
)
|
78 |
return response.choices[0].message.content.strip()
|
79 |
except Exception as e:
|