File size: 5,048 Bytes
52d1305
 
 
 
 
365b711
52d1305
 
 
 
 
 
40ad9f8
 
 
365b711
10acbd8
365b711
8cf4b42
52d1305
365b711
52d1305
 
 
 
 
 
 
 
 
 
 
 
 
 
40ad9f8
 
 
365b711
09fe455
 
 
 
 
365b711
10acbd8
365b711
 
c965083
09fe455
365b711
 
05193be
52d1305
365b711
52d1305
 
 
 
09fe455
c965083
 
52d1305
 
 
 
 
 
 
 
 
 
40ad9f8
 
 
365b711
10acbd8
 
365b711
8cf4b42
52d1305
365b711
52d1305
 
 
 
 
 
 
 
 
40ad9f8
 
 
365b711
10acbd8
 
365b711
8cf4b42
52d1305
365b711
52d1305
 
 
c965083
52d1305
c965083
 
52d1305
 
 
 
40ad9f8
 
 
365b711
10acbd8
365b711
8cf4b42
52d1305
365b711
52d1305
 
 
 
 
 
 
 
 
 
 
 
 
365b711
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Custom tools for smolagents GAIA agent
from __future__ import annotations
import contextlib
import io
import os
from typing import Any, Dict, List

from smolagents import Tool

# ---- 1. PythonRunTool ------------------------------------------------------
class PythonRunTool(Tool):
    name = "python_run"
    description = """
        Execute trusted Python code and return printed output + repr() of the last expression (or _result variable).
    """
    inputs = {
        "code": {"type": "string", "description": "Python code to execute", "required": True}
    }
    output_type = "string"

    def forward(self, code: str) -> str:
        buf, ns = io.StringIO(), {}
        last = None
        try:
            with contextlib.redirect_stdout(buf):
                exec(compile(code, "<agent-python>", "exec"), {}, ns)
            last = ns.get("_result", None)
        except Exception as e:
            raise RuntimeError(f"PythonRunTool error: {e}") from e
        out = buf.getvalue()
        return (out + (repr(last) if last is not None else "")).strip()

# ---- 2. ExcelLoaderTool ----------------------------------------------------
class ExcelLoaderTool(Tool):
    name = "load_spreadsheet"
    description = """
        Read .xlsx/.xls/.csv from disk and return rows as a list of dictionaries with string keys.
    """
    inputs = {
        "path": {
            "type": "string",
            "description": "Path to .csv/.xls/.xlsx file",
            "required": True
        },
        "sheet": {
            "type": "string",
            "description": "Sheet name or index (optional, required for Excel files only)",
            "required": False,
            "default": "",
            "nullable": True
        }
    }
    output_type = "array"

    def forward(self, path: str, sheet: str | int | None = None) -> List[Dict[str, Any]]:
        import pandas as pd
        if not os.path.isfile(path):
            raise FileNotFoundError(path)
        ext = os.path.splitext(path)[1].lower()
        # Handle empty string as None for sheet
        if sheet == "":
            sheet = None
        if ext == ".csv":
            df = pd.read_csv(path)
        else:
            df = pd.read_excel(path, sheet_name=sheet)
        records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")]
        return records

# ---- 3. YouTubeTranscriptTool ---------------------------------------------
class YouTubeTranscriptTool(Tool):
    name = "youtube_transcript"
    description = """
        Return the subtitles of a YouTube URL using youtube-transcript-api.
    """
    inputs = {
        "url": {"type": "string", "description": "YouTube URL", "required": True},
        "lang": {"type": "string", "description": "Transcript language (default: en)", "required": False, "default": "en"}
    }
    output_type = "string"

    def forward(self, url: str, lang: str = "en") -> str:
        from urllib.parse import urlparse, parse_qs
        from youtube_transcript_api._api import YouTubeTranscriptApi
        vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1]
        data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"])
        return " ".join(d["text"] for d in data).strip()

# ---- 4. AudioTranscriptionTool --------------------------------------------
class AudioTranscriptionTool(Tool):
    name = "transcribe_audio"
    description = """
        Transcribe an audio file with OpenAI Whisper, returns plain text."
    """
    inputs = {
        "path": {"type": "string", "description": "Path to audio file", "required": True},
        "model": {"type": "string", "description": "Model name for transcription (default: whisper-1)", "required": False, "default": "whisper-1"}
    }
    output_type = "string"

    def forward(self, path: str, model: str = "whisper-1") -> str:
        import openai
        if not os.path.isfile(path):
            raise FileNotFoundError(path)
        client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        with open(path, "rb") as fp:
            transcript = client.audio.transcriptions.create(model=model, file=fp)
        return transcript.text.strip()

# ---- 5. SimpleOCRTool ------------------------------------------------------
class SimpleOCRTool(Tool):
    name = "image_ocr"
    description = """
        Return any text spotted in an image via pytesseract OCR.
    """
    inputs = {
        "path": {"type": "string", "description": "Path to image file", "required": True}
    }
    output_type = "string"

    def forward(self, path: str) -> str:
        from PIL import Image
        import pytesseract
        if not os.path.isfile(path):
            raise FileNotFoundError(path)
        return pytesseract.image_to_string(Image.open(path)).strip()

# ---------------------------------------------------------------------------
__all__ = [
    "PythonRunTool",
    "ExcelLoaderTool",
    "YouTubeTranscriptTool",
    "AudioTranscriptionTool",
    "SimpleOCRTool",
]