dawid-lorek's picture
Update agent.py
4695b90 verified
raw
history blame
7.83 kB
import os
import io
import requests
import mimetypes
import subprocess
import tempfile
from openai import OpenAI
from duckduckgo_search import DDGS
from PIL import Image
import pytesseract
import openpyxl
try:
import whisper
except ImportError:
whisper = None
try:
import pdfplumber
except ImportError:
pdfplumber = None
AGENT_API_URL = "https://agents-course-unit4-scoring.hf.space"
def safe_strip(text):
if not text:
return ""
if isinstance(text, bytes):
text = text.decode(errors="ignore")
return str(text).replace("\r", "").strip()
def parse_final_answer(text):
"""
Extracts only the final answer from an LLM reply, no explanations, no 'Final Answer:' prefix
"""
for line in reversed(text.splitlines()):
if "Final Answer:" in line:
return line.split("Final Answer:")[-1].strip()
return safe_strip(text.splitlines()[-1])
def run_web_search(query, max_results=3):
try:
ddgs = DDGS()
results = ddgs.text(query)
for i, r in enumerate(results):
if i >= max_results:
break
if r.get('body'):
return r['body']
elif r.get('title'):
return r['title']
return ""
except Exception:
return ""
def fetch_file(task_id):
url = f"{AGENT_API_URL}/files/{task_id}"
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
content_type = resp.headers.get("Content-Type", "")
return resp.content, content_type
except Exception:
return None, None
def ocr_image(img_bytes):
try:
img = Image.open(io.BytesIO(img_bytes))
return safe_strip(pytesseract.image_to_string(img))
except Exception:
return ""
def read_excel(file_bytes):
try:
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), data_only=True)
sheet = wb.active
rows = list(sheet.iter_rows(values_only=True))
text = "\n".join(["\t".join(str(cell) if cell is not None else "" for cell in row) for row in rows])
return safe_strip(text)
except Exception:
return ""
def read_pdf(file_bytes):
if not pdfplumber:
return ""
try:
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
return safe_strip("\n".join(page.extract_text() or "" for page in pdf.pages))
except Exception:
return ""
def transcribe_audio(audio_bytes):
if not whisper:
return ""
try:
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmpfile:
tmpfile.write(audio_bytes)
tmpfile.flush()
model = whisper.load_model("base")
result = model.transcribe(tmpfile.name)
return safe_strip(result.get("text", ""))
except Exception:
return ""
def transcribe_youtube_audio(youtube_url):
"""
Download audio from YouTube, transcribe using whisper
"""
if not whisper:
return ""
try:
with tempfile.TemporaryDirectory() as tmpdir:
audio_path = os.path.join(tmpdir, "audio.mp3")
cmd = [
"yt-dlp", "-f", "bestaudio[ext=m4a]/bestaudio/best",
"--extract-audio", "--audio-format", "mp3",
"-o", audio_path, youtube_url
]
subprocess.run(cmd, check=True, capture_output=True)
model = whisper.load_model("base")
result = model.transcribe(audio_path)
return safe_strip(result.get("text", ""))
except Exception:
return ""
def extract_file_text(file_bytes, content_type, task_id=""):
# Images
if "image" in content_type:
return ocr_image(file_bytes)
# Excel
if "spreadsheet" in content_type or "excel" in content_type or task_id.endswith(".xlsx"):
return read_excel(file_bytes)
# PDF
if "pdf" in content_type or task_id.endswith(".pdf"):
return read_pdf(file_bytes)
# Audio
if "audio" in content_type or task_id.endswith(".mp3") or task_id.endswith(".wav"):
return transcribe_audio(file_bytes)
# Text, CSV, JSON
if "text" in content_type or "csv" in content_type or "json" in content_type or task_id.endswith(".csv") or task_id.endswith(".json") or task_id.endswith(".txt"):
return safe_strip(file_bytes[:10000])
return ""
def guess_youtube_link(question):
# If the question mentions YouTube or a video link, try to extract it
import re
matches = re.findall(r"(https?://[^\s]+)", question)
for url in matches:
if "youtube.com" in url or "youtu.be" in url:
return url
return None
class GaiaAgent:
def __init__(self):
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.instructions = (
"You are a top-tier research assistant for the GAIA benchmark. "
"You analyze documents, reason step by step, and always provide a single, concise, and correct answer. "
"If a file is provided, extract all relevant information. Use only information from the question and file. "
"If the question refers to a video/audio file or YouTube link, always try to transcribe it. "
"If you need additional facts, summarize web search results provided. "
"Never apologize, never say you are unable, never output placeholders. "
"Always output the answer only—no explanations, no extra text."
)
def __call__(self, question: str, task_id: str = None) -> str:
file_text = ""
web_context = ""
video_transcript = ""
prompt_parts = [self.instructions]
# 1. File handling (image, Excel, CSV, PDF, text, audio)
if task_id:
file_bytes, content_type = fetch_file(task_id)
if file_bytes and content_type:
file_text = extract_file_text(file_bytes, content_type, task_id)
if file_text:
prompt_parts.append(f"Here is the extracted file content:\n{file_text}\n")
# 2. YouTube/video handling (by URL in question)
youtube_url = guess_youtube_link(question)
if youtube_url:
transcript = transcribe_youtube_audio(youtube_url)
if transcript:
prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
# 3. Web search fallback for open-world/factoid questions or if no file info
search_keywords = [
"who", "what", "when", "where", "name", "number", "how many",
"first", "last", "award", "recipient", "code", "surname", "year", "album", "actor", "winner"
]
if (not file_text and not youtube_url) or any(kw in question.lower() for kw in search_keywords):
search_results = run_web_search(question)
if search_results:
prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
# 4. Compose prompt
prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
prompt = "\n".join(prompt_parts)
# 5. Call LLM
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": self.instructions},
{"role": "user", "content": prompt}
],
temperature=0.0,
max_tokens=512,
)
raw_output = safe_strip(response.choices[0].message.content)
# 6. Only return the single-line answer, with no prefix
return parse_final_answer(raw_output)
# For compatibility with older interface (for "answer_question" import)
def answer_question(question, task_id=None):
agent = GaiaAgent()
return agent(question, task_id)