Spaces:
Sleeping
Sleeping
File size: 5,656 Bytes
4191a9b 52d1305 4191a9b 9623335 4191a9b 9623335 4191a9b 73bb16b 4191a9b 73bb16b 4191a9b 9623335 4191a9b 73bb16b 4191a9b 73bb16b 4191a9b 73bb16b 4191a9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Custom tools for smolagents GAIA agent
from __future__ import annotations
import contextlib
import io
import os
from typing import Any, Dict, List
from smolagents import Tool
# ---- 1. PythonRunTool ------------------------------------------------------
class PythonRunTool(Tool):
name = "python_run"
description = """
Execute trusted Python code and return printed output + repr() of the last expression (or _result variable).
"""
inputs = {
"code": {
"type": "string",
"description": "Python code to execute",
"required": True
}
}
output_type = "string"
def forward(self, code: str) -> str:
buf, ns = io.StringIO(), {}
last = None
try:
with contextlib.redirect_stdout(buf):
exec(compile(code, "<agent-python>", "exec"), {}, ns)
last = ns.get("_result", None)
except Exception as e:
raise RuntimeError(f"PythonRunTool error: {e}") from e
out = buf.getvalue()
# Always return a string
result = (out + (repr(last) if last is not None else "")).strip()
return str(result)
# ---- 2. ExcelLoaderTool ----------------------------------------------------
class ExcelLoaderTool(Tool):
name = "load_spreadsheet"
description = """
Read .xlsx/.xls/.csv from disk and return rows as a list of dictionaries with string keys.
"""
inputs = {
"path": {
"type": "string",
"description": "Path to .csv/.xls/.xlsx file",
"required": True
},
"sheet": {
"type": "string",
"description": "Sheet name or index (optional, required for Excel files only)",
"required": False,
"default": "",
"nullable": True
}
}
output_type = "array"
def forward(self, path: str, sheet: str | int | None = None) -> str:
import pandas as pd
if not os.path.isfile(path):
raise FileNotFoundError(path)
ext = os.path.splitext(path)[1].lower()
if sheet == "":
sheet = None
if ext == ".csv":
df = pd.read_csv(path)
else:
df = pd.read_excel(path, sheet_name=sheet)
if isinstance(df, dict):
# If user did not specify a sheet, use the first one found
first_sheet = next(iter(df))
df = df[first_sheet]
records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
# Always return a string
return str(records)
# ---- 3. YouTubeTranscriptTool ---------------------------------------------
class YouTubeTranscriptTool(Tool):
name = "youtube_transcript"
description = """
Return the subtitles of a YouTube URL using youtube-transcript-api.
"""
inputs = {
"url": {
"type": "string",
"description": "YouTube URL",
"required": True
},
"lang": {
"type": "string",
"description": "Transcript language (default: en)",
"required": False,
"default": "en",
"nullable": True
}
}
output_type = "string"
def forward(self, url: str, lang: str = "en") -> str:
from urllib.parse import urlparse, parse_qs
from youtube_transcript_api._api import YouTubeTranscriptApi
vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
text = " ".join(d["text"] for d in data).strip()
return str(text)
# ---- 4. AudioTranscriptionTool --------------------------------------------
class AudioTranscriptionTool(Tool):
name = "transcribe_audio"
description = """
Transcribe an audio file with OpenAI Whisper, returns plain text."
"""
inputs = {
"path": {
"type": "string",
"description": "Path to audio file",
"required": True
},
"model": {
"type": "string",
"description": "Model name for transcription (default: whisper-1)",
"required": False,
"default": "whisper-1",
"nullable": True
}
}
output_type = "string"
def forward(self, path: str, model: str = "whisper-1") -> str:
import openai
if not os.path.isfile(path):
raise FileNotFoundError(path)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
with open(path, "rb") as fp:
transcript = client.audio.transcriptions.create(model=model, file=fp)
return str(transcript.text.strip())
# ---- 5. SimpleOCRTool ------------------------------------------------------
class SimpleOCRTool(Tool):
name = "image_ocr"
description = """
Return any text spotted in an image via pytesseract OCR.
"""
inputs = {
"path": {
"type": "string",
"description": "Path to image file",
"required": True
}
}
output_type = "string"
def forward(self, path: str) -> str:
from PIL import Image
import pytesseract
if not os.path.isfile(path):
raise FileNotFoundError(path)
return str(pytesseract.image_to_string(Image.open(path)).strip())
# ---------------------------------------------------------------------------
__all__ = [
"PythonRunTool",
"ExcelLoaderTool",
"YouTubeTranscriptTool",
"AudioTranscriptionTool",
"SimpleOCRTool",
]
|