|
import os |
|
import io |
|
import requests |
|
import mimetypes |
|
import subprocess |
|
import tempfile |
|
from openai import OpenAI |
|
from duckduckgo_search import DDGS |
|
from PIL import Image |
|
import pytesseract |
|
import openpyxl |
|
|
|
try: |
|
import whisper |
|
except ImportError: |
|
whisper = None |
|
|
|
try: |
|
import pdfplumber |
|
except ImportError: |
|
pdfplumber = None |
|
|
|
AGENT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
def safe_strip(text): |
|
if not text: |
|
return "" |
|
if isinstance(text, bytes): |
|
text = text.decode(errors="ignore") |
|
return str(text).replace("\r", "").strip() |
|
|
|
def parse_final_answer(text): |
|
""" |
|
Extracts only the final answer from an LLM reply, no explanations, no 'Final Answer:' prefix |
|
""" |
|
for line in reversed(text.splitlines()): |
|
if "Final Answer:" in line: |
|
return line.split("Final Answer:")[-1].strip() |
|
return safe_strip(text.splitlines()[-1]) |
|
|
|
def run_web_search(query, max_results=3): |
|
try: |
|
ddgs = DDGS() |
|
results = ddgs.text(query) |
|
for i, r in enumerate(results): |
|
if i >= max_results: |
|
break |
|
if r.get('body'): |
|
return r['body'] |
|
elif r.get('title'): |
|
return r['title'] |
|
return "" |
|
except Exception: |
|
return "" |
|
|
|
def fetch_file(task_id): |
|
url = f"{AGENT_API_URL}/files/{task_id}" |
|
try: |
|
resp = requests.get(url, timeout=30) |
|
resp.raise_for_status() |
|
content_type = resp.headers.get("Content-Type", "") |
|
return resp.content, content_type |
|
except Exception: |
|
return None, None |
|
|
|
def ocr_image(img_bytes): |
|
try: |
|
img = Image.open(io.BytesIO(img_bytes)) |
|
return safe_strip(pytesseract.image_to_string(img)) |
|
except Exception: |
|
return "" |
|
|
|
def read_excel(file_bytes): |
|
try: |
|
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), data_only=True) |
|
sheet = wb.active |
|
rows = list(sheet.iter_rows(values_only=True)) |
|
text = "\n".join(["\t".join(str(cell) if cell is not None else "" for cell in row) for row in rows]) |
|
return safe_strip(text) |
|
except Exception: |
|
return "" |
|
|
|
def read_pdf(file_bytes): |
|
if not pdfplumber: |
|
return "" |
|
try: |
|
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: |
|
return safe_strip("\n".join(page.extract_text() or "" for page in pdf.pages)) |
|
except Exception: |
|
return "" |
|
|
|
def transcribe_audio(audio_bytes): |
|
if not whisper: |
|
return "" |
|
try: |
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmpfile: |
|
tmpfile.write(audio_bytes) |
|
tmpfile.flush() |
|
model = whisper.load_model("base") |
|
result = model.transcribe(tmpfile.name) |
|
return safe_strip(result.get("text", "")) |
|
except Exception: |
|
return "" |
|
|
|
def transcribe_youtube_audio(youtube_url): |
|
""" |
|
Download audio from YouTube, transcribe using whisper |
|
""" |
|
if not whisper: |
|
return "" |
|
try: |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
audio_path = os.path.join(tmpdir, "audio.mp3") |
|
cmd = [ |
|
"yt-dlp", "-f", "bestaudio[ext=m4a]/bestaudio/best", |
|
"--extract-audio", "--audio-format", "mp3", |
|
"-o", audio_path, youtube_url |
|
] |
|
subprocess.run(cmd, check=True, capture_output=True) |
|
model = whisper.load_model("base") |
|
result = model.transcribe(audio_path) |
|
return safe_strip(result.get("text", "")) |
|
except Exception: |
|
return "" |
|
|
|
def extract_file_text(file_bytes, content_type, task_id=""): |
|
|
|
if "image" in content_type: |
|
return ocr_image(file_bytes) |
|
|
|
if "spreadsheet" in content_type or "excel" in content_type or task_id.endswith(".xlsx"): |
|
return read_excel(file_bytes) |
|
|
|
if "pdf" in content_type or task_id.endswith(".pdf"): |
|
return read_pdf(file_bytes) |
|
|
|
if "audio" in content_type or task_id.endswith(".mp3") or task_id.endswith(".wav"): |
|
return transcribe_audio(file_bytes) |
|
|
|
if "text" in content_type or "csv" in content_type or "json" in content_type or task_id.endswith(".csv") or task_id.endswith(".json") or task_id.endswith(".txt"): |
|
return safe_strip(file_bytes[:10000]) |
|
return "" |
|
|
|
def guess_youtube_link(question): |
|
|
|
import re |
|
matches = re.findall(r"(https?://[^\s]+)", question) |
|
for url in matches: |
|
if "youtube.com" in url or "youtu.be" in url: |
|
return url |
|
return None |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
self.instructions = ( |
|
"You are a top-tier research assistant for the GAIA benchmark. " |
|
"You analyze documents, reason step by step, and always provide a single, concise, and correct answer. " |
|
"If a file is provided, extract all relevant information. Use only information from the question and file. " |
|
"If the question refers to a video/audio file or YouTube link, always try to transcribe it. " |
|
"If you need additional facts, summarize web search results provided. " |
|
"Never apologize, never say you are unable, never output placeholders. " |
|
"Always output the answer only—no explanations, no extra text." |
|
) |
|
|
|
def __call__(self, question: str, task_id: str = None) -> str: |
|
file_text = "" |
|
web_context = "" |
|
video_transcript = "" |
|
prompt_parts = [self.instructions] |
|
|
|
if task_id: |
|
file_bytes, content_type = fetch_file(task_id) |
|
if file_bytes and content_type: |
|
file_text = extract_file_text(file_bytes, content_type, task_id) |
|
if file_text: |
|
prompt_parts.append(f"Here is the extracted file content:\n{file_text}\n") |
|
|
|
youtube_url = guess_youtube_link(question) |
|
if youtube_url: |
|
transcript = transcribe_youtube_audio(youtube_url) |
|
if transcript: |
|
prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n") |
|
|
|
search_keywords = [ |
|
"who", "what", "when", "where", "name", "number", "how many", |
|
"first", "last", "award", "recipient", "code", "surname", "year", "album", "actor", "winner" |
|
] |
|
if (not file_text and not youtube_url) or any(kw in question.lower() for kw in search_keywords): |
|
search_results = run_web_search(question) |
|
if search_results: |
|
prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n") |
|
|
|
prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.") |
|
prompt = "\n".join(prompt_parts) |
|
|
|
response = self.client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{"role": "system", "content": self.instructions}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.0, |
|
max_tokens=512, |
|
) |
|
raw_output = safe_strip(response.choices[0].message.content) |
|
|
|
return parse_final_answer(raw_output) |
|
|
|
|
|
def answer_question(question, task_id=None): |
|
agent = GaiaAgent() |
|
return agent(question, task_id) |