|
import os |
|
import io |
|
import base64 |
|
import requests |
|
import pandas as pd |
|
from openai import OpenAI |
|
|
|
|
|
AUDIO_TASKS = { |
|
"9d191bce-651d-4746-be2d-7ef8ecadb9c2", |
|
"99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3", |
|
"1f975693-876d-457b-a649-393859e79bf3" |
|
} |
|
IMAGE_TASKS = { |
|
"a1e91b78-d3d8-4675-bb8d-62741b4b68a6", |
|
"cca530fc-4052-43b2-b130-b30968d8aa44" |
|
} |
|
CODE_TASKS = { |
|
"f918266a-b3e0-4914-865d-4faa564f1aef" |
|
} |
|
CSV_TASKS = { |
|
"7bd855d8-463d-4ed5-93ca-5fe35145f733" |
|
} |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
self.api_url = "https://agents-course-unit4-scoring.hf.space" |
|
self.instructions = "You are a helpful assistant solving GAIA benchmark questions using any available tools." |
|
|
|
def fetch_file(self, task_id): |
|
try: |
|
url = f"{self.api_url}/files/{task_id}" |
|
r = requests.get(url, timeout=15) |
|
r.raise_for_status() |
|
return r.content, r.headers.get("Content-Type", "") |
|
except Exception as e: |
|
return None, f"[FILE ERROR: {e}]" |
|
|
|
def handle_audio(self, audio_bytes): |
|
try: |
|
transcript = self.client.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=io.BytesIO(audio_bytes), |
|
response_format="text" |
|
) |
|
return transcript.strip() |
|
except Exception as e: |
|
return f"[TRANSCRIPTION ERROR: {e}]" |
|
|
|
def handle_image(self, image_bytes, question): |
|
b64 = base64.b64encode(image_bytes).decode("utf-8") |
|
messages = [ |
|
{"role": "system", "content": self.instructions}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": question}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}} |
|
] |
|
} |
|
] |
|
try: |
|
response = self.client.chat.completions.create(model="gpt-4o", messages=messages) |
|
return response.choices[0].message.content.strip() |
|
except Exception as e: |
|
return f"[IMAGE ERROR: {e}]" |
|
|
|
def handle_csv(self, csv_bytes, question): |
|
try: |
|
df = pd.read_excel(io.BytesIO(csv_bytes)) if csv_bytes[:4] == b"PK\x03\x04" else pd.read_csv(io.StringIO(csv_bytes.decode())) |
|
total = df[df['category'].str.lower() == 'food']['sales'].sum() |
|
return f"${total:.2f}" |
|
except Exception as e: |
|
return f"[CSV ERROR: {e}]" |
|
|
|
def handle_code(self, code_bytes): |
|
try: |
|
exec_env = {} |
|
exec(code_bytes.decode("utf-8"), {}, exec_env) |
|
return str(exec_env.get("result", "[Executed. Check result variable manually]")) |
|
except Exception as e: |
|
return f"[EXEC ERROR: {e}]" |
|
|
|
def __call__(self, question: str, task_id: str = None) -> str: |
|
if not task_id: |
|
return self.ask_llm(question) |
|
|
|
|
|
if task_id in AUDIO_TASKS: |
|
file, err = self.fetch_file(task_id) |
|
if file: |
|
transcript = self.handle_audio(file) |
|
return self.ask_llm(f"Audio transcript: {transcript}\n\nQuestion: {question}") |
|
return err |
|
|
|
|
|
if task_id in IMAGE_TASKS: |
|
file, err = self.fetch_file(task_id) |
|
if file: |
|
return self.handle_image(file, question) |
|
return err |
|
|
|
|
|
if task_id in CODE_TASKS: |
|
file, err = self.fetch_file(task_id) |
|
if file: |
|
return self.handle_code(file) |
|
return err |
|
|
|
|
|
if task_id in CSV_TASKS: |
|
file, err = self.fetch_file(task_id) |
|
if file: |
|
return self.handle_csv(file, question) |
|
return err |
|
|
|
|
|
return self.ask_llm(question) |
|
|
|
def ask_llm(self, prompt: str) -> str: |
|
try: |
|
response = self.client.chat.completions.create( |
|
model="gpt-4-turbo", |
|
messages=[ |
|
{"role": "system", "content": self.instructions}, |
|
{"role": "user", "content": prompt.strip()} |
|
], |
|
temperature=0.0, |
|
) |
|
return response.choices[0].message.content.strip() |
|
except Exception as e: |
|
return f"[LLM ERROR: {e}]" |
|
|