Tesvia's picture
Upload 5 files
4191a9b verified
# Custom tools for smolagents GAIA agent
from __future__ import annotations
import contextlib
import io
import os
from typing import Any, Dict, List
from smolagents import Tool
# ---- 1. PythonRunTool ------------------------------------------------------
class PythonRunTool(Tool):
name = "python_run"
description = """
Execute trusted Python code and return printed output + repr() of the last expression (or _result variable).
"""
inputs = {
"code": {
"type": "string",
"description": "Python code to execute",
"required": True
}
}
output_type = "string"
def forward(self, code: str) -> str:
buf, ns = io.StringIO(), {}
last = None
try:
with contextlib.redirect_stdout(buf):
exec(compile(code, "<agent-python>", "exec"), {}, ns)
last = ns.get("_result", None)
except Exception as e:
raise RuntimeError(f"PythonRunTool error: {e}") from e
out = buf.getvalue()
# Always return a string
result = (out + (repr(last) if last is not None else "")).strip()
return str(result)
# ---- 2. ExcelLoaderTool ----------------------------------------------------
class ExcelLoaderTool(Tool):
name = "load_spreadsheet"
description = """
Read .xlsx/.xls/.csv from disk and return rows as a list of dictionaries with string keys.
"""
inputs = {
"path": {
"type": "string",
"description": "Path to .csv/.xls/.xlsx file",
"required": True
},
"sheet": {
"type": "string",
"description": "Sheet name or index (optional, required for Excel files only)",
"required": False,
"default": "",
"nullable": True
}
}
output_type = "array"
def forward(self, path: str, sheet: str | int | None = None) -> str:
import pandas as pd
if not os.path.isfile(path):
raise FileNotFoundError(path)
ext = os.path.splitext(path)[1].lower()
if sheet == "":
sheet = None
if ext == ".csv":
df = pd.read_csv(path)
else:
df = pd.read_excel(path, sheet_name=sheet)
if isinstance(df, dict):
# If user did not specify a sheet, use the first one found
first_sheet = next(iter(df))
df = df[first_sheet]
records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
# Always return a string
return str(records)
# ---- 3. YouTubeTranscriptTool ---------------------------------------------
class YouTubeTranscriptTool(Tool):
name = "youtube_transcript"
description = """
Return the subtitles of a YouTube URL using youtube-transcript-api.
"""
inputs = {
"url": {
"type": "string",
"description": "YouTube URL",
"required": True
},
"lang": {
"type": "string",
"description": "Transcript language (default: en)",
"required": False,
"default": "en",
"nullable": True
}
}
output_type = "string"
def forward(self, url: str, lang: str = "en") -> str:
from urllib.parse import urlparse, parse_qs
from youtube_transcript_api._api import YouTubeTranscriptApi
vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
text = " ".join(d["text"] for d in data).strip()
return str(text)
# ---- 4. AudioTranscriptionTool --------------------------------------------
class AudioTranscriptionTool(Tool):
name = "transcribe_audio"
description = """
Transcribe an audio file with OpenAI Whisper, returns plain text."
"""
inputs = {
"path": {
"type": "string",
"description": "Path to audio file",
"required": True
},
"model": {
"type": "string",
"description": "Model name for transcription (default: whisper-1)",
"required": False,
"default": "whisper-1",
"nullable": True
}
}
output_type = "string"
def forward(self, path: str, model: str = "whisper-1") -> str:
import openai
if not os.path.isfile(path):
raise FileNotFoundError(path)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
with open(path, "rb") as fp:
transcript = client.audio.transcriptions.create(model=model, file=fp)
return str(transcript.text.strip())
# ---- 5. SimpleOCRTool ------------------------------------------------------
class SimpleOCRTool(Tool):
name = "image_ocr"
description = """
Return any text spotted in an image via pytesseract OCR.
"""
inputs = {
"path": {
"type": "string",
"description": "Path to image file",
"required": True
}
}
output_type = "string"
def forward(self, path: str) -> str:
from PIL import Image
import pytesseract
if not os.path.isfile(path):
raise FileNotFoundError(path)
return str(pytesseract.image_to_string(Image.open(path)).strip())
# ---------------------------------------------------------------------------
__all__ = [
"PythonRunTool",
"ExcelLoaderTool",
"YouTubeTranscriptTool",
"AudioTranscriptionTool",
"SimpleOCRTool",
]