Tesvia's picture
Upload tools.py
ac3e234 verified
raw
history blame
5.22 kB
"""
Custom function tools for OpenAI Agents SDK GAIA agent.
"""
from __future__ import annotations
import contextlib
import io
import os
import time
import datetime
from typing import TypedDict, List, Union
from agents import function_tool
class DuckDuckGoResult(TypedDict):
title: str
href: str
body: str
class SpreadsheetRow(TypedDict):
# If you don't know the columns, leave this empty,
# but ideally, define them.
pass
def log(msg):
print(f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] {msg}")
def log_tool_call(func):
def wrapper(*args, **kwargs):
t0 = time.time()
log(f"Step: {func.__name__} started.")
try:
result = func(*args, **kwargs)
log(f"Step: {func.__name__} completed in {time.time() - t0:.2f}s.")
return result
except Exception as e:
log(f"Step: {func.__name__} error: {e}")
raise
return wrapper
# 1. --------------------------------------------------------------------
@function_tool
@log_tool_call
def python_run(code: str) -> str:
"""Execute trusted Python code and return the captured stdout together with
the repr() of the last expression (or `_result` variable if set).
Args:
code: Python code to execute.
"""
buf = io.StringIO()
ns: dict = {}
last = None
try:
with contextlib.redirect_stdout(buf):
exec(compile(code, "<agent-python>", "exec"), {}, ns)
last = ns.get("_result")
except Exception as e:
raise RuntimeError(f"python_run error: {e}") from e
out = buf.getvalue()
return (out + (repr(last) if last is not None else "")).strip()
# 2. --------------------------------------------------------------------
@function_tool
@log_tool_call
def load_spreadsheet(path: str, sheet: Union[str, int, None] = None) -> List[SpreadsheetRow]:
"""Read .csv, .xls or .xlsx from disk and return rows as list of dictionaries.
Args:
path: Path to spreadsheet file.
sheet: Sheet name or index (for Excel files only).
"""
import pandas as pd
if not os.path.isfile(path):
raise FileNotFoundError(path)
ext = os.path.splitext(path)[1].lower()
if ext == ".csv":
df = pd.read_csv(path)
dfs = [df]
else:
sheets = pd.read_excel(path, sheet_name=sheet if sheet not in ("", None) else None)
if isinstance(sheets, dict):
dfs = sheets.values()
else:
dfs = [sheets]
results = []
for df in dfs:
results.extend([{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")])
return results
# 3. --------------------------------------------------------------------
@function_tool
@log_tool_call
def youtube_transcript(url: str, lang: str = "en") -> str:
"""Fetch the subtitles of a YouTube video.
Args:
url: YouTube video URL.
lang: Preferred transcript language code (default "en").
"""
from urllib.parse import urlparse, parse_qs
from youtube_transcript_api._api import YouTubeTranscriptApi
vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
data = YouTubeTranscriptApi.get_transcript(
vid, languages=[lang, "en", "en-US", "en-GB"]
)
return " ".join(chunk["text"] for chunk in data).strip()
# 4. --------------------------------------------------------------------
@function_tool
@log_tool_call
def transcribe_audio(path: str, model: str = "whisper-1") -> str:
"""Transcribe an audio file using OpenAI Whisper.
Args:
path: Path to audio file (wav / mp3 / m4a / etc.).
model: Whisper model name (default "whisper-1").
"""
import openai
if not os.path.isfile(path):
raise FileNotFoundError(path)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
with open(path, "rb") as fp:
transcript = client.audio.transcriptions.create(model=model, file=fp)
return transcript.text.strip()
# 5. --------------------------------------------------------------------
@function_tool
@log_tool_call
def image_ocr(path: str) -> str:
"""Perform OCR on an image using Tesseract.
Args:
path: Path to image file.
"""
from PIL import Image
import pytesseract
if not os.path.isfile(path):
raise FileNotFoundError(path)
return pytesseract.image_to_string(Image.open(path)).strip()
# 6. --------------------------------------------------------------------
@function_tool
@log_tool_call
def duckduckgo_search(query: str, max_results: int = 5) -> List[DuckDuckGoResult]:
"""Search DuckDuckGo and return a list of result dicts with title, href and body.
Args:
query: The search query.
max_results: Maximum results to return (default 5).
"""
from duckduckgo_search import DDGS
results = []
with DDGS() as ddgs:
for r in ddgs.text(query, max_results=max_results):
results.append(
{
"title": r.get("title", ""),
"href": r.get("href", ""),
"body": r.get("body", ""),
}
)
return results