dawid-lorek's picture
Update agent.py
392825a verified
raw
history blame
4.49 kB
import os
import io
import base64
import requests
import pandas as pd
from openai import OpenAI
# --- Task classification ---
AUDIO_TASKS = {
"9d191bce-651d-4746-be2d-7ef8ecadb9c2",
"99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
"1f975693-876d-457b-a649-393859e79bf3"
}
IMAGE_TASKS = {
"a1e91b78-d3d8-4675-bb8d-62741b4b68a6",
"cca530fc-4052-43b2-b130-b30968d8aa44"
}
CODE_TASKS = {
"f918266a-b3e0-4914-865d-4faa564f1aef"
}
CSV_TASKS = {
"7bd855d8-463d-4ed5-93ca-5fe35145f733"
}
class GaiaAgent:
def __init__(self):
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.api_url = "https://agents-course-unit4-scoring.hf.space"
self.instructions = "You are a helpful assistant solving GAIA benchmark questions using any available tools."
def fetch_file(self, task_id):
try:
url = f"{self.api_url}/files/{task_id}"
r = requests.get(url, timeout=15)
r.raise_for_status()
return r.content, r.headers.get("Content-Type", "")
except Exception as e:
return None, f"[FILE ERROR: {e}]"
def handle_audio(self, audio_bytes):
try:
transcript = self.client.audio.transcriptions.create(
model="whisper-1",
file=io.BytesIO(audio_bytes),
response_format="text"
)
return transcript.strip()
except Exception as e:
return f"[TRANSCRIPTION ERROR: {e}]"
def handle_image(self, image_bytes, question):
b64 = base64.b64encode(image_bytes).decode("utf-8")
messages = [
{"role": "system", "content": self.instructions},
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}
]
}
]
try:
response = self.client.chat.completions.create(model="gpt-4o", messages=messages)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[IMAGE ERROR: {e}]"
def handle_csv(self, csv_bytes, question):
try:
df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode()))
total = df[df['category'].str.lower() == 'food']['sales'].sum()
return f"${total:.2f}"
except Exception as e:
return f"[CSV ERROR: {e}]"
def handle_code(self, code_bytes):
try:
exec_env = {}
exec(code_bytes.decode("utf-8"), {}, exec_env)
return str(exec_env.get("result", "[Executed. Check result variable manually]"))
except Exception as e:
return f"[EXEC ERROR: {e}]"
def __call__(self, question: str, task_id: str = None) -> str:
if not task_id:
return self.ask_llm(question)
# audio
if task_id in AUDIO_TASKS:
file, err = self.fetch_file(task_id)
if file:
transcript = self.handle_audio(file)
return self.ask_llm(f"Audio transcript: {transcript}\n\nQuestion: {question}")
return err
# image
if task_id in IMAGE_TASKS:
file, err = self.fetch_file(task_id)
if file:
return self.handle_image(file, question)
return err
# python code
if task_id in CODE_TASKS:
file, err = self.fetch_file(task_id)
if file:
return self.handle_code(file)
return err
# CSV/Excel
if task_id in CSV_TASKS:
file, err = self.fetch_file(task_id)
if file:
return self.handle_csv(file, question)
return err
# fallback to LLM only
return self.ask_llm(question)
def ask_llm(self, prompt: str) -> str:
try:
response = self.client.chat.completions.create(
model="gpt-4-turbo",
messages=[
{"role": "system", "content": self.instructions},
{"role": "user", "content": prompt.strip()}
],
temperature=0.0,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"[LLM ERROR: {e}]"