Spaces:
Sleeping
Sleeping
File size: 4,167 Bytes
52d1305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# Custom tools for smolagents GAIA agent
from __future__ import annotations
import contextlib
import io
import os
from typing import Any, Dict, List, Hashable
from smolagents import Tool
# ---- 1. PythonRunTool ------------------------------------------------------
class PythonRunTool(Tool):
name = "python_run"
description = (
"Execute trusted Python code and return printed output "
"+ repr() of the last expression (or _result variable)."
)
def forward(self, code: str) -> str: # type: ignore[override]
buf, ns = io.StringIO(), {}
last = None
try:
with contextlib.redirect_stdout(buf):
exec(compile(code, "<agent-python>", "exec"), {}, ns)
last = ns.get("_result", None)
except Exception as e:
raise RuntimeError(f"PythonRunTool error: {e}") from e
out = buf.getvalue()
return (out + (repr(last) if last is not None else "")).strip()
# ---- 2. ExcelLoaderTool ----------------------------------------------------
class ExcelLoaderTool(Tool):
name = "load_spreadsheet"
description = (
"Read .xlsx/.xls/.csv from disk and return "
"rows as a list of dictionaries with string keys."
)
def forward(self, path: str, sheet: str | int | None = None) -> List[Dict[str, Any]]: # type: ignore[override]
import pandas as pd
if not os.path.isfile(path):
raise FileNotFoundError(path)
ext = os.path.splitext(path)[1].lower()
if ext == ".csv":
df = pd.read_csv(path)
else:
df = pd.read_excel(path, sheet_name=sheet)
# Ensure all keys are str for type safety
records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
return records
# ---- 3. YouTubeTranscriptTool ---------------------------------------------
class YouTubeTranscriptTool(Tool):
name = "youtube_transcript"
description = "Return the subtitles of a YouTube URL using youtube-transcript-api."
def forward(self, url: str, lang: str = "en") -> str: # type: ignore[override]
from urllib.parse import urlparse, parse_qs
# Per Pylance, import from private API
from youtube_transcript_api._api import YouTubeTranscriptApi
vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
return " ".join(d["text"] for d in data).strip()
# ---- 4. AudioTranscriptionTool --------------------------------------------
class AudioTranscriptionTool(Tool):
name = "transcribe_audio"
description = "Transcribe an audio file with OpenAI Whisper, returns plain text."
def forward(self, path: str, model: str = "whisper-1") -> str: # type: ignore[override]
import openai
import os
if not os.path.isfile(path):
raise FileNotFoundError(path)
openai.api_key = os.getenv("OPENAI_API_KEY")
# Version/API guard for openai.Audio
if not hasattr(openai, "Audio"):
raise ImportError(
"Your OpenAI package does not support Audio. "
"Please upgrade it with: pip install --upgrade openai"
)
with open(path, "rb") as fp:
# type: ignore[attr-defined]
return openai.Audio.transcribe(model=model, file=fp)["text"].strip()
# ---- 5. SimpleOCRTool ------------------------------------------------------
class SimpleOCRTool(Tool):
name = "image_ocr"
description = "Return any text spotted in an image via pytesseract OCR."
def forward(self, path: str) -> str: # type: ignore[override]
from PIL import Image
import pytesseract
if not os.path.isfile(path):
raise FileNotFoundError(path)
return pytesseract.image_to_string(Image.open(path)).strip()
# ---------------------------------------------------------------------------
__all__ = [
"PythonRunTool",
"ExcelLoaderTool",
"YouTubeTranscriptTool",
"AudioTranscriptionTool",
"SimpleOCRTool",
] |